Skip to content

Commit ace739a

Browse files
Zhipeng Fanfacebook-github-bot
Zhipeng Fan
authored andcommitted
cat Keypoints
Summary: Add cat methods to Keypoints Reviewed By: ppwwyyxx Differential Revision: D30199199 fbshipit-source-id: 7d2c63810e7e05931c90b7ea5102a734b7ce8af4
1 parent 0fb7f04 commit ace739a

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

detectron2/structures/keypoints.py

+20
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,26 @@ def __repr__(self) -> str:
9191
s += "num_instances={})".format(len(self.tensor))
9292
return s
9393

94+
@staticmethod
95+
def cat(keypoints_list: List["Keypoints"]) -> "Keypoints":
96+
"""
97+
Concatenates a list of Keypoints into a single Keypoints
98+
99+
Arguments:
100+
keypoints_list (list[Keypoints])
101+
102+
Returns:
103+
Keypoints: the concatenated Keypoints
104+
"""
105+
assert isinstance(keypoints_list, (list, tuple))
106+
assert len(keypoints_list) > 0
107+
assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list)
108+
109+
cat_kpts = type(keypoints_list[0])(
110+
torch.cat([kpts.tensor for kpts in keypoints_list], dim=0)
111+
)
112+
return cat_kpts
113+
94114

95115
# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
96116
def _keypoints_to_heatmap(

tests/structures/test_keypoints.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
import unittest
3+
import torch
4+
5+
from detectron2.structures.keypoints import Keypoints
6+
7+
8+
class TestKeypoints(unittest.TestCase):
9+
def test_cat_keypoints(self):
10+
keypoints1 = Keypoints(torch.rand(2, 21, 3))
11+
keypoints2 = Keypoints(torch.rand(4, 21, 3))
12+
13+
cat_keypoints = keypoints1.cat([keypoints1, keypoints2])
14+
self.assertTrue(torch.all(cat_keypoints.tensor[:2] == keypoints1.tensor).item())
15+
self.assertTrue(torch.all(cat_keypoints.tensor[2:] == keypoints2.tensor).item())
16+
17+
18+
if __name__ == "__main__":
19+
unittest.main()

0 commit comments

Comments
 (0)