1
1
import torch
2
- from torchvision .transforms import autoaugment , transforms
3
2
from torchvision .transforms .functional import InterpolationMode
4
3
5
4
5
+ def get_module (use_v2 ):
6
+ # We need a protected import to avoid the V2 warning in case just V1 is used
7
+ if use_v2 :
8
+ import torchvision .transforms .v2
9
+
10
+ return torchvision .transforms .v2
11
+ else :
12
+ import torchvision .transforms
13
+
14
+ return torchvision .transforms
15
+
16
+
6
17
class ClassificationPresetTrain :
7
18
def __init__ (
8
19
self ,
@@ -17,41 +28,44 @@ def __init__(
17
28
augmix_severity = 3 ,
18
29
random_erase_prob = 0.0 ,
19
30
backend = "pil" ,
31
+ use_v2 = False ,
20
32
):
21
- trans = []
33
+ module = get_module (use_v2 )
34
+
35
+ transforms = []
22
36
backend = backend .lower ()
23
37
if backend == "tensor" :
24
- trans .append (transforms .PILToTensor ())
38
+ transforms .append (module .PILToTensor ())
25
39
elif backend != "pil" :
26
40
raise ValueError (f"backend can be 'tensor' or 'pil', but got { backend } " )
27
41
28
- trans .append (transforms .RandomResizedCrop (crop_size , interpolation = interpolation , antialias = True ))
42
+ transforms .append (module .RandomResizedCrop (crop_size , interpolation = interpolation , antialias = True ))
29
43
if hflip_prob > 0 :
30
- trans .append (transforms .RandomHorizontalFlip (hflip_prob ))
44
+ transforms .append (module .RandomHorizontalFlip (hflip_prob ))
31
45
if auto_augment_policy is not None :
32
46
if auto_augment_policy == "ra" :
33
- trans .append (autoaugment .RandAugment (interpolation = interpolation , magnitude = ra_magnitude ))
47
+ transforms .append (module .RandAugment (interpolation = interpolation , magnitude = ra_magnitude ))
34
48
elif auto_augment_policy == "ta_wide" :
35
- trans .append (autoaugment .TrivialAugmentWide (interpolation = interpolation ))
49
+ transforms .append (module .TrivialAugmentWide (interpolation = interpolation ))
36
50
elif auto_augment_policy == "augmix" :
37
- trans .append (autoaugment .AugMix (interpolation = interpolation , severity = augmix_severity ))
51
+ transforms .append (module .AugMix (interpolation = interpolation , severity = augmix_severity ))
38
52
else :
39
- aa_policy = autoaugment .AutoAugmentPolicy (auto_augment_policy )
40
- trans .append (autoaugment .AutoAugment (policy = aa_policy , interpolation = interpolation ))
53
+ aa_policy = module .AutoAugmentPolicy (auto_augment_policy )
54
+ transforms .append (module .AutoAugment (policy = aa_policy , interpolation = interpolation ))
41
55
42
56
if backend == "pil" :
43
- trans .append (transforms .PILToTensor ())
57
+ transforms .append (module .PILToTensor ())
44
58
45
- trans .extend (
59
+ transforms .extend (
46
60
[
47
- transforms .ConvertImageDtype (torch .float ),
48
- transforms .Normalize (mean = mean , std = std ),
61
+ module .ConvertImageDtype (torch .float ),
62
+ module .Normalize (mean = mean , std = std ),
49
63
]
50
64
)
51
65
if random_erase_prob > 0 :
52
- trans .append (transforms .RandomErasing (p = random_erase_prob ))
66
+ transforms .append (module .RandomErasing (p = random_erase_prob ))
53
67
54
- self .transforms = transforms .Compose (trans )
68
+ self .transforms = module .Compose (transforms )
55
69
56
70
def __call__ (self , img ):
57
71
return self .transforms (img )
@@ -67,28 +81,30 @@ def __init__(
67
81
std = (0.229 , 0.224 , 0.225 ),
68
82
interpolation = InterpolationMode .BILINEAR ,
69
83
backend = "pil" ,
84
+ use_v2 = False ,
70
85
):
71
- trans = []
86
+ module = get_module (use_v2 )
87
+ transforms = []
72
88
backend = backend .lower ()
73
89
if backend == "tensor" :
74
- trans .append (transforms .PILToTensor ())
90
+ transforms .append (module .PILToTensor ())
75
91
elif backend != "pil" :
76
92
raise ValueError (f"backend can be 'tensor' or 'pil', but got { backend } " )
77
93
78
- trans += [
79
- transforms .Resize (resize_size , interpolation = interpolation , antialias = True ),
80
- transforms .CenterCrop (crop_size ),
94
+ transforms += [
95
+ module .Resize (resize_size , interpolation = interpolation , antialias = True ),
96
+ module .CenterCrop (crop_size ),
81
97
]
82
98
83
99
if backend == "pil" :
84
- trans .append (transforms .PILToTensor ())
100
+ transforms .append (module .PILToTensor ())
85
101
86
- trans += [
87
- transforms .ConvertImageDtype (torch .float ),
88
- transforms .Normalize (mean = mean , std = std ),
102
+ transforms += [
103
+ module .ConvertImageDtype (torch .float ),
104
+ module .Normalize (mean = mean , std = std ),
89
105
]
90
106
91
- self .transforms = transforms .Compose (trans )
107
+ self .transforms = module .Compose (transforms )
92
108
93
109
def __call__ (self , img ):
94
110
return self .transforms (img )
0 commit comments