Skip to content

Commit 08c9938

Browse files
authored
Add --use-v2 support to classification references (#7724)
1 parent 23b0938 commit 08c9938

File tree

2 files changed

+45
-26
lines changed

2 files changed

+45
-26
lines changed

references/classification/presets.py

+42-26
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,19 @@
11
import torch
2-
from torchvision.transforms import autoaugment, transforms
32
from torchvision.transforms.functional import InterpolationMode
43

54

5+
def get_module(use_v2):
6+
# We need a protected import to avoid the V2 warning in case just V1 is used
7+
if use_v2:
8+
import torchvision.transforms.v2
9+
10+
return torchvision.transforms.v2
11+
else:
12+
import torchvision.transforms
13+
14+
return torchvision.transforms
15+
16+
617
class ClassificationPresetTrain:
718
def __init__(
819
self,
@@ -17,41 +28,44 @@ def __init__(
1728
augmix_severity=3,
1829
random_erase_prob=0.0,
1930
backend="pil",
31+
use_v2=False,
2032
):
21-
trans = []
33+
module = get_module(use_v2)
34+
35+
transforms = []
2236
backend = backend.lower()
2337
if backend == "tensor":
24-
trans.append(transforms.PILToTensor())
38+
transforms.append(module.PILToTensor())
2539
elif backend != "pil":
2640
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
2741

28-
trans.append(transforms.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
42+
transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
2943
if hflip_prob > 0:
30-
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
44+
transforms.append(module.RandomHorizontalFlip(hflip_prob))
3145
if auto_augment_policy is not None:
3246
if auto_augment_policy == "ra":
33-
trans.append(autoaugment.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
47+
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
3448
elif auto_augment_policy == "ta_wide":
35-
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
49+
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
3650
elif auto_augment_policy == "augmix":
37-
trans.append(autoaugment.AugMix(interpolation=interpolation, severity=augmix_severity))
51+
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
3852
else:
39-
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
40-
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
53+
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
54+
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
4155

4256
if backend == "pil":
43-
trans.append(transforms.PILToTensor())
57+
transforms.append(module.PILToTensor())
4458

45-
trans.extend(
59+
transforms.extend(
4660
[
47-
transforms.ConvertImageDtype(torch.float),
48-
transforms.Normalize(mean=mean, std=std),
61+
module.ConvertImageDtype(torch.float),
62+
module.Normalize(mean=mean, std=std),
4963
]
5064
)
5165
if random_erase_prob > 0:
52-
trans.append(transforms.RandomErasing(p=random_erase_prob))
66+
transforms.append(module.RandomErasing(p=random_erase_prob))
5367

54-
self.transforms = transforms.Compose(trans)
68+
self.transforms = module.Compose(transforms)
5569

5670
def __call__(self, img):
5771
return self.transforms(img)
@@ -67,28 +81,30 @@ def __init__(
6781
std=(0.229, 0.224, 0.225),
6882
interpolation=InterpolationMode.BILINEAR,
6983
backend="pil",
84+
use_v2=False,
7085
):
71-
trans = []
86+
module = get_module(use_v2)
87+
transforms = []
7288
backend = backend.lower()
7389
if backend == "tensor":
74-
trans.append(transforms.PILToTensor())
90+
transforms.append(module.PILToTensor())
7591
elif backend != "pil":
7692
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
7793

78-
trans += [
79-
transforms.Resize(resize_size, interpolation=interpolation, antialias=True),
80-
transforms.CenterCrop(crop_size),
94+
transforms += [
95+
module.Resize(resize_size, interpolation=interpolation, antialias=True),
96+
module.CenterCrop(crop_size),
8197
]
8298

8399
if backend == "pil":
84-
trans.append(transforms.PILToTensor())
100+
transforms.append(module.PILToTensor())
85101

86-
trans += [
87-
transforms.ConvertImageDtype(torch.float),
88-
transforms.Normalize(mean=mean, std=std),
102+
transforms += [
103+
module.ConvertImageDtype(torch.float),
104+
module.Normalize(mean=mean, std=std),
89105
]
90106

91-
self.transforms = transforms.Compose(trans)
107+
self.transforms = module.Compose(transforms)
92108

93109
def __call__(self, img):
94110
return self.transforms(img)

references/classification/train.py

+3
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def load_data(traindir, valdir, args):
145145
ra_magnitude=ra_magnitude,
146146
augmix_severity=augmix_severity,
147147
backend=args.backend,
148+
use_v2=args.use_v2,
148149
),
149150
)
150151
if args.cache_dataset:
@@ -172,6 +173,7 @@ def load_data(traindir, valdir, args):
172173
resize_size=val_resize_size,
173174
interpolation=interpolation,
174175
backend=args.backend,
176+
use_v2=args.use_v2,
175177
)
176178

177179
dataset_test = torchvision.datasets.ImageFolder(
@@ -516,6 +518,7 @@ def get_args_parser(add_help=True):
516518
)
517519
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
518520
parser.add_argument("--backend", default="PIL", type=str.lower, help="PIL or tensor - case insensitive")
521+
parser.add_argument("--use-v2", action="store_true", help="Use V2 transforms")
519522
return parser
520523

521524

0 commit comments

Comments
 (0)