Skip to content

Commit f655e6a

Browse files
authored
House keeping improvements: (#2964)
- fixed problem with error computation between results - refactored tensor cast for resize - fixed round usage
1 parent 46f6083 commit f655e6a

File tree

4 files changed

+19
-28
lines changed

4 files changed

+19
-28
lines changed

Diff for: test/common_utils.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,8 @@ def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None, agg_meth
352352
if np_pil_image.ndim == 2:
353353
np_pil_image = np_pil_image[:, :, None]
354354
pil_tensor = torch.as_tensor(np_pil_image.transpose((2, 0, 1))).to(tensor)
355-
err = getattr(torch, agg_method)(tensor - pil_tensor).item()
355+
# error value can be mean absolute error, max abs error
356+
err = getattr(torch, agg_method)(torch.abs(tensor - pil_tensor)).item()
356357
self.assertTrue(
357358
err < tol,
358359
msg="{}: err={}, tol={}: \n{}\nvs\n{}".format(msg, err, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])

Diff for: test/test_functional_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,8 @@ def test_adjust_hue(self):
352352
F_pil.adjust_hue,
353353
F_t.adjust_hue,
354354
[{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]],
355-
tol=0.1,
356-
agg_method="mean"
355+
tol=16.1,
356+
agg_method="max"
357357
)
358358

359359
def test_adjust_gamma(self):

Diff for: test/test_transforms_tensor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def test_color_jitter(self):
111111
for f in [0.2, 0.5, (-0.2, 0.3), [-0.4, 0.5]]:
112112
meth_kwargs = {"hue": f}
113113
self._test_class_op(
114-
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
114+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=16.1, agg_method="max"
115115
)
116116

117117
# All 4 parameters together
118118
meth_kwargs = {"brightness": 0.2, "contrast": 0.2, "saturation": 0.2, "hue": 0.2}
119119
self._test_class_op(
120-
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=0.1, agg_method="mean"
120+
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=12.1, agg_method="max"
121121
)
122122

123123
def test_pad(self):

Diff for: torchvision/transforms/functional_tensor.py

+13-23
Original file line numberDiff line numberDiff line change
@@ -822,32 +822,19 @@ def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
822822
if (w <= h and w == size_w) or (h <= w and h == size_h):
823823
return img
824824

825-
# make image NCHW
826-
need_squeeze = False
827-
if img.ndim < 4:
828-
img = img.unsqueeze(dim=0)
829-
need_squeeze = True
830-
831825
mode = _interpolation_modes[interpolation]
832826

833-
out_dtype = img.dtype
834-
need_cast = False
835-
if img.dtype not in (torch.float32, torch.float64):
836-
need_cast = True
837-
img = img.to(torch.float32)
827+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
838828

839829
# Define align_corners to avoid warnings
840830
align_corners = False if mode in ["bilinear", "bicubic"] else None
841831

842832
img = interpolate(img, size=[size_h, size_w], mode=mode, align_corners=align_corners)
843833

844-
if need_squeeze:
845-
img = img.squeeze(dim=0)
834+
if mode == "bicubic" and out_dtype == torch.uint8:
835+
img = img.clamp(min=0, max=255)
846836

847-
if need_cast:
848-
if mode == "bicubic":
849-
img = img.clamp(min=0, max=255)
850-
img = img.to(out_dtype)
837+
img = _cast_squeeze_out(img, need_cast=need_cast, need_squeeze=need_squeeze, out_dtype=out_dtype)
851838

852839
return img
853840

@@ -879,7 +866,7 @@ def _assert_grid_transform_inputs(
879866
raise ValueError("Resampling mode '{}' is unsupported with Tensor input".format(resample))
880867

881868

882-
def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool, bool, torch.dtype]:
869+
def _cast_squeeze_in(img: Tensor, req_dtypes: List[torch.dtype]) -> Tuple[Tensor, bool, bool, torch.dtype]:
883870
need_squeeze = False
884871
# make image NCHW
885872
if img.ndim < 4:
@@ -888,8 +875,9 @@ def _cast_squeeze_in(img: Tensor, req_dtype: torch.dtype) -> Tuple[Tensor, bool,
888875

889876
out_dtype = img.dtype
890877
need_cast = False
891-
if out_dtype != req_dtype:
878+
if out_dtype not in req_dtypes:
892879
need_cast = True
880+
req_dtype = req_dtypes[0]
893881
img = img.to(req_dtype)
894882
return img, need_cast, need_squeeze, out_dtype
895883

@@ -899,15 +887,17 @@ def _cast_squeeze_out(img: Tensor, need_cast: bool, need_squeeze: bool, out_dtyp
899887
img = img.squeeze(dim=0)
900888

901889
if need_cast:
902-
# it is better to round before cast
903-
img = torch.round(img).to(out_dtype)
890+
if out_dtype in (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64):
891+
# it is better to round before cast
892+
img = torch.round(img)
893+
img = img.to(out_dtype)
904894

905895
return img
906896

907897

908898
def _apply_grid_transform(img: Tensor, grid: Tensor, mode: str) -> Tensor:
909899

910-
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, grid.dtype)
900+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [grid.dtype, ])
911901

912902
if img.shape[0] > 1:
913903
# Apply same grid to a batch of images
@@ -1168,7 +1158,7 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
11681158
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
11691159
kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
11701160

1171-
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, kernel.dtype)
1161+
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [kernel.dtype, ])
11721162

11731163
# padding = (left, right, top, bottom)
11741164
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]

0 commit comments

Comments
 (0)