@@ -11,7 +11,8 @@ def __init__(self, cfg):
11
11
self .cfg = cfg
12
12
self .cls_headers = nn .ModuleList ()
13
13
self .reg_headers = nn .ModuleList ()
14
- for level , (boxes_per_location , out_channels ) in enumerate (zip (cfg .MODEL .PRIORS .BOXES_PER_LOCATION , cfg .MODEL .BACKBONE .OUT_CHANNELS )):
14
+ for level , (boxes_per_location , out_channels ) in enumerate (
15
+ zip (cfg .MODEL .PRIORS .BOXES_PER_LOCATION , cfg .MODEL .BACKBONE .OUT_CHANNELS )):
15
16
self .cls_headers .append (self .cls_block (level , out_channels , boxes_per_location ))
16
17
self .reg_headers .append (self .reg_block (level , out_channels , boxes_per_location ))
17
18
self .reset_parameters ()
@@ -36,7 +37,8 @@ def forward(self, features):
36
37
bbox_pred .append (reg_header (feature ).permute (0 , 2 , 3 , 1 ).contiguous ())
37
38
38
39
batch_size = features [0 ].shape [0 ]
39
- cls_logits = torch .cat ([c .view (c .shape [0 ], - 1 ) for c in cls_logits ], dim = 1 ).view (batch_size , - 1 , self .cfg .MODEL .NUM_CLASSES )
40
+ cls_logits = torch .cat ([c .view (c .shape [0 ], - 1 ) for c in cls_logits ], dim = 1 ).view (batch_size , - 1 ,
41
+ self .cfg .MODEL .NUM_CLASSES )
40
42
bbox_pred = torch .cat ([l .view (l .shape [0 ], - 1 ) for l in bbox_pred ], dim = 1 ).view (batch_size , - 1 , 4 )
41
43
42
44
return cls_logits , bbox_pred
@@ -45,7 +47,8 @@ def forward(self, features):
45
47
@registry .BOX_PREDICTORS .register ('SSDBoxPredictor' )
46
48
class SSDBoxPredictor (BoxPredictor ):
47
49
def cls_block (self , level , out_channels , boxes_per_location ):
48
- return nn .Conv2d (out_channels , boxes_per_location * self .cfg .MODEL .NUM_CLASSES , kernel_size = 3 , stride = 1 , padding = 1 )
50
+ return nn .Conv2d (out_channels , boxes_per_location * self .cfg .MODEL .NUM_CLASSES , kernel_size = 3 , stride = 1 ,
51
+ padding = 1 )
49
52
50
53
def reg_block (self , level , out_channels , boxes_per_location ):
51
54
return nn .Conv2d (out_channels , boxes_per_location * 4 , kernel_size = 3 , stride = 1 , padding = 1 )
@@ -57,10 +60,15 @@ def cls_block(self, level, out_channels, boxes_per_location):
57
60
num_levels = len (self .cfg .MODEL .BACKBONE .OUT_CHANNELS )
58
61
if level == num_levels - 1 :
59
62
return nn .Conv2d (out_channels , boxes_per_location * self .cfg .MODEL .NUM_CLASSES , kernel_size = 1 )
60
- return SeparableConv2d (out_channels , boxes_per_location * self .cfg .MODEL .NUM_CLASSES , kernel_size = 3 , stride = 1 , padding = 1 )
63
+ return SeparableConv2d (out_channels , boxes_per_location * self .cfg .MODEL .NUM_CLASSES , kernel_size = 3 , stride = 1 ,
64
+ padding = 1 )
61
65
62
66
def reg_block (self , level , out_channels , boxes_per_location ):
63
67
num_levels = len (self .cfg .MODEL .BACKBONE .OUT_CHANNELS )
64
68
if level == num_levels - 1 :
65
69
return nn .Conv2d (out_channels , boxes_per_location * 4 , kernel_size = 1 )
66
- return SeparableConv2d (out_channels , boxes_per_location * 4 , kernel_size = 3 , stride = 1 , padding = 1 )
70
+ return SeparableConv2d (out_channels , boxes_per_location * 4 , kernel_size = 3 , stride = 1 , padding = 1 )
71
+
72
+
73
+ def build_box_predictor (cfg ):
74
+ return registry .BOX_PREDICTORS [cfg .MODEL .BOX_HEAD .PREDICTOR ](cfg )
0 commit comments