Skip to content

Commit 0013d93

Browse files
authored
Port test_backbone_utils.py to pytest (#3991)
1 parent 182f80d commit 0013d93

File tree

1 file changed

+6
-20
lines changed

1 file changed

+6
-20
lines changed

test/test_backbone_utils.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,11 @@
1-
import unittest
2-
3-
41
import torch
52
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
63

4+
import pytest
75

8-
class ResnetFPNBackboneTester(unittest.TestCase):
9-
@classmethod
10-
def setUpClass(cls):
11-
cls.dtype = torch.float32
12-
13-
def test_resnet18_fpn_backbone(self):
14-
device = torch.device('cpu')
15-
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
16-
resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False)
17-
y = resnet18_fpn(x)
18-
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
196

20-
def test_resnet50_fpn_backbone(self):
21-
device = torch.device('cpu')
22-
x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device)
23-
resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False)
24-
y = resnet50_fpn(x)
25-
self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool'])
7+
@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50'))
8+
def test_resnet_fpn_backbone(backbone_name):
9+
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu')
10+
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
11+
assert list(y.keys()) == ['0', '1', '2', '3', 'pool']

0 commit comments

Comments
 (0)