Skip to content

Commit 39cf02a

Browse files
authored
improve COCO prototype (#4650)
* improve COCO prototype * test 2017 annotations * add option to include captions * fix categories and add tests * cleanup * add correct image size to bounding boxes * fix annotation collation * appease mypy * add benchmark * always use image as reference * another refactor * add support for segmentations * add support for segmentations * fix CI dependencies
1 parent 3d8723d commit 39cf02a

File tree

9 files changed

+498
-105
lines changed

9 files changed

+498
-105
lines changed

.circleci/config.yml

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.circleci/config.yml.in

+1-1
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ jobs:
351351
- install_torchvision
352352
- install_prototype_dependencies
353353
- pip_install:
354-
args: scipy
354+
args: scipy pycocotools
355355
descr: Install optional dependencies
356356
- run:
357357
name: Enable prototype tests

test/builtin_dataset_mocks.py

+114
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import gzip
3+
import json
34
import lzma
45
import pathlib
56
import pickle
@@ -8,6 +9,7 @@
89
from typing import Any, Dict, Tuple
910

1011
import numpy as np
12+
import PIL.Image
1113
import pytest
1214
import torch
1315
from datasets_utils import create_image_folder, make_tar, make_zip
@@ -18,7 +20,9 @@
1820
from torchvision.prototype.datasets._api import find
1921
from torchvision.prototype.utils._internal import add_suggestion
2022

23+
2124
make_tensor = functools.partial(_make_tensor, device="cpu")
25+
make_scalar = functools.partial(make_tensor, ())
2226

2327
__all__ = ["load"]
2428

@@ -490,3 +494,113 @@ def imagenet(info, root, config):
490494
make_tar(root, f"{devkit_root}.tar.gz", devkit_root, compression="gz")
491495

492496
return num_samples
497+
498+
499+
class CocoMockData:
500+
@classmethod
501+
def _make_images_archive(cls, root, name, *, num_samples):
502+
image_paths = create_image_folder(
503+
root, name, file_name_fn=lambda idx: f"{idx:012d}.jpg", num_examples=num_samples
504+
)
505+
506+
images_meta = []
507+
for path in image_paths:
508+
with PIL.Image.open(path) as image:
509+
width, height = image.size
510+
images_meta.append(dict(file_name=path.name, id=int(path.stem), width=width, height=height))
511+
512+
make_zip(root, f"{name}.zip")
513+
514+
return images_meta
515+
516+
@classmethod
517+
def _make_annotations_json(
518+
cls,
519+
root,
520+
name,
521+
*,
522+
images_meta,
523+
fn,
524+
):
525+
num_anns_per_image = torch.randint(1, 5, (len(images_meta),))
526+
num_anns_total = int(num_anns_per_image.sum())
527+
ann_ids_iter = iter(torch.arange(num_anns_total)[torch.randperm(num_anns_total)])
528+
529+
anns_meta = []
530+
for image_meta, num_anns in zip(images_meta, num_anns_per_image):
531+
for _ in range(num_anns):
532+
ann_id = int(next(ann_ids_iter))
533+
anns_meta.append(dict(fn(ann_id, image_meta), id=ann_id, image_id=image_meta["id"]))
534+
anns_meta.sort(key=lambda ann: ann["id"])
535+
536+
with open(root / name, "w") as file:
537+
json.dump(dict(images=images_meta, annotations=anns_meta), file)
538+
539+
return num_anns_per_image
540+
541+
@staticmethod
542+
def _make_instances_data(ann_id, image_meta):
543+
def make_rle_segmentation():
544+
height, width = image_meta["height"], image_meta["width"]
545+
numel = height * width
546+
counts = []
547+
while sum(counts) <= numel:
548+
counts.append(int(torch.randint(5, 8, ())))
549+
if sum(counts) > numel:
550+
counts[-1] -= sum(counts) - numel
551+
return dict(counts=counts, size=[height, width])
552+
553+
return dict(
554+
segmentation=make_rle_segmentation(),
555+
bbox=make_tensor((4,), dtype=torch.float32, low=0).tolist(),
556+
iscrowd=True,
557+
area=float(make_scalar(dtype=torch.float32)),
558+
category_id=int(make_scalar(dtype=torch.int64)),
559+
)
560+
561+
@staticmethod
562+
def _make_captions_data(ann_id, image_meta):
563+
return dict(caption=f"Caption {ann_id} describing image {image_meta['id']}.")
564+
565+
@classmethod
566+
def _make_annotations(cls, root, name, *, images_meta):
567+
num_anns_per_image = torch.zeros((len(images_meta),), dtype=torch.int64)
568+
for annotations, fn in (
569+
("instances", cls._make_instances_data),
570+
("captions", cls._make_captions_data),
571+
):
572+
num_anns_per_image += cls._make_annotations_json(
573+
root, f"{annotations}_{name}.json", images_meta=images_meta, fn=fn
574+
)
575+
576+
return int(num_anns_per_image.sum())
577+
578+
@classmethod
579+
def generate(
580+
cls,
581+
root,
582+
*,
583+
year,
584+
num_samples,
585+
):
586+
annotations_dir = root / "annotations"
587+
annotations_dir.mkdir()
588+
589+
for split in ("train", "val"):
590+
config_name = f"{split}{year}"
591+
592+
images_meta = cls._make_images_archive(root, config_name, num_samples=num_samples)
593+
cls._make_annotations(
594+
annotations_dir,
595+
config_name,
596+
images_meta=images_meta,
597+
)
598+
599+
make_zip(root, f"annotations_trainval{year}.zip", annotations_dir)
600+
601+
return num_samples
602+
603+
604+
@dataset_mocks.register_mock_data_fn
605+
def coco(info, root, config):
606+
return CocoMockData.generate(root, year=config.year, num_samples=5)

test/datasets_utils.py

+7
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,13 @@ def _split_files_or_dirs(root, *files_or_dirs):
866866

867867
def _make_archive(root, name, *files_or_dirs, opener, adder, remove=True):
868868
archive = pathlib.Path(root) / name
869+
if not files_or_dirs:
870+
dir = archive.parent / archive.name.replace("".join(archive.suffixes), "")
871+
if dir.exists() and dir.is_dir():
872+
files_or_dirs = (dir,)
873+
else:
874+
raise ValueError("No file or dir provided.")
875+
869876
files, dirs = _split_files_or_dirs(root, *files_or_dirs)
870877

871878
with opener(archive) as fh:

test/test_prototype_builtin_datasets.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@ def to_bytes(file):
1313
return file.read()
1414

1515

16+
def config_id(name, config):
17+
parts = [name]
18+
for name, value in config.items():
19+
if isinstance(value, bool):
20+
part = ("" if value else "no_") + name
21+
else:
22+
part = str(value)
23+
parts.append(part)
24+
return "-".join(parts)
25+
26+
1627
def dataset_parametrization(*names, decoder=to_bytes):
1728
if not names:
1829
# TODO: Replace this with torchvision.prototype.datasets.list() as soon as all builtin datasets are supported
@@ -27,16 +38,17 @@ def dataset_parametrization(*names, decoder=to_bytes):
2738
"caltech256",
2839
"caltech101",
2940
"imagenet",
41+
"coco",
3042
)
3143

