Skip to content

Commit 7896ffd

Browse files
sidijjuNicolasHug
andauthored
Allow v2 Resize to resize longer edge exactly to max_size (#8459)
Co-authored-by: Nicolas Hug <contact@nicolas-hug.com>
1 parent 1023987 commit 7896ffd

File tree

4 files changed

+110
-38
lines changed

4 files changed

+110
-38
lines changed

test/test_transforms_v2.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def _script(obj):
9999
return torch.jit.script(obj)
100100
except Exception as error:
101101
name = getattr(obj, "__name__", obj.__class__.__name__)
102-
raise AssertionError(f"Trying to `torch.jit.script` '{name}' raised the error above.") from error
102+
raise AssertionError(f"Trying to `torch.jit.script` `{name}` raised the error above.") from error
103103

104104

105105
def _check_kernel_scripted_vs_eager(kernel, input, *args, rtol, atol, **kwargs):
@@ -553,10 +553,12 @@ def affine_bounding_boxes(bounding_boxes):
553553

554554
class TestResize:
555555
INPUT_SIZE = (17, 11)
556-
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
556+
OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]
557557

558558
def _make_max_size_kwarg(self, *, use_max_size, size):
559-
if use_max_size:
559+
if size is None:
560+
max_size = min(list(self.INPUT_SIZE))
561+
elif use_max_size:
560562
if not (isinstance(size, int) or len(size) == 1):
561563
# This would result in an `ValueError`
562564
return None
@@ -568,10 +570,13 @@ def _make_max_size_kwarg(self, *, use_max_size, size):
568570
return dict(max_size=max_size)
569571

570572
def _compute_output_size(self, *, input_size, size, max_size):
571-
if not (isinstance(size, int) or len(size) == 1):
573+
if size is None:
574+
size = max_size
575+
576+
elif not (isinstance(size, int) or len(size) == 1):
572577
return tuple(size)
573578

574-
if not isinstance(size, int):
579+
elif not isinstance(size, int):
575580
size = size[0]
576581

577582
old_height, old_width = input_size
@@ -658,10 +663,13 @@ def test_kernel_video(self):
658663
[make_image_tensor, make_image_pil, make_image, make_bounding_boxes, make_segmentation_mask, make_video],
659664
)
660665
def test_functional(self, size, make_input):
666+
max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)
667+
661668
check_functional(
662669
F.resize,
663670
make_input(self.INPUT_SIZE),
664671
size=size,
672+
**max_size_kwarg,
665673
antialias=True,
666674
check_scripted_smoke=not isinstance(size, int),
667675
)
@@ -695,11 +703,13 @@ def test_functional_signature(self, kernel, input_type):
695703
],
696704
)
697705
def test_transform(self, size, device, make_input):
706+
max_size_kwarg = self._make_max_size_kwarg(use_max_size=size is None, size=size)
707+
698708
check_transform(
699-
transforms.Resize(size=size, antialias=True),
709+
transforms.Resize(size=size, **max_size_kwarg, antialias=True),
700710
make_input(self.INPUT_SIZE, device=device),
701711
# atol=1 due to Resize v2 is using native uint8 interpolate path for bilinear and nearest modes
702-
check_v1_compatibility=dict(rtol=0, atol=1),
712+
check_v1_compatibility=dict(rtol=0, atol=1) if size is not None else False,
703713
)
704714

705715
def _check_output_size(self, input, output, *, size, max_size):
@@ -801,7 +811,11 @@ def test_functional_pil_antialias_warning(self):
801811
],
802812
)
803813
def test_max_size_error(self, size, make_input):
804-
if isinstance(size, int) or len(size) == 1:
814+
if size is None:
815+
# value can be anything other than an integer
816+
max_size = None
817+
match = "max_size must be an integer when size is None"
818+
elif isinstance(size, int) or len(size) == 1:
805819
max_size = (size if isinstance(size, int) else size[0]) - 1
806820
match = "must be strictly greater than the requested size"
807821
else:
@@ -812,6 +826,37 @@ def test_max_size_error(self, size, make_input):
812826
with pytest.raises(ValueError, match=match):
813827
F.resize(make_input(self.INPUT_SIZE), size=size, max_size=max_size, antialias=True)
814828

