File tree 3 files changed +4
-14
lines changed
3 files changed +4
-14
lines changed Original file line number Diff line number Diff line change @@ -208,10 +208,8 @@ def _decode_per_level_predictions(
208
208
209
209
# 2. Keep top k top scoring boxes only
210
210
num_topk = min (topk_candidates , topk_idxs .size (0 ))
211
- # torch.sort is actually faster than .topk (https://github.com/pytorch/pytorch/issues/22812)
212
- pred_scores , idxs = pred_scores .sort (descending = True )
213
- pred_scores = pred_scores [:num_topk ]
214
- topk_idxs = topk_idxs [idxs [:num_topk ]]
211
+ pred_scores , idxs = pred_scores .topk (num_topk )
212
+ topk_idxs = topk_idxs [idxs ]
215
213
216
214
anchor_idxs , classes_idxs = topk_idxs .unbind (dim = 1 )
217
215
Original file line number Diff line number Diff line change @@ -72,11 +72,7 @@ def find_top_rpn_proposals(
72
72
else :
73
73
num_proposals_i = min (Hi_Wi_A , pre_nms_topk )
74
74
75
- # sort is faster than topk: https://github.com/pytorch/pytorch/issues/22812
76
- # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
77
- logits_i , idx = logits_i .sort (descending = True , dim = 1 )
78
- topk_scores_i = logits_i .narrow (1 , 0 , num_proposals_i )
79
- topk_idx = idx .narrow (1 , 0 , num_proposals_i )
75
+ topk_scores_i , topk_idx = logits_i .topk (num_proposals_i , dim = 1 )
80
76
81
77
# each is N x topk
82
78
topk_proposals_i = proposals_i [batch_idx [:, None ], topk_idx ] # N x topk x 4
Original file line number Diff line number Diff line change @@ -73,11 +73,7 @@ def find_top_rrpn_proposals(
73
73
else :
74
74
num_proposals_i = min (Hi_Wi_A , pre_nms_topk )
75
75
76
- # sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812)
77
- # topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
78
- logits_i , idx = logits_i .sort (descending = True , dim = 1 )
79
- topk_scores_i = logits_i [batch_idx , :num_proposals_i ]
80
- topk_idx = idx [batch_idx , :num_proposals_i ]
76
+ topk_scores_i , topk_idx = logits_i .topk (num_proposals_i , dim = 1 )
81
77
82
78
# each is N x topk
83
79
topk_proposals_i = proposals_i [batch_idx [:, None ], topk_idx ] # N x topk x 5
You can’t perform that action at this time.
0 commit comments