Skip to content

Commit bb3aae7

Browse files
NicolasHugpmeier
andauthored
Add --backend and --use-v2 support to detection refs (#7732)
Co-authored-by: Philip Meier <github.pmeier@posteo.de>
1 parent 08c9938 commit bb3aae7

File tree

8 files changed

+166
-106
lines changed

8 files changed

+166
-106
lines changed

references/classification/presets.py

+25-22
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ def get_module(use_v2):
1515

1616

1717
class ClassificationPresetTrain:
18+
# Note: this transform assumes that the input to forward() are always PIL
19+
# images, regardless of the backend parameter. We may change that in the
20+
# future though, if we change the output type from the dataset.
1821
def __init__(
1922
self,
2023
*,
@@ -30,42 +33,42 @@ def __init__(
3033
backend="pil",
3134
use_v2=False,
3235
):
33-
module = get_module(use_v2)
36+
T = get_module(use_v2)
3437

3538
transforms = []
3639
backend = backend.lower()
3740
if backend == "tensor":
38-
transforms.append(module.PILToTensor())
41+
transforms.append(T.PILToTensor())
3942
elif backend != "pil":
4043
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
4144

42-
transforms.append(module.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
45+
transforms.append(T.RandomResizedCrop(crop_size, interpolation=interpolation, antialias=True))
4346
if hflip_prob > 0:
44-
transforms.append(module.RandomHorizontalFlip(hflip_prob))
47+
transforms.append(T.RandomHorizontalFlip(hflip_prob))
4548
if auto_augment_policy is not None:
4649
if auto_augment_policy == "ra":
47-
transforms.append(module.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
50+
transforms.append(T.RandAugment(interpolation=interpolation, magnitude=ra_magnitude))
4851
elif auto_augment_policy == "ta_wide":
49-
transforms.append(module.TrivialAugmentWide(interpolation=interpolation))
52+
transforms.append(T.TrivialAugmentWide(interpolation=interpolation))
5053
elif auto_augment_policy == "augmix":
51-
transforms.append(module.AugMix(interpolation=interpolation, severity=augmix_severity))
54+
transforms.append(T.AugMix(interpolation=interpolation, severity=augmix_severity))
5255
else:
53-
aa_policy = module.AutoAugmentPolicy(auto_augment_policy)
54-
transforms.append(module.AutoAugment(policy=aa_policy, interpolation=interpolation))
56+
aa_policy = T.AutoAugmentPolicy(auto_augment_policy)
57+
transforms.append(T.AutoAugment(policy=aa_policy, interpolation=interpolation))
5558

5659
if backend == "pil":
57-
transforms.append(module.PILToTensor())
60+
transforms.append(T.PILToTensor())
5861

5962
transforms.extend(
6063
[
61-
module.ConvertImageDtype(torch.float),
62-
module.Normalize(mean=mean, std=std),
64+
T.ConvertImageDtype(torch.float),
65+
T.Normalize(mean=mean, std=std),
6366
]
6467
)
6568
if random_erase_prob > 0:
66-
transforms.append(module.RandomErasing(p=random_erase_prob))
69+
transforms.append(T.RandomErasing(p=random_erase_prob))
6770

68-
self.transforms = module.Compose(transforms)
71+
self.transforms = T.Compose(transforms)
6972

7073
def __call__(self, img):
7174
return self.transforms(img)
@@ -83,28 +86,28 @@ def __init__(
8386
backend="pil",
8487
use_v2=False,
8588
):
86-
module = get_module(use_v2)
89+
T = get_module(use_v2)
8790
transforms = []
8891
backend = backend.lower()
8992
if backend == "tensor":
90-
transforms.append(module.PILToTensor())
93+
transforms.append(T.PILToTensor())
9194
elif backend != "pil":
9295
raise ValueError(f"backend can be 'tensor' or 'pil', but got {backend}")
9396

9497
transforms += [
95-
module.Resize(resize_size, interpolation=interpolation, antialias=True),
96-
module.CenterCrop(crop_size),
98+
T.Resize(resize_size, interpolation=interpolation, antialias=True),
99+
T.CenterCrop(crop_size),
97100
]
98101

99102
if backend == "pil":
100-
transforms.append(module.PILToTensor())
103+
transforms.append(T.PILToTensor())
101104

102105
transforms += [
103-
module.ConvertImageDtype(torch.float),
104-
module.Normalize(mean=mean, std=std),
106+
T.ConvertImageDtype(torch.float),
107+
T.Normalize(mean=mean, std=std),
105108
]
106109

107-
self.transforms = module.Compose(transforms)
110+
self.transforms = T.Compose(transforms)
108111

109112
def __call__(self, img):
110113
return self.transforms(img)

references/detection/coco_utils.py

+20-15
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import transforms as T
88
from pycocotools import mask as coco_mask
99
from pycocotools.coco import COCO
10+
from torchvision.datasets import wrap_dataset_for_transforms_v2
1011

1112

1213
class FilterAndRemapCocoCategories:
@@ -49,7 +50,6 @@ def __call__(self, image, target):
4950
w, h = image.size
5051

5152
image_id = target["image_id"]
52-
image_id = torch.tensor([image_id])
5353

5454
anno = target["annotations"]
5555

@@ -126,10 +126,6 @@ def _has_valid_annotation(anno):
126126
return True
127127
return False
128128

129-
if not isinstance(dataset, torchvision.datasets.CocoDetection):
130-
raise TypeError(
131-
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
132-
)
133129
ids = []
134130
for ds_idx, img_id in enumerate(dataset.ids):
135131
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
@@ -196,12 +192,15 @@ def convert_to_coco_api(ds):
196192

197193

198194
def get_coco_api_from_dataset(dataset):
195+
# FIXME: This is... awful?
199196
for _ in range(10):
200197
if isinstance(dataset, torchvision.datasets.CocoDetection):
201198
break
202199
if isinstance(dataset, torch.utils.data.Subset):
203200
dataset = dataset.dataset
204-
if isinstance(dataset, torchvision.datasets.CocoDetection):
201+
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
202+
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
203+
):
205204
return dataset.coco
206205
return convert_to_coco_api(dataset)
207206

@@ -220,25 +219,29 @@ def __getitem__(self, idx):
220219
return img, target
221220

222221

223-
def get_coco(root, image_set, transforms, mode="instances"):
222+
def get_coco(root, image_set, transforms, mode="instances", use_v2=False):
224223
anno_file_template = "{}_{}2017.json"
225224
PATHS = {
226225
"train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
227226
"val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
228227
# "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
229228
}
230229

231-
t = [ConvertCocoPolysToMask()]
232-
233-
if transforms is not None:
234-
t.append(transforms)
235-
transforms = T.Compose(t)
236-
237230
img_folder, ann_file = PATHS[image_set]
238231
img_folder = os.path.join(root, img_folder)
239232
ann_file = os.path.join(root, ann_file)
240233

241-
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
234+
if use_v2:
235+
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
236+
# TODO: need to update target_keys to handle masks for segmentation!
237+
dataset = wrap_dataset_for_transforms_v2(dataset, target_keys={"boxes", "labels", "image_id"})
238+
else:
239+
t = [ConvertCocoPolysToMask()]
240+
if transforms is not None:
241+
t.append(transforms)
242+
transforms = T.Compose(t)
243+
244+
dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
242245

243246
if image_set == "train":
244247
dataset = _coco_remove_images_without_annotations(dataset)
@@ -248,5 +251,7 @@ def get_coco(root, image_set, transforms, mode="instances"):
248251
return dataset
249252

250253

251-
def get_coco_kp(root, image_set, transforms):
254+
def get_coco_kp(root, image_set, transforms, use_v2=False):
255+
if use_v2:
256+
raise ValueError("KeyPoints aren't supported by transforms V2 yet.")
252257
return get_coco(root, image_set, transforms, mode="person_keypoints")

references/detection/engine.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, sc
2626

2727
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
2828
images = list(image.to(device) for image in images)
29-
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
29+
targets = [{k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in t.items()} for t in targets]
3030
with torch.cuda.amp.autocast(enabled=scaler is not None):
3131
loss_dict = model(images, targets)
3232
losses = sum(loss for loss in loss_dict.values())
@@ -97,7 +97,7 @@ def evaluate(model, data_loader, device):
9797
outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
9898
model_time = time.time() - model_time
9999

100-
res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
100+
res = {target["image_id"]: output for target, output in zip(targets, outputs)}
101101
evaluator_time = time.time()
102102
coco_evaluator.update(res)
103103
evaluator_time = time.time() - evaluator_time

references/detection/group_by_aspect_ratio.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,9 @@ def compute_aspect_ratios(dataset, indices=None):
164164
if hasattr(dataset, "get_height_and_width"):
165165
return _compute_aspect_ratios_custom_dataset(dataset, indices)
166166

167-
if isinstance(dataset, torchvision.datasets.CocoDetection):
167+
if isinstance(dataset, torchvision.datasets.CocoDetection) or isinstance(
168+
getattr(dataset, "_dataset", None), torchvision.datasets.CocoDetection
169+
):
168170
return _compute_aspect_ratios_coco_dataset(dataset, indices)
169171

170172
if isinstance(dataset, torchvision.datasets.VOCDetection):

references/detection/presets.py

+89-53
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,109 @@
1+
from collections import defaultdict
2+
13
import torch
2-
import transforms as T
4+
import transforms as reference_transforms
5+
6+
7+
def get_modules(use_v2):
8+
# We need a protected import to avoid the V2 warning in case just V1 is used
9+
if use_v2:
10+
import torchvision.datapoints
11+
import torchvision.transforms.v2
12+
13+
return torchvision.transforms.v2, torchvision.datapoints
14+
else:
15+
return reference_transforms, None
316

417

518
class DetectionPresetTrain:
6-
def __init__(self, *, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
19+
# Note: this transform assumes that the input to forward() are always PIL
20+
# images, regardless of the backend parameter.
21+
def __init__(
22+
self,
23+
*,
24+
data_augmentation,
25+
hflip_prob=0.5,
26+
mean=(123.0, 117.0, 104.0),
27+
backend="pil",
28+
use_v2=False,
29+
):
30+
31+
T, datapoints = get_modules(use_v2)
32+
33+
transforms = []
34+
backend = backend.lower()
35+
if backend == "datapoint":
36+
transforms.append(T.ToImageTensor())
37+
elif backend == "tensor":
38+
transforms.append(T.PILToTensor())
39+
elif backend != "pil":
40+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
41+
742
if data_augmentation == "hflip":
8-
self.transforms = T.Compose(
9-
[
10-
T.RandomHorizontalFlip(p=hflip_prob),
11-
T.PILToTensor(),
12-
T.ConvertImageDtype(torch.float),
13-
]
14-
)
43+
transforms += [T.RandomHorizontalFlip(p=hflip_prob)]
1544
elif data_augmentation == "lsj":
16-
self.transforms = T.Compose(
17-
[
18-
T.ScaleJitter(target_size=(1024, 1024)),
19-
T.FixedSizeCrop(size=(1024, 1024), fill=mean),
20-
T.RandomHorizontalFlip(p=hflip_prob),
21-
T.PILToTensor(),
22-
T.ConvertImageDtype(torch.float),
23-
]
24-
)
45+
transforms += [
46+
T.ScaleJitter(target_size=(1024, 1024), antialias=True),
47+
# TODO: FixedSizeCrop below doesn't work on tensors!
48+
reference_transforms.FixedSizeCrop(size=(1024, 1024), fill=mean),
49+
T.RandomHorizontalFlip(p=hflip_prob),
50+
]
2551
elif data_augmentation == "multiscale":
26-
self.transforms = T.Compose(
27-
[
28-
T.RandomShortestSize(
29-
min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333
30-
),
31-
T.RandomHorizontalFlip(p=hflip_prob),
32-
T.PILToTensor(),
33-
T.ConvertImageDtype(torch.float),
34-
]
35-
)
52+
transforms += [
53+
T.RandomShortestSize(min_size=(480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800), max_size=1333),
54+
T.RandomHorizontalFlip(p=hflip_prob),
55+
]
3656
elif data_augmentation == "ssd":
37-
self.transforms = T.Compose(
38-
[
39-
T.RandomPhotometricDistort(),
40-
T.RandomZoomOut(fill=list(mean)),
41-
T.RandomIoUCrop(),
42-
T.RandomHorizontalFlip(p=hflip_prob),
43-
T.PILToTensor(),
44-
T.ConvertImageDtype(torch.float),
45-
]
46-
)
57+
fill = defaultdict(lambda: mean, {datapoints.Mask: 0}) if use_v2 else list(mean)
58+
transforms += [
59+
T.RandomPhotometricDistort(),
60+
T.RandomZoomOut(fill=fill),
61+
T.RandomIoUCrop(),
62+
T.RandomHorizontalFlip(p=hflip_prob),
63+
]
4764
elif data_augmentation == "ssdlite":
48-
self.transforms = T.Compose(
49-
[
50-
T.RandomIoUCrop(),
51-
T.RandomHorizontalFlip(p=hflip_prob),
52-
T.PILToTensor(),
53-
T.ConvertImageDtype(torch.float),
54-
]
55-
)
65+
transforms += [
66+
T.RandomIoUCrop(),
67+
T.RandomHorizontalFlip(p=hflip_prob),
68+
]
5669
else:
5770
raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
5871

72+
if backend == "pil":
73+
# Note: we could just convert to pure tensors even in v2.
74+
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
75+
76+
transforms += [T.ConvertImageDtype(torch.float)]
77+
78+
if use_v2:
79+
transforms += [
80+
T.ConvertBoundingBoxFormat(datapoints.BoundingBoxFormat.XYXY),
81+
T.SanitizeBoundingBox(),
82+
]
83+
84+
self.transforms = T.Compose(transforms)
85+
5986
def __call__(self, img, target):
6087
return self.transforms(img, target)
6188

6289

6390
class DetectionPresetEval:
64-
def __init__(self):
65-
self.transforms = T.Compose(
66-
[
67-
T.PILToTensor(),
68-
T.ConvertImageDtype(torch.float),
69-
]
70-
)
91+
def __init__(self, backend="pil", use_v2=False):
92+
T, _ = get_modules(use_v2)
93+
transforms = []
94+
backend = backend.lower()
95+
if backend == "pil":
96+
# Note: we could just convert to pure tensors even in v2?
97+
transforms += [T.ToImageTensor() if use_v2 else T.PILToTensor()]
98+
elif backend == "tensor":
99+
transforms += [T.PILToTensor()]
100+
elif backend == "datapoint":
101+
transforms += [T.ToImageTensor()]
102+
else:
103+
raise ValueError(f"backend can be 'datapoint', 'tensor' or 'pil', but got {backend}")
104+
105+
transforms += [T.ConvertImageDtype(torch.float)]
106+
self.transforms = T.Compose(transforms)
71107

72108
def __call__(self, img, target):
73109
return self.transforms(img, target)

0 commit comments

Comments
 (0)