32-
params = []
33-
for name in names:
34-
for config in datasets.info(name)._configs:
35-
id = f"{name}-{'-'.join([str(value) for value in config.values()])}"
36-
dataset, mock_info = builtin_dataset_mocks.load(name, decoder=decoder, **config)
37-
params.append(pytest.param(dataset, mock_info, id=id))
38-
39-
return pytest.mark.parametrize(("dataset", "mock_info"), params)
44+
return pytest.mark.parametrize(
45+
("dataset", "mock_info"),
46+
[
47+
pytest.param(*builtin_dataset_mocks.load(name, decoder=decoder, **config), id=config_id(name, config))
48+
for name in names
49+
for config in datasets.info(name)._configs
50+
],
51+
)
4052

4153

4254
class TestCommon:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
__background__,N/A
2+
person,person
3+
bicycle,vehicle
4+
car,vehicle
5+
motorcycle,vehicle
6+
airplane,vehicle
7+
bus,vehicle
8+
train,vehicle
9+
truck,vehicle
10+
boat,vehicle
11+
traffic light,outdoor
12+
fire hydrant,outdoor
13+
N/A,N/A
14+
stop sign,outdoor
15+
parking meter,outdoor
16+
bench,outdoor
17+
bird,animal
18+
cat,animal
19+
dog,animal
20+
horse,animal
21+
sheep,animal
22+
cow,animal
23+
elephant,animal
24+
bear,animal
25+
zebra,animal
26+
giraffe,animal
27+
N/A,N/A
28+
backpack,accessory
29+
umbrella,accessory
30+
N/A,N/A
31+
N/A,N/A
32+
handbag,accessory
33+
tie,accessory
34+
suitcase,accessory
35+
frisbee,sports
36+
skis,sports
37+
snowboard,sports
38+
sports ball,sports
39+
kite,sports
40+
baseball bat,sports
41+
baseball glove,sports
42+
skateboard,sports
43+
surfboard,sports
44+
tennis racket,sports
45+
bottle,kitchen
46+
N/A,N/A
47+
wine glass,kitchen
48+
cup,kitchen
49+
fork,kitchen
50+
knife,kitchen
51+
spoon,kitchen
52+
bowl,kitchen
53+
banana,food
54+
apple,food
55+
sandwich,food
56+
orange,food
57+
broccoli,food
58+
carrot,food
59+
hot dog,food
60+
pizza,food
61+
donut,food
62+
cake,food
63+
chair,furniture
64+
couch,furniture
65+
potted plant,furniture
66+
bed,furniture
67+
N/A,N/A
68+
dining table,furniture
69+
N/A,N/A
70+
N/A,N/A
71+
toilet,furniture
72+
N/A,N/A
73+
tv,electronic
74+
laptop,electronic
75+
mouse,electronic
76+
remote,electronic
77+
keyboard,electronic
78+
cell phone,electronic
79+
microwave,appliance
80+
oven,appliance
81+
toaster,appliance
82+
sink,appliance
83+
refrigerator,appliance
84+
N/A,N/A
85+
book,indoor
86+
clock,indoor
87+
vase,indoor
88+
scissors,indoor
89+
teddy bear,indoor
90+
hair drier,indoor
91+
toothbrush,indoor

0 commit comments

Comments
 (0)