@@ -237,9 +237,6 @@ def multi_scale_logits(images,
237
237
# Setup default values.
238
238
if not image_pyramid :
239
239
image_pyramid = [1.0 ]
240
- if model_options .crop_size is None and model_options .add_image_level_feature :
241
- raise ValueError (
242
- 'Crop size must be specified for using image-level feature.' )
243
240
crop_height = (
244
241
model_options .crop_size [0 ]
245
242
if model_options .crop_size else tf .shape (images )[1 ])
@@ -378,18 +375,39 @@ def extract_features(images,
378
375
branch_logits = []
379
376
380
377
if model_options .add_image_level_feature :
381
- pool_height = scale_dimension (model_options .crop_size [0 ],
382
- 1. / model_options .output_stride )
383
- pool_width = scale_dimension (model_options .crop_size [1 ],
384
- 1. / model_options .output_stride )
385
- image_feature = slim .avg_pool2d (
386
- features , [pool_height , pool_width ], [pool_height , pool_width ],
387
- padding = 'VALID' )
378
+ if model_options .crop_size is not None :
379
+ image_pooling_crop_size = model_options .image_pooling_crop_size
380
+ # If image_pooling_crop_size is not specified, use crop_size.
381
+ if image_pooling_crop_size is None :
382
+ image_pooling_crop_size = model_options .crop_size
383
+ pool_height = scale_dimension (image_pooling_crop_size [0 ],
384
+ 1. / model_options .output_stride )
385
+ pool_width = scale_dimension (image_pooling_crop_size [1 ],
386
+ 1. / model_options .output_stride )
387
+ image_feature = slim .avg_pool2d (
388
+ features , [pool_height , pool_width ], [1 , 1 ], padding = 'VALID' )
389
+ resize_height = scale_dimension (model_options .crop_size [0 ],
390
+ 1. / model_options .output_stride )
391
+ resize_width = scale_dimension (model_options .crop_size [1 ],
392
+ 1. / model_options .output_stride )
393
+ else :
394
+ # If crop_size is None, we simply do global pooling.
395
+ pool_height = tf .shape (features )[1 ]
396
+ pool_width = tf .shape (features )[2 ]
397
+ image_feature = tf .reduce_mean (features , axis = [1 , 2 ])[:, tf .newaxis ,
398
+ tf .newaxis ]
399
+ resize_height = pool_height
400
+ resize_width = pool_width
388
401
image_feature = slim .conv2d (
389
402
image_feature , depth , 1 , scope = IMAGE_POOLING_SCOPE )
390
403
image_feature = tf .image .resize_bilinear (
391
- image_feature , [pool_height , pool_width ], align_corners = True )
392
- image_feature .set_shape ([None , pool_height , pool_width , depth ])
404
+ image_feature , [resize_height , resize_width ], align_corners = True )
405
+ # Set shape for resize_height/resize_width if they are not Tensor.
406
+ if isinstance (resize_height , tf .Tensor ):
407
+ resize_height = None
408
+ if isinstance (resize_width , tf .Tensor ):
409
+ resize_width = None
410
+ image_feature .set_shape ([None , resize_height , resize_width , depth ])
393
411
branch_logits .append (image_feature )
394
412
395
413
# Employ a 1x1 convolution.
@@ -453,9 +471,14 @@ def _get_logits(images,
453
471
fine_tune_batch_norm = fine_tune_batch_norm )
454
472
455
473
if model_options .decoder_output_stride is not None :
456
- decoder_height = scale_dimension (model_options .crop_size [0 ],
474
+ if model_options .crop_size is None :
475
+ height = tf .shape (images )[1 ]
476
+ width = tf .shape (images )[2 ]
477
+ else :
478
+ height , width = model_options .crop_size
479
+ decoder_height = scale_dimension (height ,
457
480
1.0 / model_options .decoder_output_stride )
458
- decoder_width = scale_dimension (model_options . crop_size [ 1 ] ,
481
+ decoder_width = scale_dimension (width ,
459
482
1.0 / model_options .decoder_output_stride )
460
483
features = refine_by_decoder (
461
484
features ,
@@ -557,8 +580,11 @@ def refine_by_decoder(features,
557
580
for j , feature in enumerate (decoder_features_list ):
558
581
decoder_features_list [j ] = tf .image .resize_bilinear (
559
582
feature , [decoder_height , decoder_width ], align_corners = True )
560
- decoder_features_list [j ].set_shape (
561
- [None , decoder_height , decoder_width , None ])
583
+ h = (None if isinstance (decoder_height , tf .Tensor )
584
+ else decoder_height )
585
+ w = (None if isinstance (decoder_width , tf .Tensor )
586
+ else decoder_width )
587
+ decoder_features_list [j ].set_shape ([None , h , w , None ])
562
588
decoder_depth = 256
563
589
if decoder_use_separable_conv :
564
590
decoder_features = split_separable_conv2d (
0 commit comments