Skip to content

Commit 59dc938

Browse files
authored
perform out of bounds check for single values and two tuples in ColorJitter (#7133)
1 parent d509156 commit 59dc938

File tree

4 files changed

+18
-9
lines changed

4 files changed

+18
-9
lines changed

test/test_prototype_transforms_consistency.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def __init__(
317317
ArgsKwargs(saturation=(0.8, 0.9)),
318318
ArgsKwargs(hue=0.3),
319319
ArgsKwargs(hue=(-0.1, 0.2)),
320-
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.6),
320+
ArgsKwargs(brightness=0.1, contrast=0.4, saturation=0.5, hue=0.3),
321321
],
322322
closeness_kwargs={"atol": 1e-5, "rtol": 1e-5},
323323
),

test/test_transforms.py

+6
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,12 @@ def test_color_jitter():
17981798
color_jitter.__repr__()
17991799

18001800

1801+
@pytest.mark.parametrize("hue", [1, (-1, 1)])
1802+
def test_color_jitter_hue_out_of_bounds(hue):
1803+
with pytest.raises(ValueError, match=re.escape("hue values should be between (-0.5, 0.5)")):
1804+
transforms.ColorJitter(hue=hue)
1805+
1806+
18011807
@pytest.mark.parametrize("seed", range(10))
18021808
@pytest.mark.skipif(stats is None, reason="scipy.stats not available")
18031809
def test_random_erasing(seed):

torchvision/prototype/transforms/_color.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,12 @@ def _check_input(
7777
value = [center - value, center + value]
7878
if clip_first_on_zero:
7979
value[0] = max(value[0], 0.0)
80-
elif isinstance(value, collections.abc.Sequence) and len(value) == 2:
81-
if not bound[0] <= value[0] <= value[1] <= bound[1]:
82-
raise ValueError(f"{name} values should be between {bound}")
83-
else:
80+
elif not (isinstance(value, collections.abc.Sequence) and len(value) == 2):
8481
raise TypeError(f"{name} should be a single number or a sequence with length 2.")
8582

83+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
84+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
85+
8686
return None if value[0] == value[1] == center else (float(value[0]), float(value[1]))
8787

8888
@staticmethod

torchvision/transforms/transforms.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -1195,16 +1195,19 @@ def _check_input(self, value, name, center=1, bound=(0, float("inf")), clip_firs
11951195
if clip_first_on_zero:
11961196
value[0] = max(value[0], 0.0)
11971197
elif isinstance(value, (tuple, list)) and len(value) == 2:
1198-
if not bound[0] <= value[0] <= value[1] <= bound[1]:
1199-
raise ValueError(f"{name} values should be between {bound}")
1198+
value = [float(value[0]), float(value[1])]
12001199
else:
12011200
raise TypeError(f"{name} should be a single number or a list/tuple with length 2.")
12021201

1202+
if not bound[0] <= value[0] <= value[1] <= bound[1]:
1203+
raise ValueError(f"{name} values should be between {bound}, but got {value}.")
1204+
12031205
# if value is 0 or (1., 1.) for brightness/contrast/saturation
12041206
# or (0., 0.) for hue, do nothing
12051207
if value[0] == value[1] == center:
1206-
value = None
1207-
return value
1208+
return None
1209+
else:
1210+
return tuple(value)
12081211

12091212
@staticmethod
12101213
def get_params(

0 commit comments

Comments
 (0)