Skip to content

Commit 1ad5759

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
use topk instead of sort
Summary: now pytorch/pytorch#22812 is fixed Reviewed By: zhanghang1989 Differential Revision: D32610251 fbshipit-source-id: e099a2c53f71cca95af35aafc26ab59f9613c07b
1 parent 4606450 commit 1ad5759

File tree

3 files changed

+4
-14
lines changed

3 files changed

+4
-14
lines changed

detectron2/modeling/meta_arch/dense_detector.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,8 @@ def _decode_per_level_predictions(
208208

209209
# 2. Keep top k top scoring boxes only
210210
num_topk = min(topk_candidates, topk_idxs.size(0))
211-
# torch.sort is actually faster than .topk (https://github.com/pytorch/pytorch/issues/22812)
212-
pred_scores, idxs = pred_scores.sort(descending=True)
213-
pred_scores = pred_scores[:num_topk]
214-
topk_idxs = topk_idxs[idxs[:num_topk]]
211+
pred_scores, idxs = pred_scores.topk(num_topk)
212+
topk_idxs = topk_idxs[idxs]
215213

216214
anchor_idxs, classes_idxs = topk_idxs.unbind(dim=1)
217215

detectron2/modeling/proposal_generator/proposal_utils.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,7 @@ def find_top_rpn_proposals(
7272
else:
7373
num_proposals_i = min(Hi_Wi_A, pre_nms_topk)
7474

75-
# sort is faster than topk: https://github.com/pytorch/pytorch/issues/22812
76-
# topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
77-
logits_i, idx = logits_i.sort(descending=True, dim=1)
78-
topk_scores_i = logits_i.narrow(1, 0, num_proposals_i)
79-
topk_idx = idx.narrow(1, 0, num_proposals_i)
75+
topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
8076

8177
# each is N x topk
8278
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4

detectron2/modeling/proposal_generator/rrpn.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,7 @@ def find_top_rrpn_proposals(
7373
else:
7474
num_proposals_i = min(Hi_Wi_A, pre_nms_topk)
7575

76-
# sort is faster than topk (https://github.com/pytorch/pytorch/issues/22812)
77-
# topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
78-
logits_i, idx = logits_i.sort(descending=True, dim=1)
79-
topk_scores_i = logits_i[batch_idx, :num_proposals_i]
80-
topk_idx = idx[batch_idx, :num_proposals_i]
76+
topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1)
8177

8278
# each is N x topk
8379
topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 5

0 commit comments

Comments
 (0)