Skip to content

Commit 9475032

Browse files
NickVeldleoxiaobin
authored andcommitted
unify addressing to cfg, reuse cfg['MODEL']['EXTRA']
some data retrieving from cfg looks like retrieving from dict, some like from Namespace's instance or custom class' instance "extra" is defined, why does the code ignore it later?
1 parent 015c946 commit 9475032

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

lib/models/pose_hrnet.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ class PoseHighResolutionNet(nn.Module):
275275

276276
def __init__(self, cfg, **kwargs):
277277
self.inplanes = 64
278-
extra = cfg.MODEL.EXTRA
278+
extra = cfg['MODEL']['EXTRA']
279279
super(PoseHighResolutionNet, self).__init__()
280280

281281
# stem net
@@ -288,7 +288,7 @@ def __init__(self, cfg, **kwargs):
288288
self.relu = nn.ReLU(inplace=True)
289289
self.layer1 = self._make_layer(Bottleneck, 64, 4)
290290

291-
self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2']
291+
self.stage2_cfg = extra['STAGE2']
292292
num_channels = self.stage2_cfg['NUM_CHANNELS']
293293
block = blocks_dict[self.stage2_cfg['BLOCK']]
294294
num_channels = [
@@ -298,7 +298,7 @@ def __init__(self, cfg, **kwargs):
298298
self.stage2, pre_stage_channels = self._make_stage(
299299
self.stage2_cfg, num_channels)
300300

301-
self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3']
301+
self.stage3_cfg = extra['STAGE3']
302302
num_channels = self.stage3_cfg['NUM_CHANNELS']
303303
block = blocks_dict[self.stage3_cfg['BLOCK']]
304304
num_channels = [
@@ -309,7 +309,7 @@ def __init__(self, cfg, **kwargs):
309309
self.stage3, pre_stage_channels = self._make_stage(
310310
self.stage3_cfg, num_channels)
311311

312-
self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4']
312+
self.stage4_cfg = extra['STAGE4']
313313
num_channels = self.stage4_cfg['NUM_CHANNELS']
314314
block = blocks_dict[self.stage4_cfg['BLOCK']]
315315
num_channels = [
@@ -322,13 +322,13 @@ def __init__(self, cfg, **kwargs):
322322

323323
self.final_layer = nn.Conv2d(
324324
in_channels=pre_stage_channels[0],
325-
out_channels=cfg.MODEL.NUM_JOINTS,
326-
kernel_size=extra.FINAL_CONV_KERNEL,
325+
out_channels=cfg['MODEL']['NUM_JOINTS'],
326+
kernel_size=extra['FINAL_CONV_KERNEL'],
327327
stride=1,
328-
padding=1 if extra.FINAL_CONV_KERNEL == 3 else 0
328+
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
329329
)
330330

331-
self.pretrained_layers = cfg['MODEL']['EXTRA']['PRETRAINED_LAYERS']
331+
self.pretrained_layers = extra['PRETRAINED_LAYERS']
332332

333333
def _make_transition_layer(
334334
self, num_channels_pre_layer, num_channels_cur_layer):
@@ -495,7 +495,7 @@ def init_weights(self, pretrained=''):
495495
def get_pose_net(cfg, is_train, **kwargs):
496496
model = PoseHighResolutionNet(cfg, **kwargs)
497497

498-
if is_train and cfg.MODEL.INIT_WEIGHTS:
499-
model.init_weights(cfg.MODEL.PRETRAINED)
498+
if is_train and cfg['MODEL']['INIT_WEIGHTS']:
499+
model.init_weights(cfg['MODEL']['PRETRAINED'])
500500

501501
return model

0 commit comments

Comments
 (0)