Skip to content

Commit 7d80ba6

Browse files
committed
fix(box): box_predictor无法导入
1 parent fef04ba commit 7d80ba6

File tree

5 files changed

+15
-12
lines changed

5 files changed

+15
-12
lines changed

py/ssd/models/box_head/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .build import build_box_head, build_box_predictor
1+
from .build import build_box_head

py/ssd/models/box_head/box_head.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from ssd.models import registry
55
from ssd.models.anchors.prior_box import PriorBox
6-
from ssd.models.box_head import build_box_predictor
6+
from ssd.models.box_head.box_predictor import build_box_predictor
77
from ssd.utils import box_utils
88
from .inference import PostProcessor
99
from .loss import MultiBoxLoss

py/ssd/models/box_head/box_predictor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ def __init__(self, cfg):
1111
self.cfg = cfg
1212
self.cls_headers = nn.ModuleList()
1313
self.reg_headers = nn.ModuleList()
14-
for level, (boxes_per_location, out_channels) in enumerate(zip(cfg.MODEL.PRIORS.BOXES_PER_LOCATION, cfg.MODEL.BACKBONE.OUT_CHANNELS)):
14+
for level, (boxes_per_location, out_channels) in enumerate(
15+
zip(cfg.MODEL.PRIORS.BOXES_PER_LOCATION, cfg.MODEL.BACKBONE.OUT_CHANNELS)):
1516
self.cls_headers.append(self.cls_block(level, out_channels, boxes_per_location))
1617
self.reg_headers.append(self.reg_block(level, out_channels, boxes_per_location))
1718
self.reset_parameters()
@@ -36,7 +37,8 @@ def forward(self, features):
3637
bbox_pred.append(reg_header(feature).permute(0, 2, 3, 1).contiguous())
3738

3839
batch_size = features[0].shape[0]
39-
cls_logits = torch.cat([c.view(c.shape[0], -1) for c in cls_logits], dim=1).view(batch_size, -1, self.cfg.MODEL.NUM_CLASSES)
40+
cls_logits = torch.cat([c.view(c.shape[0], -1) for c in cls_logits], dim=1).view(batch_size, -1,
41+
self.cfg.MODEL.NUM_CLASSES)
4042
bbox_pred = torch.cat([l.view(l.shape[0], -1) for l in bbox_pred], dim=1).view(batch_size, -1, 4)
4143

4244
return cls_logits, bbox_pred
@@ -45,7 +47,8 @@ def forward(self, features):
4547
@registry.BOX_PREDICTORS.register('SSDBoxPredictor')
4648
class SSDBoxPredictor(BoxPredictor):
4749
def cls_block(self, level, out_channels, boxes_per_location):
48-
return nn.Conv2d(out_channels, boxes_per_location * self.cfg.MODEL.NUM_CLASSES, kernel_size=3, stride=1, padding=1)
50+
return nn.Conv2d(out_channels, boxes_per_location * self.cfg.MODEL.NUM_CLASSES, kernel_size=3, stride=1,
51+
padding=1)
4952

5053
def reg_block(self, level, out_channels, boxes_per_location):
5154
return nn.Conv2d(out_channels, boxes_per_location * 4, kernel_size=3, stride=1, padding=1)
@@ -57,10 +60,15 @@ def cls_block(self, level, out_channels, boxes_per_location):
5760
num_levels = len(self.cfg.MODEL.BACKBONE.OUT_CHANNELS)
5861
if level == num_levels - 1:
5962
return nn.Conv2d(out_channels, boxes_per_location * self.cfg.MODEL.NUM_CLASSES, kernel_size=1)
60-
return SeparableConv2d(out_channels, boxes_per_location * self.cfg.MODEL.NUM_CLASSES, kernel_size=3, stride=1, padding=1)
63+
return SeparableConv2d(out_channels, boxes_per_location * self.cfg.MODEL.NUM_CLASSES, kernel_size=3, stride=1,
64+
padding=1)
6165

6266
def reg_block(self, level, out_channels, boxes_per_location):
6367
num_levels = len(self.cfg.MODEL.BACKBONE.OUT_CHANNELS)
6468
if level == num_levels - 1:
6569
return nn.Conv2d(out_channels, boxes_per_location * 4, kernel_size=1)
66-
return SeparableConv2d(out_channels, boxes_per_location * 4, kernel_size=3, stride=1, padding=1)
70+
return SeparableConv2d(out_channels, boxes_per_location * 4, kernel_size=3, stride=1, padding=1)
71+
72+
73+
def build_box_predictor(cfg):
74+
return registry.BOX_PREDICTORS[cfg.MODEL.BOX_HEAD.PREDICTOR](cfg)

py/ssd/models/box_head/build.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99

1010
from ssd.models import registry
1111
from .box_head import SSDBoxHead
12-
from .box_predictor import SSDBoxPredictor, SSDLiteBoxPredictor
1312

1413

1514
def build_box_head(cfg):
1615
return registry.BOX_HEADS[cfg.MODEL.BOX_HEAD.NAME](cfg)
17-
18-
19-
def build_box_predictor(cfg):
20-
return registry.BOX_PREDICTORS[cfg.MODEL.BOX_HEAD.PREDICTOR](cfg)
File renamed without changes.

0 commit comments

Comments
 (0)