7
7
from common_utils import set_rng_seed
8
8
from torchvision import models
9
9
from torchvision .models ._utils import IntermediateLayerGetter
10
- from torchvision .models .detection .backbone_utils import mobilenet_backbone , resnet_fpn_backbone
10
+ from torchvision .models .detection .backbone_utils import BackboneWithFPN , mobilenet_backbone , resnet_fpn_backbone
11
11
from torchvision .models .feature_extraction import create_feature_extractor , get_graph_node_names
12
12
13
13
@@ -19,7 +19,9 @@ def get_available_models():
19
19
@pytest .mark .parametrize ("backbone_name" , ("resnet18" , "resnet50" ))
20
20
def test_resnet_fpn_backbone (backbone_name ):
21
21
x = torch .rand (1 , 3 , 300 , 300 , dtype = torch .float32 , device = "cpu" )
22
- y = resnet_fpn_backbone (backbone_name = backbone_name , pretrained = False )(x )
22
+ model = resnet_fpn_backbone (backbone_name = backbone_name , pretrained = False )
23
+ assert isinstance (model , BackboneWithFPN )
24
+ y = model (x )
23
25
assert list (y .keys ()) == ["0" , "1" , "2" , "3" , "pool" ]
24
26
25
27
with pytest .raises (ValueError , match = r"Trainable layers should be in the range" ):
@@ -38,6 +40,10 @@ def test_mobilenet_backbone(backbone_name):
38
40
mobilenet_backbone (backbone_name , False , fpn = True , returned_layers = [- 1 , 0 , 1 , 2 ])
39
41
with pytest .raises (ValueError , match = r"Each returned layer should be in the range" ):
40
42
mobilenet_backbone (backbone_name , False , fpn = True , returned_layers = [3 , 4 , 5 , 6 ])
43
+ model_fpn = mobilenet_backbone (backbone_name , False , fpn = True )
44
+ assert isinstance (model_fpn , BackboneWithFPN )
45
+ model = mobilenet_backbone (backbone_name , False , fpn = False )
46
+ assert isinstance (model , torch .nn .Sequential )
41
47
42
48
43
49
# Needed by TestFxFeatureExtraction.test_leaf_module_and_function
0 commit comments