Skip to content

Commit f483e71

Browse files
authored
Added gray image support to adjust_saturation function (#4480)
* update channels parameter to every calling to check_functional_vs_PIL_vs_scripted * update adjust_saturation * update docstrings for functional transformations * parametrize channels * update docstring of ColorJitter class * move channels to class's parameter * remove testing channels for geometric transforms * revert redundant changes * revert redundant changes * update grayscale test cases for randaugment, autoaugment, trivialaugment * update docstrings of randaugment, autoaugment, trivialaugment * update docstring of ColorJitter * fix adjust_hue's docstring * change test equal tolerance * refactor grayscale tests * make get_grayscale_test_image private
1 parent 3e27eb2 commit f483e71

8 files changed

+96
-53
lines changed

test/common_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _create_data_batch(height=3, width=3, channels=3, num_samples=4, device="cpu
148148
return batch_tensor
149149

150150

151-
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
151+
assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=1e-6)
152152

153153

154154
def get_list_of_videos(tmpdir, num_videos=5, sizes=None, fps=None):

test/test_functional_tensor.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -681,57 +681,65 @@ def check_functional_vs_PIL_vs_scripted(
681681
@pytest.mark.parametrize('device', cpu_and_gpu())
682682
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
683683
@pytest.mark.parametrize('config', [{"brightness_factor": f} for f in (0.1, 0.5, 1.0, 1.34, 2.5)])
684-
def test_adjust_brightness(device, dtype, config):
684+
@pytest.mark.parametrize('channels', [1, 3])
685+
def test_adjust_brightness(device, dtype, config, channels):
685686
check_functional_vs_PIL_vs_scripted(
686687
F.adjust_brightness,
687688
F_pil.adjust_brightness,
688689
F_t.adjust_brightness,
689690
config,
690691
device,
691692
dtype,
693+
channels,
692694
)
693695

694696

695697
@pytest.mark.parametrize('device', cpu_and_gpu())
696698
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
697-
def test_invert(device, dtype):
699+
@pytest.mark.parametrize('channels', [1, 3])
700+
def test_invert(device, dtype, channels):
698701
check_functional_vs_PIL_vs_scripted(
699702
F.invert,
700703
F_pil.invert,
701704
F_t.invert,
702705
{},
703706
device,
704707
dtype,
708+
channels,
705709
tol=1.0,
706710
agg_method="max"
707711
)
708712

709713

710714
@pytest.mark.parametrize('device', cpu_and_gpu())
711715
@pytest.mark.parametrize('config', [{"bits": bits} for bits in range(0, 8)])
712-
def test_posterize(device, config):
716+
@pytest.mark.parametrize('channels', [1, 3])
717+
def test_posterize(device, config, channels):
713718
check_functional_vs_PIL_vs_scripted(
714719
F.posterize,
715720
F_pil.posterize,
716721
F_t.posterize,
717722
config,
718723
device,
719724
dtype=None,
725+
channels=channels,
720726
tol=1.0,
721727
agg_method="max",
722728
)
723729

724730

725731
@pytest.mark.parametrize('device', cpu_and_gpu())
726732
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0, 64, 128, 192, 255]])
727-
def test_solarize1(device, config):
733+
@pytest.mark.parametrize('channels', [1, 3])
734+
def test_solarize1(device, config, channels):
728735
check_functional_vs_PIL_vs_scripted(
729736
F.solarize,
730737
F_pil.solarize,
731738
F_t.solarize,
732739
config,
733740
device,
734741
dtype=None,
742+
channels=channels,
735743
tol=1.0,
736744
agg_method="max",
737745
)
@@ -740,14 +748,16 @@ def test_solarize1(device, config):
740748
@pytest.mark.parametrize('device', cpu_and_gpu())
741749
@pytest.mark.parametrize('dtype', (torch.float32, torch.float64))
742750
@pytest.mark.parametrize('config', [{"threshold": threshold} for threshold in [0.0, 0.25, 0.5, 0.75, 1.0]])
743-
def test_solarize2(device, dtype, config):
751+
@pytest.mark.parametrize('channels', [1, 3])
752+
def test_solarize2(device, dtype, config, channels):
744753
check_functional_vs_PIL_vs_scripted(
745754
F.solarize,
746755
lambda img, threshold: F_pil.solarize(img, 255 * threshold),
747756
F_t.solarize,
748757
config,
749758
device,
750759
dtype,
760+
channels,
751761
tol=1.0,
752762
agg_method="max",
753763
)
@@ -756,34 +766,39 @@ def test_solarize2(device, dtype, config):
756766
@pytest.mark.parametrize('device', cpu_and_gpu())
757767
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
758768
@pytest.mark.parametrize('config', [{"sharpness_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]])
759-
def test_adjust_sharpness(device, dtype, config):
769+
@pytest.mark.parametrize('channels', [1, 3])
770+
def test_adjust_sharpness(device, dtype, config, channels):
760771
check_functional_vs_PIL_vs_scripted(
761772
F.adjust_sharpness,
762773
F_pil.adjust_sharpness,
763774
F_t.adjust_sharpness,
764775
config,
765776
device,
766777
dtype,
778+
channels,
767779
)
768780

769781

770782
@pytest.mark.parametrize('device', cpu_and_gpu())
771783
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
772-
def test_autocontrast(device, dtype):
784+
@pytest.mark.parametrize('channels', [1, 3])
785+
def test_autocontrast(device, dtype, channels):
773786
check_functional_vs_PIL_vs_scripted(
774787
F.autocontrast,
775788
F_pil.autocontrast,
776789
F_t.autocontrast,
777790
{},
778791
device,
779792
dtype,
793+
channels,
780794
tol=1.0,
781795
agg_method="max"
782796
)
783797

784798

785799
@pytest.mark.parametrize('device', cpu_and_gpu())
786-
def test_equalize(device):
800+
@pytest.mark.parametrize('channels', [1, 3])
801+
def test_equalize(device, channels):
787802
torch.use_deterministic_algorithms(False)
788803
check_functional_vs_PIL_vs_scripted(
789804
F.equalize,
@@ -792,6 +807,7 @@ def test_equalize(device):
792807
{},
793808
device,
794809
dtype=None,
810+
channels=channels,
795811
tol=1.0,
796812
agg_method="max",
797813
)
@@ -809,35 +825,39 @@ def test_adjust_contrast(device, dtype, config, channels):
809825
config,
810826
device,
811827
dtype,
812-
channels=channels
828+
channels
813829
)
814830

815831

816832
@pytest.mark.parametrize('device', cpu_and_gpu())
817833
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
818834
@pytest.mark.parametrize('config', [{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]])
819-
def test_adjust_saturation(device, dtype, config):
835+
@pytest.mark.parametrize('channels', [1, 3])
836+
def test_adjust_saturation(device, dtype, config, channels):
820837
check_functional_vs_PIL_vs_scripted(
821838
F.adjust_saturation,
822839
F_pil.adjust_saturation,
823840
F_t.adjust_saturation,
824841
config,
825842
device,
826-
dtype
843+
dtype,
844+
channels
827845
)
828846

829847

830848
@pytest.mark.parametrize('device', cpu_and_gpu())
831849
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
832850
@pytest.mark.parametrize('config', [{"hue_factor": f} for f in [-0.45, -0.25, 0.0, 0.25, 0.45]])
833-
def test_adjust_hue(device, dtype, config):
851+
@pytest.mark.parametrize('channels', [1, 3])
852+
def test_adjust_hue(device, dtype, config, channels):
834853
check_functional_vs_PIL_vs_scripted(
835854
F.adjust_hue,
836855
F_pil.adjust_hue,
837856
F_t.adjust_hue,
838857
config,
839858
device,
840859
dtype,
860+
channels,
841861
tol=16.1,
842862
agg_method="max"
843863
)
@@ -846,14 +866,16 @@ def test_adjust_hue(device, dtype, config):
846866
@pytest.mark.parametrize('device', cpu_and_gpu())
847867
@pytest.mark.parametrize('dtype', (None, torch.float32, torch.float64))
848868
@pytest.mark.parametrize('config', [{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])])
849-
def test_adjust_gamma(device, dtype, config):
869+
@pytest.mark.parametrize('channels', [1, 3])
870+
def test_adjust_gamma(device, dtype, config, channels):
850871
check_functional_vs_PIL_vs_scripted(
851872
F.adjust_gamma,
852873
F_pil.adjust_gamma,
853874
F_t.adjust_gamma,
854875
config,
855876
device,
856877
dtype,
878+
channels,
857879
)
858880

859881

test/test_transforms.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@
2626
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
2727

2828

29+
def _get_grayscale_test_image(img, fill=None):
30+
img = img.convert('L')
31+
fill = (fill[0], ) if isinstance(fill, tuple) else fill
32+
return img, fill
33+
34+
2935
class TestConvertImageDtype:
3036
@pytest.mark.parametrize('input_dtype, output_dtype', cycle_over(float_dtypes()))
3137
def test_float_to_float(self, input_dtype, output_dtype):
@@ -1482,9 +1488,12 @@ def test_five_crop(single_dim):
14821488

14831489
@pytest.mark.parametrize('policy', transforms.AutoAugmentPolicy)
14841490
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1485-
def test_autoaugment(policy, fill):
1491+
@pytest.mark.parametrize('grayscale', [True, False])
1492+
def test_autoaugment(policy, fill, grayscale):
14861493
random.seed(42)
14871494
img = Image.open(GRACE_HOPPER)
1495+
if grayscale:
1496+
img, fill = _get_grayscale_test_image(img, fill)
14881497
transform = transforms.AutoAugment(policy=policy, fill=fill)
14891498
for _ in range(100):
14901499
img = transform(img)
@@ -1494,9 +1503,12 @@ def test_autoaugment(policy, fill):
14941503
@pytest.mark.parametrize('num_ops', [1, 2, 3])
14951504
@pytest.mark.parametrize('magnitude', [7, 9, 11])
14961505
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
1497-
def test_randaugment(num_ops, magnitude, fill):
1506+
@pytest.mark.parametrize('grayscale', [True, False])
1507+
def test_randaugment(num_ops, magnitude, fill, grayscale):
14981508
random.seed(42)
14991509
img = Image.open(GRACE_HOPPER)
1510+
if grayscale:
1511+
img, fill = _get_grayscale_test_image(img, fill)
15001512
transform = transforms.RandAugment(num_ops=num_ops, magnitude=magnitude, fill=fill)
15011513
for _ in range(100):
15021514
img = transform(img)
@@ -1505,9 +1517,12 @@ def test_randaugment(num_ops, magnitude, fill):
15051517

15061518
@pytest.mark.parametrize('fill', [None, 85, (128, 128, 128)])
15071519
@pytest.mark.parametrize('num_magnitude_bins', [10, 13, 30])
1508-
def test_trivialaugmentwide(fill, num_magnitude_bins):
1520+
@pytest.mark.parametrize('grayscale', [True, False])
1521+
def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
15091522
random.seed(42)
15101523
img = Image.open(GRACE_HOPPER)
1524+
if grayscale:
1525+
img, fill = _get_grayscale_test_image(img, fill)
15111526
transform = transforms.TrivialAugmentWide(fill=fill, num_magnitude_bins=num_magnitude_bins)
15121527
for _ in range(100):
15131528
img = transform(img)

0 commit comments

Comments
 (0)