Skip to content

Commit cf8f014

Browse files
committed
Fix nasnet image classification and object detection
Fix nasnet image classification and object detection by moving the option to turn ON or OFF batch norm training into it's own arg_scope used only by detection
1 parent b3f04bc commit cf8f014

File tree

2 files changed

+59
-44
lines changed

2 files changed

+59
-44
lines changed

research/object_detection/models/faster_rcnn_nas_feature_extractor.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@
3030
slim = tf.contrib.slim
3131

3232

33+
def nasnet_large_arg_scope_for_detection(is_batch_norm_training=False):
34+
"""Defines the default arg scope for the NASNet-A Large for object detection.
35+
36+
This provides a small edit to switch batch norm training on and off.
37+
38+
Args:
39+
is_batch_norm_training: Boolean indicating whether to train with batch norm.
40+
41+
Returns:
42+
An `arg_scope` to use for the NASNet Large Model.
43+
"""
44+
imagenet_scope = nasnet.nasnet_large_arg_scope()
45+
with arg_scope(imagenet_scope):
46+
with arg_scope([slim.batch_norm], is_training=is_batch_norm_training) as sc:
47+
return sc
48+
49+
3350
# Note: This is largely a copy of _build_nasnet_base inside nasnet.py but
3451
# with special edits to remove instantiation of the stem and the special
3552
# ability to receive as input a pair of hidden states.
@@ -163,11 +180,11 @@ def _extract_proposal_features(self, preprocessed_inputs, scope):
163180
raise ValueError('`preprocessed_inputs` must be 4 dimensional, got a '
164181
'tensor of shape %s' % preprocessed_inputs.get_shape())
165182

166-
with slim.arg_scope(nasnet.nasnet_large_arg_scope()):
183+
with slim.arg_scope(nasnet_large_arg_scope_for_detection(
184+
is_batch_norm_training=self._train_batch_norm)):
167185
_, end_points = nasnet.build_nasnet_large(
168186
preprocessed_inputs, num_classes=None,
169187
is_training=self._is_training,
170-
is_batchnorm_training=self._train_batch_norm,
171188
final_endpoint='Cell_11')
172189

173190
# Note that both 'Cell_10' and 'Cell_11' have equal depth = 2016.

research/slim/nets/nasnet/nasnet.py

+40-42
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def build_nasnet_cifar(
324324

325325

326326
def build_nasnet_mobile(images, num_classes,
327-
is_training=True, is_batchnorm_training=True,
327+
is_training=True,
328328
final_endpoint=None):
329329
"""Build NASNet Mobile model for the ImageNet Dataset."""
330330
hparams = _mobile_imagenet_config()
@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes,
348348
reduction_cell = nasnet_utils.NasNetAReductionCell(
349349
hparams.num_conv_filters, hparams.drop_path_keep_prob,
350350
total_num_cells, hparams.total_training_steps)
351-
with arg_scope([slim.dropout, nasnet_utils.drop_path],
351+
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
352352
is_training=is_training):
353-
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training):
354-
with arg_scope([slim.avg_pool2d,
355-
slim.max_pool2d,
356-
slim.conv2d,
357-
slim.batch_norm,
358-
slim.separable_conv2d,
359-
nasnet_utils.factorized_reduction,
360-
nasnet_utils.global_avg_pool,
361-
nasnet_utils.get_channel_index,
362-
nasnet_utils.get_channel_dim],
363-
data_format=hparams.data_format):
364-
return _build_nasnet_base(images,
365-
normal_cell=normal_cell,
366-
reduction_cell=reduction_cell,
367-
num_classes=num_classes,
368-
hparams=hparams,
369-
is_training=is_training,
370-
stem_type='imagenet',
371-
final_endpoint=final_endpoint)
353+
with arg_scope([slim.avg_pool2d,
354+
slim.max_pool2d,
355+
slim.conv2d,
356+
slim.batch_norm,
357+
slim.separable_conv2d,
358+
nasnet_utils.factorized_reduction,
359+
nasnet_utils.global_avg_pool,
360+
nasnet_utils.get_channel_index,
361+
nasnet_utils.get_channel_dim],
362+
data_format=hparams.data_format):
363+
return _build_nasnet_base(images,
364+
normal_cell=normal_cell,
365+
reduction_cell=reduction_cell,
366+
num_classes=num_classes,
367+
hparams=hparams,
368+
is_training=is_training,
369+
stem_type='imagenet',
370+
final_endpoint=final_endpoint)
372371
build_nasnet_mobile.default_image_size = 224
373372

374373

375374
def build_nasnet_large(images, num_classes,
376-
is_training=True, is_batchnorm_training=True,
375+
is_training=True,
377376
final_endpoint=None):
378377
"""Build NASNet Large model for the ImageNet Dataset."""
379378
hparams = _large_imagenet_config(is_training=is_training)
@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes,
397396
reduction_cell = nasnet_utils.NasNetAReductionCell(
398397
hparams.num_conv_filters, hparams.drop_path_keep_prob,
399398
total_num_cells, hparams.total_training_steps)
400-
with arg_scope([slim.dropout, nasnet_utils.drop_path],
399+
with arg_scope([slim.dropout, nasnet_utils.drop_path, slim.batch_norm],
401400
is_training=is_training):
402-
with arg_scope([slim.batch_norm], is_training=is_batchnorm_training):
403-
with arg_scope([slim.avg_pool2d,
404-
slim.max_pool2d,
405-
slim.conv2d,
406-
slim.batch_norm,
407-
slim.separable_conv2d,
408-
nasnet_utils.factorized_reduction,
409-
nasnet_utils.global_avg_pool,
410-
nasnet_utils.get_channel_index,
411-
nasnet_utils.get_channel_dim],
412-
data_format=hparams.data_format):
413-
return _build_nasnet_base(images,
414-
normal_cell=normal_cell,
415-
reduction_cell=reduction_cell,
416-
num_classes=num_classes,
417-
hparams=hparams,
418-
is_training=is_training,
419-
stem_type='imagenet',
420-
final_endpoint=final_endpoint)
401+
with arg_scope([slim.avg_pool2d,
402+
slim.max_pool2d,
403+
slim.conv2d,
404+
slim.batch_norm,
405+
slim.separable_conv2d,
406+
nasnet_utils.factorized_reduction,
407+
nasnet_utils.global_avg_pool,
408+
nasnet_utils.get_channel_index,
409+
nasnet_utils.get_channel_dim],
410+
data_format=hparams.data_format):
411+
return _build_nasnet_base(images,
412+
normal_cell=normal_cell,
413+
reduction_cell=reduction_cell,
414+
num_classes=num_classes,
415+
hparams=hparams,
416+
is_training=is_training,
417+
stem_type='imagenet',
418+
final_endpoint=final_endpoint)
421419
build_nasnet_large.default_image_size = 331
422420

423421

0 commit comments

Comments
 (0)