Commit ace739a 1 parent 0fb7f04 commit ace739a Copy full SHA for ace739a
File tree 2 files changed +39
-0
lines changed
2 files changed +39
-0
lines changed Original file line number Diff line number Diff line change @@ -91,6 +91,26 @@ def __repr__(self) -> str:
91
91
s += "num_instances={})" .format (len (self .tensor ))
92
92
return s
93
93
94
+ @staticmethod
95
+ def cat (keypoints_list : List ["Keypoints" ]) -> "Keypoints" :
96
+ """
97
+ Concatenates a list of Keypoints into a single Keypoints
98
+
99
+ Arguments:
100
+ keypoints_list (list[Keypoints])
101
+
102
+ Returns:
103
+ Keypoints: the concatenated Keypoints
104
+ """
105
+ assert isinstance (keypoints_list , (list , tuple ))
106
+ assert len (keypoints_list ) > 0
107
+ assert all (isinstance (keypoints , Keypoints ) for keypoints in keypoints_list )
108
+
109
+ cat_kpts = type (keypoints_list [0 ])(
110
+ torch .cat ([kpts .tensor for kpts in keypoints_list ], dim = 0 )
111
+ )
112
+ return cat_kpts
113
+
94
114
95
115
# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop)
96
116
def _keypoints_to_heatmap (
Original file line number Diff line number Diff line change
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ import unittest
3
+ import torch
4
+
5
+ from detectron2 .structures .keypoints import Keypoints
6
+
7
+
8
+ class TestKeypoints (unittest .TestCase ):
9
+ def test_cat_keypoints (self ):
10
+ keypoints1 = Keypoints (torch .rand (2 , 21 , 3 ))
11
+ keypoints2 = Keypoints (torch .rand (4 , 21 , 3 ))
12
+
13
+ cat_keypoints = keypoints1 .cat ([keypoints1 , keypoints2 ])
14
+ self .assertTrue (torch .all (cat_keypoints .tensor [:2 ] == keypoints1 .tensor ).item ())
15
+ self .assertTrue (torch .all (cat_keypoints .tensor [2 :] == keypoints2 .tensor ).item ())
16
+
17
+
18
+ if __name__ == "__main__" :
19
+ unittest .main ()
You can’t perform that action at this time.
0 commit comments