|
1 |
| -import unittest |
2 |
| - |
3 |
| - |
4 | 1 | import torch
|
5 | 2 | from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
6 | 3 |
|
| 4 | +import pytest |
7 | 5 |
|
8 |
| -class ResnetFPNBackboneTester(unittest.TestCase): |
9 |
| - @classmethod |
10 |
| - def setUpClass(cls): |
11 |
| - cls.dtype = torch.float32 |
12 |
| - |
13 |
| - def test_resnet18_fpn_backbone(self): |
14 |
| - device = torch.device('cpu') |
15 |
| - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) |
16 |
| - resnet18_fpn = resnet_fpn_backbone(backbone_name='resnet18', pretrained=False) |
17 |
| - y = resnet18_fpn(x) |
18 |
| - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) |
19 | 6 |
|
20 |
| - def test_resnet50_fpn_backbone(self): |
21 |
| - device = torch.device('cpu') |
22 |
| - x = torch.rand(1, 3, 300, 300, dtype=self.dtype, device=device) |
23 |
| - resnet50_fpn = resnet_fpn_backbone(backbone_name='resnet50', pretrained=False) |
24 |
| - y = resnet50_fpn(x) |
25 |
| - self.assertEqual(list(y.keys()), ['0', '1', '2', '3', 'pool']) |
| 7 | +@pytest.mark.parametrize('backbone_name', ('resnet18', 'resnet50')) |
| 8 | +def test_resnet_fpn_backbone(backbone_name): |
| 9 | + x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device='cpu') |
| 10 | + y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x) |
| 11 | + assert list(y.keys()) == ['0', '1', '2', '3', 'pool'] |
0 commit comments