Skip to content

Commit 090d823

Browse files
authored
Improve test of backbone utils (#5552)
1 parent a8bde78 commit 090d823

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

test/test_backbone_utils.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from common_utils import set_rng_seed
88
from torchvision import models
99
from torchvision.models._utils import IntermediateLayerGetter
10-
from torchvision.models.detection.backbone_utils import mobilenet_backbone, resnet_fpn_backbone
10+
from torchvision.models.detection.backbone_utils import BackboneWithFPN, mobilenet_backbone, resnet_fpn_backbone
1111
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
1212

1313

@@ -19,7 +19,9 @@ def get_available_models():
1919
@pytest.mark.parametrize("backbone_name", ("resnet18", "resnet50"))
2020
def test_resnet_fpn_backbone(backbone_name):
2121
x = torch.rand(1, 3, 300, 300, dtype=torch.float32, device="cpu")
22-
y = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)(x)
22+
model = resnet_fpn_backbone(backbone_name=backbone_name, pretrained=False)
23+
assert isinstance(model, BackboneWithFPN)
24+
y = model(x)
2325
assert list(y.keys()) == ["0", "1", "2", "3", "pool"]
2426

2527
with pytest.raises(ValueError, match=r"Trainable layers should be in the range"):
@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
3840
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[-1, 0, 1, 2])
3941
with pytest.raises(ValueError, match=r"Each returned layer should be in the range"):
4042
mobilenet_backbone(backbone_name, False, fpn=True, returned_layers=[3, 4, 5, 6])
43+
model_fpn = mobilenet_backbone(backbone_name, False, fpn=True)
44+
assert isinstance(model_fpn, BackboneWithFPN)
45+
model = mobilenet_backbone(backbone_name, False, fpn=False)
46+
assert isinstance(model, torch.nn.Sequential)
4147

4248

4349
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function

0 commit comments

Comments
 (0)