Skip to content

Commit 555200f

Browse files
kondelafacebook-github-bot
authored andcommittedJan 10, 2020
Refactor test_boxes, test_model_zoo, test_rpn (facebookresearch#672)
Summary: I've noticed that there is some inconsistency between different tests, i.e. some use python's builtin `assert` while others use `unittest.TestCase` asserts (a clear preference is for the latter). I went through `test_boxes.py`, `test_model_zoo.py`, and `test_rpn.py` and refactored it. Also some asserting did not work so I fixed that as well + did some minor refactoring in general. Pull Request resolved: facebookresearch#672 Differential Revision: D19352228 Pulled By: ppwwyyxx fbshipit-source-id: 0504324b919584d6dfd9f140ed402bb3e33adc02
1 parent ed1d74b commit 555200f

File tree

3 files changed

+35
-26
lines changed

3 files changed

+35
-26
lines changed
 

‎tests/test_boxes.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def test_box_convert_list(self):
1818
for tp in [list, tuple]:
1919
box = tp([5, 5, 10, 10])
2020
output = self._convert_xy_to_wh(box)
21-
self.assertTrue(isinstance(output, tp))
22-
self.assertTrue(output == tp([5, 5, 5, 5]))
21+
self.assertIsInstance(output, tp)
22+
self.assertEqual(output, tp([5, 5, 5, 5]))
2323

2424
with self.assertRaises(Exception):
2525
self._convert_xy_to_wh([box])
@@ -45,8 +45,8 @@ def test_box_convert_xywha_to_xyxy_list(self):
4545
for tp in [list, tuple]:
4646
box = tp([50, 50, 30, 20, 0])
4747
output = self._convert_xywha_to_xyxy(box)
48-
self.assertTrue(isinstance(output, tp))
49-
self.assertTrue(output == tp([35, 40, 65, 60]))
48+
self.assertIsInstance(output, tp)
49+
self.assertEqual(output, tp([35, 40, 65, 60]))
5050

5151
with self.assertRaises(Exception):
5252
self._convert_xywha_to_xyxy([box])
@@ -107,7 +107,7 @@ def test_pairwise_iou(self):
107107

108108
ious = pairwise_iou(Boxes(boxes1), Boxes(boxes2))
109109

110-
assert torch.allclose(ious, expected_ious)
110+
self.assertTrue(torch.allclose(ious, expected_ious))
111111

112112

113113
if __name__ == "__main__":

‎tests/test_model_zoo.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
class TestModelZoo(unittest.TestCase):
1212
def test_get_returns_model(self):
1313
model = model_zoo.get("Misc/scratch_mask_rcnn_R_50_FPN_3x_gn.yaml", trained=False)
14-
assert isinstance(model, GeneralizedRCNN), model
15-
assert isinstance(model.backbone, FPN), model.backbone
14+
self.assertIsInstance(model, GeneralizedRCNN)
15+
self.assertIsInstance(model.backbone, FPN)
1616

1717
def test_get_invalid_model(self):
1818
self.assertRaises(RuntimeError, model_zoo.get, "Invalid/config.yaml")

‎tests/test_rpn.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_rpn(self):
4141
"loss_rpn_loc": torch.tensor(0.0990132466),
4242
}
4343
for name in expected_losses.keys():
44-
assert torch.allclose(proposal_losses[name], expected_losses[name])
44+
self.assertTrue(torch.allclose(proposal_losses[name], expected_losses[name]))
4545

4646
expected_proposal_boxes = [
4747
Boxes(torch.tensor([[0, 0, 10, 10], [7.3365392685, 0, 10, 10]])),
@@ -63,13 +63,15 @@ def test_rpn(self):
6363
torch.tensor([0.1415634006, 0.0989848152, 0.0565387346, -0.0072308783, -0.0428492837]),
6464
]
6565

66-
for i in range(len(image_sizes)):
67-
assert len(proposals[i]) == len(expected_proposal_boxes[i])
68-
assert proposals[i].image_size == (image_sizes[i][0], image_sizes[i][1])
69-
assert torch.allclose(
70-
proposals[i].proposal_boxes.tensor, expected_proposal_boxes[i].tensor
66+
for proposal, expected_proposal_box, im_size, expected_objectness_logit in zip(
67+
proposals, expected_proposal_boxes, image_sizes, expected_objectness_logits
68+
):
69+
self.assertEqual(len(proposal), len(expected_proposal_box))
70+
self.assertEqual(proposal.image_size, im_size)
71+
self.assertTrue(
72+
torch.allclose(proposal.proposal_boxes.tensor, expected_proposal_box.tensor)
7173
)
72-
assert torch.allclose(proposals[i].objectness_logits, expected_objectness_logits[i])
74+
self.assertTrue(torch.allclose(proposal.objectness_logits, expected_objectness_logit))
7375

7476
def test_rrpn(self):
7577
torch.manual_seed(121)
@@ -103,7 +105,7 @@ def test_rrpn(self):
103105
"loss_rpn_loc": torch.tensor(0.1552739739),
104106
}
105107
for name in expected_losses.keys():
106-
assert torch.allclose(proposal_losses[name], expected_losses[name])
108+
self.assertTrue(torch.allclose(proposal_losses[name], expected_losses[name]))
107109

108110
expected_proposal_boxes = [
109111
RotatedBoxes(
@@ -185,25 +187,32 @@ def test_rrpn(self):
185187

186188
torch.set_printoptions(precision=8, sci_mode=False)
187189

188-
for i in range(len(image_sizes)):
189-
assert len(proposals[i]) == len(expected_proposal_boxes[i])
190-
assert proposals[i].image_size == (image_sizes[i][0], image_sizes[i][1])
190+
for proposal, expected_proposal_box, im_size, expected_objectness_logit in zip(
191+
proposals, expected_proposal_boxes, image_sizes, expected_objectness_logits
192+
):
193+
self.assertEqual(len(proposal), len(expected_proposal_box))
194+
self.assertEqual(proposal.image_size, im_size)
191195
# It seems that there's some randomness in the result across different machines:
192196
# This test can be run on a local machine for 100 times with exactly the same result,
193197
# However, a different machine might produce slightly different results,
194198
# thus the atol here.
195199
err_msg = "computed proposal boxes = {}, expected {}".format(
196-
proposals[i].proposal_boxes.tensor, expected_proposal_boxes[i].tensor
200+
proposal.proposal_boxes.tensor, expected_proposal_box.tensor
197201
)
198-
assert torch.allclose(
199-
proposals[i].proposal_boxes.tensor, expected_proposal_boxes[i].tensor, atol=1e-5
200-
), err_msg
202+
self.assertTrue(
203+
torch.allclose(
204+
proposal.proposal_boxes.tensor, expected_proposal_box.tensor, atol=1e-5
205+
),
206+
err_msg,
207+
)
208+
201209
err_msg = "computed objectness logits = {}, expected {}".format(
202-
proposals[i].objectness_logits, expected_objectness_logits[i]
210+
proposal.objectness_logits, expected_objectness_logit
211+
)
212+
self.assertTrue(
213+
torch.allclose(proposal.objectness_logits, expected_objectness_logit, atol=1e-5),
214+
err_msg,
203215
)
204-
assert torch.allclose(
205-
proposals[i].objectness_logits, expected_objectness_logits[i], atol=1e-5
206-
), err_msg
207216

208217

209218
if __name__ == "__main__":

0 commit comments

Comments
 (0)
Please sign in to comment.