@@ -681,57 +681,65 @@ def check_functional_vs_PIL_vs_scripted(
681
681
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
682
682
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
683
683
@pytest .mark .parametrize ('config' , [{"brightness_factor" : f } for f in (0.1 , 0.5 , 1.0 , 1.34 , 2.5 )])
684
- def test_adjust_brightness (device , dtype , config ):
684
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
685
+ def test_adjust_brightness (device , dtype , config , channels ):
685
686
check_functional_vs_PIL_vs_scripted (
686
687
F .adjust_brightness ,
687
688
F_pil .adjust_brightness ,
688
689
F_t .adjust_brightness ,
689
690
config ,
690
691
device ,
691
692
dtype ,
693
+ channels ,
692
694
)
693
695
694
696
695
697
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
696
698
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
697
- def test_invert (device , dtype ):
699
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
700
+ def test_invert (device , dtype , channels ):
698
701
check_functional_vs_PIL_vs_scripted (
699
702
F .invert ,
700
703
F_pil .invert ,
701
704
F_t .invert ,
702
705
{},
703
706
device ,
704
707
dtype ,
708
+ channels ,
705
709
tol = 1.0 ,
706
710
agg_method = "max"
707
711
)
708
712
709
713
710
714
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
711
715
@pytest .mark .parametrize ('config' , [{"bits" : bits } for bits in range (0 , 8 )])
712
- def test_posterize (device , config ):
716
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
717
+ def test_posterize (device , config , channels ):
713
718
check_functional_vs_PIL_vs_scripted (
714
719
F .posterize ,
715
720
F_pil .posterize ,
716
721
F_t .posterize ,
717
722
config ,
718
723
device ,
719
724
dtype = None ,
725
+ channels = channels ,
720
726
tol = 1.0 ,
721
727
agg_method = "max" ,
722
728
)
723
729
724
730
725
731
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
726
732
@pytest .mark .parametrize ('config' , [{"threshold" : threshold } for threshold in [0 , 64 , 128 , 192 , 255 ]])
727
- def test_solarize1 (device , config ):
733
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
734
+ def test_solarize1 (device , config , channels ):
728
735
check_functional_vs_PIL_vs_scripted (
729
736
F .solarize ,
730
737
F_pil .solarize ,
731
738
F_t .solarize ,
732
739
config ,
733
740
device ,
734
741
dtype = None ,
742
+ channels = channels ,
735
743
tol = 1.0 ,
736
744
agg_method = "max" ,
737
745
)
@@ -740,14 +748,16 @@ def test_solarize1(device, config):
740
748
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
741
749
@pytest .mark .parametrize ('dtype' , (torch .float32 , torch .float64 ))
742
750
@pytest .mark .parametrize ('config' , [{"threshold" : threshold } for threshold in [0.0 , 0.25 , 0.5 , 0.75 , 1.0 ]])
743
- def test_solarize2 (device , dtype , config ):
751
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
752
+ def test_solarize2 (device , dtype , config , channels ):
744
753
check_functional_vs_PIL_vs_scripted (
745
754
F .solarize ,
746
755
lambda img , threshold : F_pil .solarize (img , 255 * threshold ),
747
756
F_t .solarize ,
748
757
config ,
749
758
device ,
750
759
dtype ,
760
+ channels ,
751
761
tol = 1.0 ,
752
762
agg_method = "max" ,
753
763
)
@@ -756,34 +766,39 @@ def test_solarize2(device, dtype, config):
756
766
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
757
767
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
758
768
@pytest .mark .parametrize ('config' , [{"sharpness_factor" : f } for f in [0.2 , 0.5 , 1.0 , 1.5 , 2.0 ]])
759
- def test_adjust_sharpness (device , dtype , config ):
769
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
770
+ def test_adjust_sharpness (device , dtype , config , channels ):
760
771
check_functional_vs_PIL_vs_scripted (
761
772
F .adjust_sharpness ,
762
773
F_pil .adjust_sharpness ,
763
774
F_t .adjust_sharpness ,
764
775
config ,
765
776
device ,
766
777
dtype ,
778
+ channels ,
767
779
)
768
780
769
781
770
782
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
771
783
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
772
- def test_autocontrast (device , dtype ):
784
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
785
+ def test_autocontrast (device , dtype , channels ):
773
786
check_functional_vs_PIL_vs_scripted (
774
787
F .autocontrast ,
775
788
F_pil .autocontrast ,
776
789
F_t .autocontrast ,
777
790
{},
778
791
device ,
779
792
dtype ,
793
+ channels ,
780
794
tol = 1.0 ,
781
795
agg_method = "max"
782
796
)
783
797
784
798
785
799
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
786
- def test_equalize (device ):
800
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
801
+ def test_equalize (device , channels ):
787
802
torch .use_deterministic_algorithms (False )
788
803
check_functional_vs_PIL_vs_scripted (
789
804
F .equalize ,
@@ -792,6 +807,7 @@ def test_equalize(device):
792
807
{},
793
808
device ,
794
809
dtype = None ,
810
+ channels = channels ,
795
811
tol = 1.0 ,
796
812
agg_method = "max" ,
797
813
)
@@ -809,35 +825,39 @@ def test_adjust_contrast(device, dtype, config, channels):
809
825
config ,
810
826
device ,
811
827
dtype ,
812
- channels = channels
828
+ channels
813
829
)
814
830
815
831
816
832
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
817
833
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
818
834
@pytest .mark .parametrize ('config' , [{"saturation_factor" : f } for f in [0.5 , 0.75 , 1.0 , 1.5 , 2.0 ]])
819
- def test_adjust_saturation (device , dtype , config ):
835
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
836
+ def test_adjust_saturation (device , dtype , config , channels ):
820
837
check_functional_vs_PIL_vs_scripted (
821
838
F .adjust_saturation ,
822
839
F_pil .adjust_saturation ,
823
840
F_t .adjust_saturation ,
824
841
config ,
825
842
device ,
826
- dtype
843
+ dtype ,
844
+ channels
827
845
)
828
846
829
847
830
848
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
831
849
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
832
850
@pytest .mark .parametrize ('config' , [{"hue_factor" : f } for f in [- 0.45 , - 0.25 , 0.0 , 0.25 , 0.45 ]])
833
- def test_adjust_hue (device , dtype , config ):
851
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
852
+ def test_adjust_hue (device , dtype , config , channels ):
834
853
check_functional_vs_PIL_vs_scripted (
835
854
F .adjust_hue ,
836
855
F_pil .adjust_hue ,
837
856
F_t .adjust_hue ,
838
857
config ,
839
858
device ,
840
859
dtype ,
860
+ channels ,
841
861
tol = 16.1 ,
842
862
agg_method = "max"
843
863
)
@@ -846,14 +866,16 @@ def test_adjust_hue(device, dtype, config):
846
866
@pytest .mark .parametrize ('device' , cpu_and_gpu ())
847
867
@pytest .mark .parametrize ('dtype' , (None , torch .float32 , torch .float64 ))
848
868
@pytest .mark .parametrize ('config' , [{"gamma" : g1 , "gain" : g2 } for g1 , g2 in zip ([0.8 , 1.0 , 1.2 ], [0.7 , 1.0 , 1.3 ])])
849
- def test_adjust_gamma (device , dtype , config ):
869
+ @pytest .mark .parametrize ('channels' , [1 , 3 ])
870
+ def test_adjust_gamma (device , dtype , config , channels ):
850
871
check_functional_vs_PIL_vs_scripted (
851
872
F .adjust_gamma ,
852
873
F_pil .adjust_gamma ,
853
874
F_t .adjust_gamma ,
854
875
config ,
855
876
device ,
856
877
dtype ,
878
+ channels ,
857
879
)
858
880
859
881
0 commit comments