-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathmink_resnet.py
140 lines (123 loc) · 5.17 KB
/
mink_resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
# Copyright (c) OpenRobotLab. All rights reserved.
# Follow https://github.com/NVIDIA/MinkowskiEngine/blob/master/examples/resnet.py # noqa
# and mmcv.cnn.ResNet
from typing import List, Union
try:
import MinkowskiEngine as ME
from MinkowskiEngine import SparseTensor
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
except ImportError:
# blocks are used in the static part of MinkResNet
ME = BasicBlock = Bottleneck = SparseTensor = None
import torch.nn as nn
from mmengine.model import BaseModule
from embodiedscan.registry import MODELS
@MODELS.register_module()
class MinkResNet(BaseModule):
r"""Minkowski ResNet backbone. See `4D Spatio-Temporal ConvNets
<https://arxiv.org/abs/1904.08755>`_ for more details.
Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (int): Number of input channels, 3 for RGB.
num_stages (int): Resnet stages. Defaults to 4.
pool (bool): Whether to add max pooling after first conv.
Defaults to True.
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth: int,
in_channels: int,
num_stages: int = 4,
pool: bool = True):
super(MinkResNet, self).__init__()
if ME is None:
raise ImportError(
'Please follow `get_started.md` to install MinkowskiEngine.`')
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert 4 >= num_stages >= 1
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.num_stages = num_stages
self.pool = pool
self.inplanes = 64
self.conv1 = ME.MinkowskiConvolution(in_channels,
self.inplanes,
kernel_size=3,
stride=2,
dimension=3)
# May be BatchNorm is better, but we follow original implementation.
self.norm1 = ME.MinkowskiInstanceNorm(self.inplanes)
self.relu = ME.MinkowskiReLU(inplace=True)
if self.pool:
self.maxpool = ME.MinkowskiMaxPooling(kernel_size=2,
stride=2,
dimension=3)
for i in range(len(stage_blocks)):
setattr(
self, f'layer{i + 1}',
self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2))
def init_weights(self):
"""Initialize weights."""
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(m.kernel,
mode='fan_out',
nonlinearity='relu')
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def _make_layer(self, block: Union[BasicBlock, Bottleneck], planes: int,
blocks: int, stride: int) -> nn.Module:
"""Make single level of residual blocks.
Args:
block (BasicBlock | Bottleneck): Residual block class.
planes (int): Number of convolution filters.
blocks (int): Number of blocks in the layers.
stride (int): Stride of the first convolutional layer.
Returns:
nn.Module: With residual blocks.
"""
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
ME.MinkowskiConvolution(self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
dimension=3),
ME.MinkowskiBatchNorm(planes * block.expansion))
layers = []
layers.append(
block(self.inplanes,
planes,
stride=stride,
downsample=downsample,
dimension=3))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, stride=1, dimension=3))
return nn.Sequential(*layers)
def forward(self, x: SparseTensor) -> List[SparseTensor]:
"""Forward pass of ResNet.
Args:
x (ME.SparseTensor): Input sparse tensor.
Returns:
list[ME.SparseTensor]: Output sparse tensors.
"""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
if self.pool:
x = self.maxpool(x)
outs = []
for i in range(self.num_stages):
x = getattr(self, f'layer{i + 1}')(x)
outs.append(x)
return outs