forked from openai/improved-diffusion
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar10.py
43 lines (35 loc) · 944 Bytes
/
cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import os
import tempfile
import torchvision
from tqdm.auto import tqdm
CLASSES = (
"plane",
"car",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck",
)
def main():
for split in ["train", "test"]:
out_dir = f"cifar_{split}"
if os.path.exists(out_dir):
print(f"skipping split {split} since {out_dir} already exists.")
continue
print("downloading...")
with tempfile.TemporaryDirectory() as tmp_dir:
dataset = torchvision.datasets.CIFAR10(
root=tmp_dir, train=split == "train", download=True
)
print("dumping images...")
os.mkdir(out_dir)
for i in tqdm(range(len(dataset))):
image, label = dataset[i]
filename = os.path.join(out_dir, f"{CLASSES[label]}_{i:05d}.png")
image.save(filename)
if __name__ == "__main__":
main()