@@ -822,32 +822,19 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
822
822
if (w <= h and w == size_w ) or (h <= w and h == size_h ):
823
823
return img
824
824
825
- # make image NCHW
826
- need_squeeze = False
827
- if img .ndim < 4 :
828
- img = img .unsqueeze (dim = 0 )
829
- need_squeeze = True
830
-
831
825
mode = _interpolation_modes [interpolation ]
832
826
833
- out_dtype = img .dtype
834
- need_cast = False
835
- if img .dtype not in (torch .float32 , torch .float64 ):
836
- need_cast = True
837
- img = img .to (torch .float32 )
827
+ img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , [torch .float32 , torch .float64 ])
838
828
839
829
# Define align_corners to avoid warnings
840
830
align_corners = False if mode in ["bilinear" , "bicubic" ] else None
841
831
842
832
img = interpolate (img , size = [size_h , size_w ], mode = mode , align_corners = align_corners )
843
833
844
- if need_squeeze :
845
- img = img .squeeze ( dim = 0 )
834
+ if mode == "bicubic" and out_dtype == torch . uint8 :
835
+ img = img .clamp ( min = 0 , max = 255 )
846
836
847
- if need_cast :
848
- if mode == "bicubic" :
849
- img = img .clamp (min = 0 , max = 255 )
850
- img = img .to (out_dtype )
837
+ img = _cast_squeeze_out (img , need_cast = need_cast , need_squeeze = need_squeeze , out_dtype = out_dtype )
851
838
852
839
return img
853
840
@@ -879,7 +866,7 @@ def _assert_grid_transform_inputs(
879
866
raise ValueError ("Resampling mode '{}' is unsupported with Tensor input" .format (resample ))
880
867
881
868
882
- def _cast_squeeze_in (img : Tensor , req_dtype : torch .dtype ) -> Tuple [Tensor , bool , bool , torch .dtype ]:
869
+ def _cast_squeeze_in (img : Tensor , req_dtypes : List [ torch .dtype ] ) -> Tuple [Tensor , bool , bool , torch .dtype ]:
883
870
need_squeeze = False
884
871
# make image NCHW
885
872
if img .ndim < 4 :
@@ -888,8 +875,9 @@ def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool,
888
875
889
876
out_dtype = img .dtype
890
877
need_cast = False
891
- if out_dtype != req_dtype :
878
+ if out_dtype not in req_dtypes :
892
879
need_cast = True
880
+ req_dtype = req_dtypes [0 ]
893
881
img = img .to (req_dtype )
894
882
return img , need_cast , need_squeeze , out_dtype
895
883
@@ -899,15 +887,17 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
899
887
img = img .squeeze (dim = 0 )
900
888
901
889
if need_cast :
902
- # it is better to round before cast
903
- img = torch .round (img ).to (out_dtype )
890
+ if out_dtype in (torch .uint8 , torch .int8 , torch .int16 , torch .int32 , torch .int64 ):
891
+ # it is better to round before cast
892
+ img = torch .round (img )
893
+ img = img .to (out_dtype )
904
894
905
895
return img
906
896
907
897
908
898
def _apply_grid_transform (img : Tensor , grid : Tensor , mode : str ) -> Tensor :
909
899
910
- img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , grid .dtype )
900
+ img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , [ grid .dtype , ] )
911
901
912
902
if img .shape [0 ] > 1 :
913
903
# Apply same grid to a batch of images
@@ -1168,7 +1158,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
1168
1158
kernel = _get_gaussian_kernel2d (kernel_size , sigma , dtype = dtype , device = img .device )
1169
1159
kernel = kernel .expand (img .shape [- 3 ], 1 , kernel .shape [0 ], kernel .shape [1 ])
1170
1160
1171
- img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , kernel .dtype )
1161
+ img , need_cast , need_squeeze , out_dtype = _cast_squeeze_in (img , [ kernel .dtype , ] )
1172
1162
1173
1163
# padding = (left, right, top, bottom)
1174
1164
padding = [kernel_size [0 ] // 2 , kernel_size [0 ] // 2 , kernel_size [1 ] // 2 , kernel_size [1 ] // 2 ]
0 commit comments