Skip to content

Commit b5d81f0

Browse files
oke-adityafmassa
andauthored
Add JIT tests (#4472)
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
1 parent a9d710a commit b5d81f0

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

test/test_ops.py

+24
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,14 @@ def area_check(box, expected, tolerance=1e-4):
959959
expected = torch.tensor([605113.875, 600495.1875, 592247.25])
960960
area_check(box_tensor, expected)
961961

962+
def test_box_area_jit(self):
963+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float)
964+
TOLERANCE = 1e-3
965+
expected = ops.box_area(box_tensor)
966+
scripted_fn = torch.jit.script(ops.box_area)
967+
scripted_area = scripted_fn(box_tensor)
968+
torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=TOLERANCE)
969+
962970

963971
class TestBoxIou:
964972
def test_iou(self):
@@ -980,6 +988,14 @@ def iou_check(box, expected, tolerance=1e-4):
980988
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
981989
iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-4)
982990

991+
def test_iou_jit(self):
992+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
993+
TOLERANCE = 1e-3
994+
expected = ops.box_iou(box_tensor, box_tensor)
995+
scripted_fn = torch.jit.script(ops.box_iou)
996+
scripted_iou = scripted_fn(box_tensor, box_tensor)
997+
torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE)
998+
983999

9841000
class TestGenBoxIou:
9851001
def test_gen_iou(self):
@@ -1001,6 +1017,14 @@ def gen_iou_check(box, expected, tolerance=1e-4):
10011017
expected = torch.tensor([[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]])
10021018
gen_iou_check(box_tensor, expected, tolerance=0.002 if dtype == torch.float16 else 1e-3)
10031019

1020+
def test_giou_jit(self):
1021+
box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]], dtype=torch.float)
1022+
TOLERANCE = 1e-3
1023+
expected = ops.generalized_box_iou(box_tensor, box_tensor)
1024+
scripted_fn = torch.jit.script(ops.generalized_box_iou)
1025+
scripted_iou = scripted_fn(box_tensor, box_tensor)
1026+
torch.testing.assert_close(scripted_iou, expected, rtol=0.0, atol=TOLERANCE)
1027+
10041028

10051029
class TestMasksToBoxes:
10061030
def test_masks_box(self):

0 commit comments

Comments
 (0)