18
18
import tensorflow as tf
19
19
20
20
from deeplab .core import xception
21
+ from nets .mobilenet import mobilenet as mobilenet_lib
22
+ from nets .mobilenet import mobilenet_v2
21
23
22
24
23
25
slim = tf .contrib .slim
24
26
27
+ # Default end point for MobileNetv2.
28
+ _MOBILENET_V2_FINAL_ENDPOINT = 'layer_18'
29
+
30
+
31
+ def _mobilenet_v2 (net ,
32
+ depth_multiplier ,
33
+ output_stride ,
34
+ reuse = None ,
35
+ scope = None ,
36
+ final_endpoint = None ):
37
+ """Auxiliary function to add support for 'reuse' to mobilenet_v2.
38
+
39
+ Args:
40
+ net: Input tensor of shape [batch_size, height, width, channels].
41
+ depth_multiplier: Float multiplier for the depth (number of channels)
42
+ for all convolution ops. The value must be greater than zero. Typical
43
+ usage will be to set this value in (0, 1) to reduce the number of
44
+ parameters or computation cost of the model.
45
+ output_stride: An integer that specifies the requested ratio of input to
46
+ output spatial resolution. If not None, then we invoke atrous convolution
47
+ if necessary to prevent the network from reducing the spatial resolution
48
+ of the activation maps. Allowed values are 8 (accurate fully convolutional
49
+ mode), 16 (fast fully convolutional mode), 32 (classification mode).
50
+ reuse: Reuse model variables.
51
+ scope: Optional variable scope.
52
+ final_endpoint: The endpoint to construct the network up to.
53
+
54
+ Returns:
55
+ Features extracted by MobileNetv2.
56
+ """
57
+ with tf .variable_scope (
58
+ scope , 'MobilenetV2' , [net ], reuse = reuse ) as scope :
59
+ return mobilenet_lib .mobilenet_base (
60
+ net ,
61
+ conv_defs = mobilenet_v2 .V2_DEF ,
62
+ multiplier = depth_multiplier ,
63
+ final_endpoint = final_endpoint or _MOBILENET_V2_FINAL_ENDPOINT ,
64
+ output_stride = output_stride ,
65
+ scope = scope )
66
+
25
67
26
68
# A map from network name to network function.
27
69
networks_map = {
70
+ 'mobilenet_v2' : _mobilenet_v2 ,
28
71
'xception_65' : xception .xception_65 ,
29
72
}
30
73
31
74
# A map from network name to network arg scope.
32
75
arg_scopes_map = {
76
+ 'mobilenet_v2' : mobilenet_v2 .training_scope ,
33
77
'xception_65' : xception .xception_arg_scope ,
34
78
}
35
79
38
82
39
83
# A dictionary from network name to a map of end point features.
40
84
networks_to_feature_maps = {
85
+ 'mobilenet_v2' : {
86
+ # The provided checkpoint does not include decoder module.
87
+ DECODER_END_POINTS : None ,
88
+ },
41
89
'xception_65' : {
42
90
DECODER_END_POINTS : [
43
91
'entry_flow/block2/unit_1/xception_module/'
49
97
# A map from feature extractor name to the network name scope used in the
50
98
# ImageNet pretrained versions of these models.
51
99
name_scope = {
100
+ 'mobilenet_v2' : 'MobilenetV2' ,
52
101
'xception_65' : 'xception_65' ,
53
102
}
54
103
@@ -68,6 +117,7 @@ def _preprocess_zero_mean_unit_range(inputs):
68
117
69
118
70
119
_PREPROCESS_FN = {
120
+ 'mobilenet_v2' : _preprocess_zero_mean_unit_range ,
71
121
'xception_65' : _preprocess_zero_mean_unit_range ,
72
122
}
73
123
@@ -99,6 +149,8 @@ def mean_pixel(model_variant=None):
99
149
def extract_features (images ,
100
150
output_stride = 8 ,
101
151
multi_grid = None ,
152
+ depth_multiplier = 1.0 ,
153
+ final_endpoint = None ,
102
154
model_variant = None ,
103
155
weight_decay = 0.0001 ,
104
156
reuse = None ,
@@ -114,6 +166,9 @@ def extract_features(images,
114
166
images: A tensor of size [batch, height, width, channels].
115
167
output_stride: The ratio of input to output spatial resolution.
116
168
multi_grid: Employ a hierarchy of different atrous rates within network.
169
+ depth_multiplier: Float multiplier for the depth (number of channels)
170
+ for all convolution ops used in MobileNet.
171
+ final_endpoint: The MobileNet endpoint to construct the network up to.
117
172
model_variant: Model variant for feature extraction.
118
173
weight_decay: The weight decay for model variables.
119
174
reuse: Reuse the model variables or not.
@@ -159,7 +214,17 @@ def extract_features(images,
159
214
reuse = reuse ,
160
215
scope = name_scope [model_variant ])
161
216
elif 'mobilenet' in model_variant :
162
- raise ValueError ('MobileNetv2 support is coming soon.' )
217
+ arg_scope = arg_scopes_map [model_variant ](
218
+ is_training = (is_training and fine_tune_batch_norm ),
219
+ weight_decay = weight_decay )
220
+ features , end_points = get_network (
221
+ model_variant , preprocess_images , arg_scope )(
222
+ inputs = images ,
223
+ depth_multiplier = depth_multiplier ,
224
+ output_stride = output_stride ,
225
+ reuse = reuse ,
226
+ scope = name_scope [model_variant ],
227
+ final_endpoint = final_endpoint )
163
228
else :
164
229
raise ValueError ('Unknown model variant %s.' % model_variant )
165
230
0 commit comments