@@ -324,7 +324,7 @@ def build_nasnet_cifar(
324
324
325
325
326
326
def build_nasnet_mobile (images , num_classes ,
327
- is_training = True , is_batchnorm_training = True ,
327
+ is_training = True ,
328
328
final_endpoint = None ):
329
329
"""Build NASNet Mobile model for the ImageNet Dataset."""
330
330
hparams = _mobile_imagenet_config ()
@@ -348,32 +348,31 @@ def build_nasnet_mobile(images, num_classes,
348
348
reduction_cell = nasnet_utils .NasNetAReductionCell (
349
349
hparams .num_conv_filters , hparams .drop_path_keep_prob ,
350
350
total_num_cells , hparams .total_training_steps )
351
- with arg_scope ([slim .dropout , nasnet_utils .drop_path ],
351
+ with arg_scope ([slim .dropout , nasnet_utils .drop_path , slim . batch_norm ],
352
352
is_training = is_training ):
353
- with arg_scope ([slim .batch_norm ], is_training = is_batchnorm_training ):
354
- with arg_scope ([slim .avg_pool2d ,
355
- slim .max_pool2d ,
356
- slim .conv2d ,
357
- slim .batch_norm ,
358
- slim .separable_conv2d ,
359
- nasnet_utils .factorized_reduction ,
360
- nasnet_utils .global_avg_pool ,
361
- nasnet_utils .get_channel_index ,
362
- nasnet_utils .get_channel_dim ],
363
- data_format = hparams .data_format ):
364
- return _build_nasnet_base (images ,
365
- normal_cell = normal_cell ,
366
- reduction_cell = reduction_cell ,
367
- num_classes = num_classes ,
368
- hparams = hparams ,
369
- is_training = is_training ,
370
- stem_type = 'imagenet' ,
371
- final_endpoint = final_endpoint )
353
+ with arg_scope ([slim .avg_pool2d ,
354
+ slim .max_pool2d ,
355
+ slim .conv2d ,
356
+ slim .batch_norm ,
357
+ slim .separable_conv2d ,
358
+ nasnet_utils .factorized_reduction ,
359
+ nasnet_utils .global_avg_pool ,
360
+ nasnet_utils .get_channel_index ,
361
+ nasnet_utils .get_channel_dim ],
362
+ data_format = hparams .data_format ):
363
+ return _build_nasnet_base (images ,
364
+ normal_cell = normal_cell ,
365
+ reduction_cell = reduction_cell ,
366
+ num_classes = num_classes ,
367
+ hparams = hparams ,
368
+ is_training = is_training ,
369
+ stem_type = 'imagenet' ,
370
+ final_endpoint = final_endpoint )
372
371
build_nasnet_mobile .default_image_size = 224
373
372
374
373
375
374
def build_nasnet_large (images , num_classes ,
376
- is_training = True , is_batchnorm_training = True ,
375
+ is_training = True ,
377
376
final_endpoint = None ):
378
377
"""Build NASNet Large model for the ImageNet Dataset."""
379
378
hparams = _large_imagenet_config (is_training = is_training )
@@ -397,27 +396,26 @@ def build_nasnet_large(images, num_classes,
397
396
reduction_cell = nasnet_utils .NasNetAReductionCell (
398
397
hparams .num_conv_filters , hparams .drop_path_keep_prob ,
399
398
total_num_cells , hparams .total_training_steps )
400
- with arg_scope ([slim .dropout , nasnet_utils .drop_path ],
399
+ with arg_scope ([slim .dropout , nasnet_utils .drop_path , slim . batch_norm ],
401
400
is_training = is_training ):
402
- with arg_scope ([slim .batch_norm ], is_training = is_batchnorm_training ):
403
- with arg_scope ([slim .avg_pool2d ,
404
- slim .max_pool2d ,
405
- slim .conv2d ,
406
- slim .batch_norm ,
407
- slim .separable_conv2d ,
408
- nasnet_utils .factorized_reduction ,
409
- nasnet_utils .global_avg_pool ,
410
- nasnet_utils .get_channel_index ,
411
- nasnet_utils .get_channel_dim ],
412
- data_format = hparams .data_format ):
413
- return _build_nasnet_base (images ,
414
- normal_cell = normal_cell ,
415
- reduction_cell = reduction_cell ,
416
- num_classes = num_classes ,
417
- hparams = hparams ,
418
- is_training = is_training ,
419
- stem_type = 'imagenet' ,
420
- final_endpoint = final_endpoint )
401
+ with arg_scope ([slim .avg_pool2d ,
402
+ slim .max_pool2d ,
403
+ slim .conv2d ,
404
+ slim .batch_norm ,
405
+ slim .separable_conv2d ,
406
+ nasnet_utils .factorized_reduction ,
407
+ nasnet_utils .global_avg_pool ,
408
+ nasnet_utils .get_channel_index ,
409
+ nasnet_utils .get_channel_dim ],
410
+ data_format = hparams .data_format ):
411
+ return _build_nasnet_base (images ,
412
+ normal_cell = normal_cell ,
413
+ reduction_cell = reduction_cell ,
414
+ num_classes = num_classes ,
415
+ hparams = hparams ,
416
+ is_training = is_training ,
417
+ stem_type = 'imagenet' ,
418
+ final_endpoint = final_endpoint )
421
419
build_nasnet_large .default_image_size = 331
422
420
423
421
0 commit comments