829+
if isinstance(size, list) and len(size) != 1:
830+
with pytest.raises(ValueError, match="max_size should only be passed if size is None or specifies"):
831+
F.resize(make_input(self.INPUT_SIZE), size=size, max_size=500)
832+
833+
@pytest.mark.parametrize(
834+
"input_size, max_size, expected_size",
835+
[
836+
((10, 10), 10, (10, 10)),
837+
((10, 20), 40, (20, 40)),
838+
((20, 10), 40, (40, 20)),
839+
((10, 20), 10, (5, 10)),
840+
((20, 10), 10, (10, 5)),
841+
],
842+
)
843+
@pytest.mark.parametrize(
844+
"make_input",
845+
[
846+
make_image_tensor,
847+
make_image_pil,
848+
make_image,
849+
make_bounding_boxes,
850+
make_segmentation_mask,
851+
make_detection_masks,
852+
make_video,
853+
],
854+
)
855+
def test_resize_size_none(self, input_size, max_size, expected_size, make_input):
856+
img = make_input(input_size)
857+
out = F.resize(img, size=None, max_size=max_size)
858+
assert F.get_size(out)[-2:] == list(expected_size)
859+
815860
@pytest.mark.parametrize("interpolation", INTERPOLATION_MODES)
816861
@pytest.mark.parametrize(
817862
"make_input",
@@ -834,7 +879,7 @@ def test_interpolation_int(self, interpolation, make_input):
834879
assert_equal(actual, expected)
835880

836881
def test_transform_unknown_size_error(self):
837-
with pytest.raises(ValueError, match="size can either be an integer or a sequence of one or two integers"):
882+
with pytest.raises(ValueError, match="size can be an integer, a sequence of one or two integers, or None"):
838883
transforms.Resize(size=object())
839884

840885
@pytest.mark.parametrize(

torchvision/transforms/functional.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,22 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
351351

352352

353353
def _compute_resized_output_size(
354-
image_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
354+
image_size: Tuple[int, int],
355+
size: Optional[List[int]],
356+
max_size: Optional[int] = None,
357+
allow_size_none: bool = False, # only True in v2
355358
) -> List[int]:
356-
if len(size) == 1: # specified size only for the smallest edge
357-
h, w = image_size
358-
short, long = (w, h) if w <= h else (h, w)
359+
h, w = image_size
360+
short, long = (w, h) if w <= h else (h, w)
361+
if size is None:
362+
if not allow_size_none:
363+
raise ValueError("This should never happen!!")
364+
if not isinstance(max_size, int):
365+
raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
366+
new_short, new_long = int(max_size * short / long), max_size
367+
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
368+
elif len(size) == 1: # specified size only for the smallest edge
359369
requested_new_short = size if isinstance(size, int) else size[0]
360-
361370
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
362371

363372
if max_size is not None:

torchvision/transforms/v2/_geometry.py

+29-14
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ class Resize(Transform):
7575
the image can have ``[..., C, H, W]`` shape. A bounding box can have ``[..., 4]`` shape.
7676
7777
Args:
78-
size (sequence or int): Desired output size. If size is a sequence like
79-
(h, w), output size will be matched to this. If size is an int,
80-
smaller edge of the image will be matched to this number.
81-
i.e, if height > width, then image will be rescaled to
82-
(size * height / width, size).
78+
size (sequence, int, or None): Desired
79+
output size.
80+
81+
- If size is a sequence like (h, w), output size will be matched to this.
82+
- If size is an int, smaller edge of the image will be matched to this
83+
number. i.e, if height > width, then image will be rescaled to
84+
(size * height / width, size).
85+
- If size is None, the output shape is determined by the ``max_size``
86+
parameter.
8387
8488
.. note::
8589
In torchscript mode size as single int is not supported, use a sequence of length 1: ``[size, ]``.
@@ -89,13 +93,21 @@ class Resize(Transform):
8993
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
9094
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
9195
max_size (int, optional): The maximum allowed for the longer edge of
92-
the resized image. If the longer edge of the image is greater
93-
than ``max_size`` after being resized according to ``size``,
94-
``size`` will be overruled so that the longer edge is equal to
95-
``max_size``.
96-
As a result, the smaller edge may be shorter than ``size``. This
97-
is only supported if ``size`` is an int (or a sequence of length
98-
1 in torchscript mode).
96+
the resized image.
97+
98+
- If ``size`` is an int: if the longer edge of the image is greater
99+
than ``max_size`` after being resized according to ``size``,
100+
``size`` will be overruled so that the longer edge is equal to
101+
``max_size``. As a result, the smaller edge may be shorter than
102+
``size``. This is only supported if ``size`` is an int (or a
103+
sequence of length 1 in torchscript mode).
104+
- If ``size`` is None: the longer edge of the image will be matched
105+
to max_size. i.e, if height > width, then image will be rescaled
106+
to (max_size, max_size * width / height).
107+
108+
This should be left to ``None`` (default) when ``size`` is a
109+
sequence.
110+
99111
antialias (bool, optional): Whether to apply antialiasing.
100112
It only affects **tensors** with bilinear or bicubic modes and it is
101113
ignored otherwise: on PIL images, antialiasing is always applied on
@@ -120,7 +132,7 @@ class Resize(Transform):
120132

121133
def __init__(
122134
self,
123-
size: Union[int, Sequence[int]],
135+
size: Union[int, Sequence[int], None],
124136
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
125137
max_size: Optional[int] = None,
126138
antialias: Optional[bool] = True,
@@ -131,9 +143,12 @@ def __init__(
131143
size = [size]
132144
elif isinstance(size, Sequence) and len(size) in {1, 2}:
133145
size = list(size)
146+
elif size is None:
147+
if not isinstance(max_size, int):
148+
raise ValueError(f"max_size must be an integer when size is None, but got {max_size} instead.")
134149
else:
135150
raise ValueError(
136-
f"size can either be an integer or a sequence of one or two integers, but got {size} instead."
151+
f"size can be an integer, a sequence of one or two integers, or None, but got {size} instead."
137152
)
138153
self.size = size
139154

torchvision/transforms/v2/functional/_geometry.py

+13-10
Original file line numberDiff line numberDiff line change
@@ -159,21 +159,21 @@ def vertical_flip_video(video: torch.Tensor) -> torch.Tensor:
159159

160160

161161
def _compute_resized_output_size(
162-
canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
162+
canvas_size: Tuple[int, int], size: Optional[List[int]], max_size: Optional[int] = None
163163
) -> List[int]:
164164
if isinstance(size, int):
165165
size = [size]
166-
elif max_size is not None and len(size) != 1:
166+
elif max_size is not None and size is not None and len(size) != 1:
167167
raise ValueError(
168-
"max_size should only be passed if size specifies the length of the smaller edge, "
168+
"max_size should only be passed if size is None or specifies the length of the smaller edge, "
169169
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
170170
)
171-
return __compute_resized_output_size(canvas_size, size=size, max_size=max_size)
171+
return __compute_resized_output_size(canvas_size, size=size, max_size=max_size, allow_size_none=True)
172172

173173

174174
def resize(
175175
inpt: torch.Tensor,
176-
size: List[int],
176+
size: Optional[List[int]],
177177
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
178178
max_size: Optional[int] = None,
179179
antialias: Optional[bool] = True,
@@ -206,7 +206,7 @@ def _do_native_uint8_resize_on_cpu(interpolation: InterpolationMode) -> bool:
206206
@_register_kernel_internal(resize, tv_tensors.Image)
207207
def resize_image(
208208
image: torch.Tensor,
209-
size: List[int],
209+
size: Optional[List[int]],
210210
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
211211
max_size: Optional[int] = None,
212212
antialias: Optional[bool] = True,
@@ -310,7 +310,7 @@ def __resize_image_pil_dispatch(
310310
return _resize_image_pil(image, size=size, interpolation=interpolation, max_size=max_size)
311311

312312

313-
def resize_mask(mask: torch.Tensor, size: List[int], max_size: Optional[int] = None) -> torch.Tensor:
313+
def resize_mask(mask: torch.Tensor, size: Optional[List[int]], max_size: Optional[int] = None) -> torch.Tensor:
314314
if mask.ndim < 3:
315315
mask = mask.unsqueeze(0)
316316
needs_squeeze = True
@@ -334,7 +334,10 @@ def _resize_mask_dispatch(
334334

335335

336336
def resize_bounding_boxes(
337-
bounding_boxes: torch.Tensor, canvas_size: Tuple[int, int], size: List[int], max_size: Optional[int] = None
337+
bounding_boxes: torch.Tensor,
338+
canvas_size: Tuple[int, int],
339+
size: Optional[List[int]],
340+
max_size: Optional[int] = None,
338341
) -> Tuple[torch.Tensor, Tuple[int, int]]:
339342
old_height, old_width = canvas_size
340343
new_height, new_width = _compute_resized_output_size(canvas_size, size=size, max_size=max_size)
@@ -353,7 +356,7 @@ def resize_bounding_boxes(
353356

354357
@_register_kernel_internal(resize, tv_tensors.BoundingBoxes, tv_tensor_wrapper=False)
355358
def _resize_bounding_boxes_dispatch(
356-
inpt: tv_tensors.BoundingBoxes, size: List[int], max_size: Optional[int] = None, **kwargs: Any
359+
inpt: tv_tensors.BoundingBoxes, size: Optional[List[int]], max_size: Optional[int] = None, **kwargs: Any
357360
) -> tv_tensors.BoundingBoxes:
358361
output, canvas_size = resize_bounding_boxes(
359362
inpt.as_subclass(torch.Tensor), inpt.canvas_size, size, max_size=max_size
@@ -364,7 +367,7 @@ def _resize_bounding_boxes_dispatch(
364367
@_register_kernel_internal(resize, tv_tensors.Video)
365368
def resize_video(
366369
video: torch.Tensor,
367-
size: List[int],
370+
size: Optional[List[int]],
368371
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
369372
max_size: Optional[int] = None,
370373
antialias: Optional[bool] = True,

0 commit comments

Comments
 (0)