@@ -99,7 +99,7 @@ def _script(obj):
99
99
return torch .jit .script (obj )
100
100
except Exception as error :
101
101
name = getattr (obj , "__name__" , obj .__class__ .__name__ )
102
- raise AssertionError (f"Trying to `torch.jit.script` ' { name } ' raised the error above." ) from error
102
+ raise AssertionError (f"Trying to `torch.jit.script` ` { name } ` raised the error above." ) from error
103
103
104
104
105
105
def _check_kernel_scripted_vs_eager (kernel , input , * args , rtol , atol , ** kwargs ):
@@ -553,10 +553,12 @@ def affine_bounding_boxes(bounding_boxes):
553
553
554
554
class TestResize :
555
555
INPUT_SIZE = (17 , 11 )
556
- OUTPUT_SIZES = [17 , [17 ], (17 ,), [12 , 13 ], (12 , 13 )]
556
+ OUTPUT_SIZES = [17 , [17 ], (17 ,), None , [12 , 13 ], (12 , 13 )]
557
557
558
558
def _make_max_size_kwarg (self , * , use_max_size , size ):
559
- if use_max_size :
559
+ if size is None :
560
+ max_size = min (list (self .INPUT_SIZE ))
561
+ elif use_max_size :
560
562
if not (isinstance (size , int ) or len (size ) == 1 ):
561
563
# This would result in an `ValueError`
562
564
return None
@@ -568,10 +570,13 @@ def _make_max_size_kwarg(self, *, use_max_size, size):
568
570
return dict (max_size = max_size )
569
571
570
572
def _compute_output_size (self , * , input_size , size , max_size ):
571
- if not (isinstance (size , int ) or len (size ) == 1 ):
573
+ if size is None :
574
+ size = max_size
575
+
576
+ elif not (isinstance (size , int ) or len (size ) == 1 ):
572
577
return tuple (size )
573
578
574
- if not isinstance (size , int ):
579
+ elif not isinstance (size , int ):
575
580
size = size [0 ]
576
581
577
582
old_height , old_width = input_size
@@ -658,10 +663,13 @@ def test_kernel_video(self):
658
663
[make_image_tensor , make_image_pil , make_image , make_bounding_boxes , make_segmentation_mask , make_video ],
659
664
)
660
665
def test_functional (self , size , make_input ):
666
+ max_size_kwarg = self ._make_max_size_kwarg (use_max_size = size is None , size = size )
667
+
661
668
check_functional (
662
669
F .resize ,
663
670
make_input (self .INPUT_SIZE ),
664
671
size = size ,
672
+ ** max_size_kwarg ,
665
673
antialias = True ,
666
674
check_scripted_smoke = not isinstance (size , int ),
667
675
)
@@ -695,11 +703,13 @@ def test_functional_signature(self, kernel, input_type):
695
703
],
696
704
)
697
705
def test_transform (self , size , device , make_input ):
706
+ max_size_kwarg = self ._make_max_size_kwarg (use_max_size = size is None , size = size )
707
+
698
708
check_transform (
699
- transforms .Resize (size = size , antialias = True ),
709
+ transforms .Resize (size = size , ** max_size_kwarg , antialias = True ),
700
710
make_input (self .INPUT_SIZE , device = device ),
701
711
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
702
- check_v1_compatibility = dict (rtol = 0 , atol = 1 ),
712
+ check_v1_compatibility = dict (rtol = 0 , atol = 1 ) if size is not None else False ,
703
713
)
704
714
705
715
def _check_output_size (self , input , output , * , size , max_size ):
@@ -801,7 +811,11 @@ def test_functional_pil_antialias_warning(self):
801
811
],
802
812
)
803
813
def test_max_size_error (self , size , make_input ):
804
- if isinstance (size , int ) or len (size ) == 1 :
814
+ if size is None :
815
+ # value can be anything other than an integer
816
+ max_size = None
817
+ match = "max_size must be an integer when size is None"
818
+ elif isinstance (size , int ) or len (size ) == 1 :
805
819
max_size = (size if isinstance (size , int ) else size [0 ]) - 1
806
820
match = "must be strictly greater than the requested size"
807
821
else :
@@ -812,6 +826,37 @@ def test_max_size_error(self, size, make_input):
812
826
with pytest .raises (ValueError , match = match ):
813
827
F .resize (make_input (self .INPUT_SIZE ), size = size , max_size = max_size , antialias = True )
814
828
829
+ if isinstance (size , list ) and len (size ) != 1 :
830
+ with pytest .raises (ValueError , match = "max_size should only be passed if size is None or specifies" ):
831
+ F .resize (make_input (self .INPUT_SIZE ), size = size , max_size = 500 )
832
+
833
+ @pytest .mark .parametrize (
834
+ "input_size, max_size, expected_size" ,
835
+ [
836
+ ((10 , 10 ), 10 , (10 , 10 )),
837
+ ((10 , 20 ), 40 , (20 , 40 )),
838
+ ((20 , 10 ), 40 , (40 , 20 )),
839
+ ((10 , 20 ), 10 , (5 , 10 )),
840
+ ((20 , 10 ), 10 , (10 , 5 )),
841
+ ],
842
+ )
843
+ @pytest .mark .parametrize (
844
+ "make_input" ,
845
+ [
846
+ make_image_tensor ,
847
+ make_image_pil ,
848
+ make_image ,
849
+ make_bounding_boxes ,
850
+ make_segmentation_mask ,
851
+ make_detection_masks ,
852
+ make_video ,
853
+ ],
854
+ )
855
+ def test_resize_size_none (self , input_size , max_size , expected_size , make_input ):
856
+ img = make_input (input_size )
857
+ out = F .resize (img , size = None , max_size = max_size )
858
+ assert F .get_size (out )[- 2 :] == list (expected_size )
859
+
815
860
@pytest .mark .parametrize ("interpolation" , INTERPOLATION_MODES )
816
861
@pytest .mark .parametrize (
817
862
"make_input" ,
@@ -834,7 +879,7 @@ def test_interpolation_int(self, interpolation, make_input):
834
879
assert_equal (actual , expected )
835
880
836
881
def test_transform_unknown_size_error (self ):
837
- with pytest .raises (ValueError , match = "size can either be an integer or a sequence of one or two integers" ):
882
+ with pytest .raises (ValueError , match = "size can be an integer, a sequence of one or two integers, or None " ):
838
883
transforms .Resize (size = object ())
839
884
840
885
@pytest .mark .parametrize (
0 commit comments