17
17
import functools
18
18
import tensorflow as tf
19
19
20
+ from deeplab .core import resnet_v1_beta
20
21
from deeplab .core import xception
21
- from nets . mobilenet import mobilenet as mobilenet_lib
22
+ from tensorflow . contrib . slim . nets import resnet_utils
22
23
from nets .mobilenet import mobilenet_v2
23
24
24
25
@@ -56,10 +57,12 @@ def _mobilenet_v2(net,
56
57
"""
57
58
with tf .variable_scope (
58
59
scope , 'MobilenetV2' , [net ], reuse = reuse ) as scope :
59
- return mobilenet_lib .mobilenet_base (
60
+ return mobilenet_v2 .mobilenet_base (
60
61
net ,
61
62
conv_defs = mobilenet_v2 .V2_DEF ,
62
- multiplier = depth_multiplier ,
63
+ depth_multiplier = depth_multiplier ,
64
+ min_depth = 8 if depth_multiplier == 1.0 else 1 ,
65
+ divisible_by = 8 if depth_multiplier == 1.0 else 1 ,
63
66
final_endpoint = final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT ,
64
67
output_stride = output_stride ,
65
68
scope = scope )
@@ -68,13 +71,25 @@ def _mobilenet_v2(net,
68
71
# A map from network name to network function.
69
72
networks_map = {
70
73
'mobilenet_v2' : _mobilenet_v2 ,
74
+ 'resnet_v1_50' : resnet_v1_beta .resnet_v1_50 ,
75
+ 'resnet_v1_50_beta' : resnet_v1_beta .resnet_v1_50_beta ,
76
+ 'resnet_v1_101' : resnet_v1_beta .resnet_v1_101 ,
77
+ 'resnet_v1_101_beta' : resnet_v1_beta .resnet_v1_101_beta ,
78
+ 'xception_41' : xception .xception_41 ,
71
79
'xception_65' : xception .xception_65 ,
80
+ 'xception_71' : xception .xception_71 ,
72
81
}
73
82
74
83
# A map from network name to network arg scope.
75
84
arg_scopes_map = {
76
85
'mobilenet_v2' : mobilenet_v2 .training_scope ,
86
+ 'resnet_v1_50' : resnet_utils .resnet_arg_scope ,
87
+ 'resnet_v1_50_beta' : resnet_utils .resnet_arg_scope ,
88
+ 'resnet_v1_101' : resnet_utils .resnet_arg_scope ,
89
+ 'resnet_v1_101_beta' : resnet_utils .resnet_arg_scope ,
90
+ 'xception_41' : xception .xception_arg_scope ,
77
91
'xception_65' : xception .xception_arg_scope ,
92
+ 'xception_71' : xception .xception_arg_scope ,
78
93
}
79
94
80
95
# Names for end point features.
@@ -86,19 +101,49 @@ def _mobilenet_v2(net,
86
101
# The provided checkpoint does not include decoder module.
87
102
DECODER_END_POINTS : None ,
88
103
},
104
+ 'resnet_v1_50' : {
105
+ DECODER_END_POINTS : ['block1/unit_2/bottleneck_v1/conv3' ],
106
+ },
107
+ 'resnet_v1_50_beta' : {
108
+ DECODER_END_POINTS : ['block1/unit_2/bottleneck_v1/conv3' ],
109
+ },
110
+ 'resnet_v1_101' : {
111
+ DECODER_END_POINTS : ['block1/unit_2/bottleneck_v1/conv3' ],
112
+ },
113
+ 'resnet_v1_101_beta' : {
114
+ DECODER_END_POINTS : ['block1/unit_2/bottleneck_v1/conv3' ],
115
+ },
116
+ 'xception_41' : {
117
+ DECODER_END_POINTS : [
118
+ 'entry_flow/block2/unit_1/xception_module/'
119
+ 'separable_conv2_pointwise' ,
120
+ ],
121
+ },
89
122
'xception_65' : {
90
123
DECODER_END_POINTS : [
91
124
'entry_flow/block2/unit_1/xception_module/'
92
125
'separable_conv2_pointwise' ,
93
126
],
94
- }
127
+ },
128
+ 'xception_71' : {
129
+ DECODER_END_POINTS : [
130
+ 'entry_flow/block2/unit_1/xception_module/'
131
+ 'separable_conv2_pointwise' ,
132
+ ],
133
+ },
95
134
}
96
135
97
136
# A map from feature extractor name to the network name scope used in the
98
137
# ImageNet pretrained versions of these models.
99
138
name_scope = {
100
139
'mobilenet_v2' : 'MobilenetV2' ,
140
+ 'resnet_v1_50' : 'resnet_v1_50' ,
141
+ 'resnet_v1_50_beta' : 'resnet_v1_50' ,
142
+ 'resnet_v1_101' : 'resnet_v1_101' ,
143
+ 'resnet_v1_101_beta' : 'resnet_v1_101' ,
144
+ 'xception_41' : 'xception_41' ,
101
145
'xception_65' : 'xception_65' ,
146
+ 'xception_71' : 'xception_71' ,
102
147
}
103
148
104
149
# Mean pixel value.
@@ -118,7 +163,13 @@ def _preprocess_zero_mean_unit_range(inputs):
118
163
119
164
_PREPROCESS_FN = {
120
165
'mobilenet_v2' : _preprocess_zero_mean_unit_range ,
166
+ 'resnet_v1_50' : _preprocess_subtract_imagenet_mean ,
167
+ 'resnet_v1_50_beta' : _preprocess_zero_mean_unit_range ,
168
+ 'resnet_v1_101' : _preprocess_subtract_imagenet_mean ,
169
+ 'resnet_v1_101_beta' : _preprocess_zero_mean_unit_range ,
170
+ 'xception_41' : _preprocess_zero_mean_unit_range ,
121
171
'xception_65' : _preprocess_zero_mean_unit_range ,
172
+ 'xception_71' : _preprocess_zero_mean_unit_range ,
122
173
}
123
174
124
175
@@ -140,7 +191,8 @@ def mean_pixel(model_variant=None):
140
191
Returns:
141
192
Mean pixel value.
142
193
"""
143
- if model_variant is None :
194
+ if model_variant in ['resnet_v1_50' ,
195
+ 'resnet_v1_101' ] or model_variant is None :
144
196
return _MEAN_RGB
145
197
else :
146
198
return [127.5 , 127.5 , 127.5 ]
@@ -159,7 +211,8 @@ def extract_features(images,
159
211
regularize_depthwise = False ,
160
212
preprocess_images = True ,
161
213
num_classes = None ,
162
- global_pool = False ):
214
+ global_pool = False ,
215
+ use_bounded_activations = False ):
163
216
"""Extracts features by the particular model_variant.
164
217
165
218
Args:
@@ -184,6 +237,8 @@ def extract_features(images,
184
237
to None for dense prediction tasks.
185
238
global_pool: Global pooling for image classification task. Defaults to
186
239
False, since dense prediction tasks do not use this.
240
+ use_bounded_activations: Whether or not to use bounded activations. Bounded
241
+ activations better lend themselves to quantized inference.
187
242
188
243
Returns:
189
244
features: A tensor of size [batch, feature_height, feature_width,
@@ -195,7 +250,25 @@ def extract_features(images,
195
250
Raises:
196
251
ValueError: Unrecognized model variant.
197
252
"""
198
- if 'xception' in model_variant :
253
+ if 'resnet' in model_variant :
254
+ arg_scope = arg_scopes_map [model_variant ](
255
+ weight_decay = weight_decay ,
256
+ batch_norm_decay = 0.95 ,
257
+ batch_norm_epsilon = 1e-5 ,
258
+ batch_norm_scale = True ,
259
+ activation_fn = tf .nn .relu6 if use_bounded_activations else tf .nn .relu )
260
+ features , end_points = get_network (
261
+ model_variant , preprocess_images , arg_scope )(
262
+ inputs = images ,
263
+ num_classes = num_classes ,
264
+ is_training = (is_training and fine_tune_batch_norm ),
265
+ global_pool = global_pool ,
266
+ output_stride = output_stride ,
267
+ multi_grid = multi_grid ,
268
+ reuse = reuse ,
269
+ scope = name_scope [model_variant ],
270
+ use_bounded_activations = use_bounded_activations )
271
+ elif 'xception' in model_variant :
199
272
arg_scope = arg_scopes_map [model_variant ](
200
273
weight_decay = weight_decay ,
201
274
batch_norm_decay = 0.9997 ,
0 commit comments