14
14
import torch
15
15
import torch .nn as nn
16
16
from wtorch .utils import *
17
- from wtorch .nn import *
18
- from einops import rearrange
19
17
20
18
21
19
BN_MOMENTUM = 0.1
22
20
logger = logging .getLogger (__name__ )
23
21
24
22
25
- def conv3x3 (in_planes , out_planes , stride = 1 , bias = False ):
23
+ def conv3x3 (in_planes , out_planes , stride = 1 ):
26
24
"""3x3 convolution with padding"""
27
25
return nn .Conv2d (in_planes , out_planes , kernel_size = 3 , stride = stride ,
28
- padding = 1 , bias = bias )
26
+ padding = 1 , bias = False )
29
27
30
- def get_norm (planes ,momentum = None ,type = "layer_norm" ):
31
- if type == "bn" :
32
- return nn .BatchNorm2d (planes , momentum = BN_MOMENTUM )
33
- elif type == "layer_norm" :
34
- return LayerNorm (planes )
35
-
36
- def get_activation_fn (* args ,** kwargs ):
37
- return nn .GELU ()
38
28
39
29
class BasicBlock (nn .Module ):
40
30
expansion = 1
41
31
42
32
def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
43
33
super (BasicBlock , self ).__init__ ()
44
34
self .conv1 = conv3x3 (inplanes , planes , stride )
45
- self .bn1 = get_norm (planes , momentum = BN_MOMENTUM )
46
- self .relu = get_activation_fn (inplace = True )
47
- self .conv2 = conv3x3 (planes , planes ,bias = True )
35
+ self .bn1 = nn .BatchNorm2d (planes , momentum = BN_MOMENTUM )
36
+ self .relu = nn .ReLU (inplace = True )
37
+ self .conv2 = conv3x3 (planes , planes )
38
+ self .bn2 = nn .BatchNorm2d (planes , momentum = BN_MOMENTUM )
48
39
self .downsample = downsample
49
40
self .stride = stride
50
- if self .stride == 1 :
51
- self .drop_path = nn .Dropout2d (p = 0.08 )
52
- else :
53
- self .drop_path = None
54
41
55
42
def forward (self , x ):
56
43
residual = x
57
44
58
45
out = self .conv1 (x )
59
46
out = self .bn1 (out )
47
+ out = self .relu (out )
60
48
61
49
out = self .conv2 (out )
50
+ out = self .bn2 (out )
62
51
63
52
if self .downsample is not None :
64
53
residual = self .downsample (x )
65
-
66
- if self .drop_path is not None :
67
- out = self .drop_path (out )
54
+
68
55
out += residual
69
56
out = self .relu (out )
70
57
@@ -77,30 +64,37 @@ class Bottleneck(nn.Module):
77
64
def __init__ (self , inplanes , planes , stride = 1 , downsample = None ):
78
65
super (Bottleneck , self ).__init__ ()
79
66
self .conv1 = nn .Conv2d (inplanes , planes , kernel_size = 1 , bias = False )
67
+ self .bn1 = nn .BatchNorm2d (planes , momentum = BN_MOMENTUM )
80
68
self .conv2 = nn .Conv2d (planes , planes , kernel_size = 3 , stride = stride ,
81
69
padding = 1 , bias = False )
82
- self .bn2 = get_norm (planes , momentum = BN_MOMENTUM )
70
+ self .bn2 = nn . BatchNorm2d (planes , momentum = BN_MOMENTUM )
83
71
self .conv3 = nn .Conv2d (planes , planes * self .expansion , kernel_size = 1 ,
84
- bias = True )
85
- self .relu = get_activation_fn (inplace = True )
72
+ bias = False )
73
+ self .bn3 = nn .BatchNorm2d (planes * self .expansion ,
74
+ momentum = BN_MOMENTUM )
75
+ self .relu = nn .ReLU (inplace = True )
86
76
self .downsample = downsample
87
77
self .stride = stride
88
78
89
79
def forward (self , x ):
90
80
residual = x
91
81
92
82
out = self .conv1 (x )
83
+ out = self .bn1 (out )
93
84
out = self .relu (out )
94
85
95
86
out = self .conv2 (out )
96
87
out = self .bn2 (out )
88
+ out = self .relu (out )
97
89
98
90
out = self .conv3 (out )
91
+ out = self .bn3 (out )
99
92
100
93
if self .downsample is not None :
101
94
residual = self .downsample (x )
102
95
103
96
out += residual
97
+ out = self .relu (out )
104
98
105
99
return out
106
100
@@ -121,7 +115,7 @@ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
121
115
self .branches = self ._make_branches (
122
116
num_branches , blocks , num_blocks , num_channels )
123
117
self .fuse_layers = self ._make_fuse_layers ()
124
- self .relu = get_activation_fn (True )
118
+ self .relu = nn . ReLU (True )
125
119
126
120
def _check_branches (self , num_branches , blocks , num_blocks ,
127
121
num_inchannels , num_channels ):
@@ -154,7 +148,7 @@ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
154
148
num_channels [branch_index ] * block .expansion ,
155
149
kernel_size = 1 , stride = stride , bias = False
156
150
),
157
- get_norm (
151
+ nn . BatchNorm2d (
158
152
num_channels [branch_index ] * block .expansion ,
159
153
momentum = BN_MOMENTUM
160
154
),
@@ -209,7 +203,7 @@ def _make_fuse_layers(self):
209
203
num_inchannels [i ],
210
204
1 , 1 , 0 , bias = False
211
205
),
212
- get_norm (num_inchannels [i ]),
206
+ nn . BatchNorm2d (num_inchannels [i ]),
213
207
nn .Upsample (scale_factor = 2 ** (j - i ), mode = 'nearest' )
214
208
)
215
209
)
@@ -227,7 +221,7 @@ def _make_fuse_layers(self):
227
221
num_outchannels_conv3x3 ,
228
222
3 , 2 , 1 , bias = False
229
223
),
230
- get_norm (num_outchannels_conv3x3 )
224
+ nn . BatchNorm2d (num_outchannels_conv3x3 )
231
225
)
232
226
)
233
227
else :
@@ -239,8 +233,8 @@ def _make_fuse_layers(self):
239
233
num_outchannels_conv3x3 ,
240
234
3 , 2 , 1 , bias = False
241
235
),
242
- get_norm (num_outchannels_conv3x3 ),
243
- get_activation_fn (True )
236
+ nn . BatchNorm2d (num_outchannels_conv3x3 ),
237
+ nn . ReLU (True )
244
238
)
245
239
)
246
240
fuse_layer .append (nn .Sequential (* conv3x3s ))
@@ -277,31 +271,6 @@ def forward(self, x):
277
271
'BOTTLENECK' : Bottleneck
278
272
}
279
273
280
- class FCBlock (nn .Module ):
281
- def __init__ (self ,channels ,width ,height ):
282
- super ().__init__ ()
283
- channels1 = width * height
284
- self .fc0 = nn .Linear (channels1 ,channels1 ,bias = False )
285
- self .norm0 = nn .LayerNorm (channels1 )
286
- self .fc1 = nn .Linear (channels ,channels ,bias = False )
287
- self .norm1 = nn .LayerNorm (channels )
288
- self .relu = get_activation_fn ()
289
-
290
- def forward (self ,x ):
291
- residual = x
292
- shape = x .shape
293
- x = rearrange (x ,'b c h w -> b c (h w)' )
294
- x = self .fc0 (x )
295
- x = self .norm0 (x )
296
- x = rearrange (x ,'b c s -> b s c' )
297
- x = self .fc1 (x )
298
- x = self .norm1 (x )
299
- x = rearrange (x ,'b s c -> b c s' )
300
- x = torch .reshape (x ,shape )
301
- x = x + residual
302
- x = self .relu (x )
303
- return x
304
-
305
274
306
275
class PoseHighResolutionNet (nn .Module ):
307
276
@@ -311,10 +280,13 @@ def __init__(self, cfg, **kwargs):
311
280
super (PoseHighResolutionNet , self ).__init__ ()
312
281
313
282
# stem net
314
- self .conv1 = nn .Conv2d (3 , 64 , kernel_size = 4 , stride = 4 ,
283
+ self .conv1 = nn .Conv2d (3 , 64 , kernel_size = 3 , stride = 2 , padding = 1 ,
315
284
bias = False )
316
- self .bn1 = get_norm (64 , momentum = BN_MOMENTUM )
317
- self .relu = get_activation_fn (inplace = True )
285
+ self .bn1 = nn .BatchNorm2d (64 , momentum = BN_MOMENTUM )
286
+ self .conv2 = nn .Conv2d (64 , 64 , kernel_size = 3 , stride = 2 , padding = 1 ,
287
+ bias = False )
288
+ self .bn2 = nn .BatchNorm2d (64 , momentum = BN_MOMENTUM )
289
+ self .relu = nn .ReLU (inplace = True )
318
290
self .layer1 = self ._make_layer (Bottleneck , 64 , 4 )
319
291
320
292
self .stage2_cfg = extra ['STAGE2' ]
@@ -349,8 +321,6 @@ def __init__(self, cfg, **kwargs):
349
321
self .stage4 , pre_stage_channels = self ._make_stage (
350
322
self .stage4_cfg , num_channels , multi_scale_output = False )
351
323
352
- self .fc_block0 = FCBlock (pre_stage_channels [0 ],48 ,64 )
353
- self .fc_block1 = FCBlock (pre_stage_channels [0 ],48 ,64 )
354
324
self .final_layer = nn .Conv2d (
355
325
in_channels = pre_stage_channels [0 ],
356
326
out_channels = cfg ['MODEL' ]['NUM_JOINTS' ],
@@ -378,8 +348,8 @@ def _make_transition_layer(
378
348
num_channels_cur_layer [i ],
379
349
3 , 1 , 1 , bias = False
380
350
),
381
- get_norm (num_channels_cur_layer [i ]),
382
- get_activation_fn (inplace = True )
351
+ nn . BatchNorm2d (num_channels_cur_layer [i ]),
352
+ nn . ReLU (inplace = True )
383
353
)
384
354
)
385
355
else :
@@ -395,8 +365,8 @@ def _make_transition_layer(
395
365
nn .Conv2d (
396
366
inchannels , outchannels , 3 , 2 , 1 , bias = False
397
367
),
398
- get_norm (outchannels ),
399
- get_activation_fn (inplace = True )
368
+ nn . BatchNorm2d (outchannels ),
369
+ nn . ReLU (inplace = True )
400
370
)
401
371
)
402
372
transition_layers .append (nn .Sequential (* conv3x3s ))
@@ -411,7 +381,7 @@ def _make_layer(self, block, planes, blocks, stride=1):
411
381
self .inplanes , planes * block .expansion ,
412
382
kernel_size = 1 , stride = stride , bias = False
413
383
),
414
- get_norm (planes * block .expansion , momentum = BN_MOMENTUM ),
384
+ nn . BatchNorm2d (planes * block .expansion , momentum = BN_MOMENTUM ),
415
385
)
416
386
417
387
layers = []
@@ -468,6 +438,9 @@ def forward(self, x):
468
438
x = self .conv1 (x )
469
439
x = self .bn1 (x )
470
440
x = self .relu (x )
441
+ x = self .conv2 (x )
442
+ x = self .bn2 (x )
443
+ x = self .relu (x )
471
444
x = self .layer1 (x )
472
445
473
446
x_list = []
@@ -493,9 +466,8 @@ def forward(self, x):
493
466
else :
494
467
x_list .append (y_list [i ])
495
468
y_list = self .stage4 (x_list )
496
- x = self .fc_block0 (y_list [0 ])
497
- x = self .fc_block1 (x )
498
- x = self .final_layer (x )
469
+
470
+ x = self .final_layer (y_list [0 ])
499
471
500
472
return x
501
473
@@ -508,7 +480,7 @@ def init_weights(self, pretrained=''):
508
480
for name , _ in m .named_parameters ():
509
481
if name in ['bias' ]:
510
482
nn .init .constant_ (m .bias , 0 )
511
- elif isinstance (m , LayerNorm ):
483
+ elif isinstance (m , nn . BatchNorm2d ):
512
484
nn .init .constant_ (m .weight , 1 )
513
485
nn .init .constant_ (m .bias , 0 )
514
486
elif isinstance (m , nn .ConvTranspose2d ):
0 commit comments