Skip to content

Commit c7bcfad

Browse files
authored
Add torchscript test for io image stuff (#8313)
1 parent eb815ae commit c7bcfad

File tree

1 file changed

+46
-17
lines changed

1 file changed

+46
-17
lines changed

test/test_image.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def normalize_dimensions(img_pil):
7979
("RGB", ImageReadMode.RGB),
8080
],
8181
)
82-
def test_decode_jpeg(img_path, pil_mode, mode):
82+
@pytest.mark.parametrize("scripted", (False, True))
83+
@pytest.mark.parametrize("decode_fun", (decode_jpeg, decode_image))
84+
def test_decode_jpeg(img_path, pil_mode, mode, scripted, decode_fun):
8385

8486
with Image.open(img_path) as img:
8587
is_cmyk = img.mode == "CMYK"
@@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
9294

9395
img_pil = normalize_dimensions(img_pil)
9496
data = read_file(img_path)
95-
img_ljpeg = decode_image(data, mode=mode)
97+
if scripted:
98+
decode_fun = torch.jit.script(decode_fun)
99+
img_ljpeg = decode_fun(data, mode=mode)
96100

97101
# Permit a small variation on pixel values to account for implementation
98102
# differences between Pillow and LibJPEG.
@@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path):
188192
("RGBA", ImageReadMode.RGB_ALPHA),
189193
],
190194
)
191-
def test_decode_png(img_path, pil_mode, mode):
195+
@pytest.mark.parametrize("scripted", (False, True))
196+
@pytest.mark.parametrize("decode_fun", (decode_png, decode_image))
197+
def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
198+
199+
if scripted:
200+
decode_fun = torch.jit.script(decode_fun)
192201

193202
with Image.open(img_path) as img:
194203
if pil_mode is not None:
@@ -202,15 +211,15 @@ def test_decode_png(img_path, pil_mode, mode):
202211
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
203212
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
204213
data = read_file(img_path)
205-
img_lpng = decode_image(data, mode=mode)
214+
img_lpng = decode_fun(data, mode=mode)
206215

207216
img_lpng = _read_png_16(img_path, mode=mode)
208217
assert img_lpng.dtype == torch.int32
209218
# PIL converts 16 bits pngs in uint8
210219
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
211220
else:
212221
data = read_file(img_path)
213-
img_lpng = decode_image(data, mode=mode)
222+
img_lpng = decode_fun(data, mode=mode)
214223

215224
tol = 0 if pil_mode is None else 1
216225

@@ -239,11 +248,13 @@ def test_decode_png_errors():
239248
"img_path",
240249
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
241250
)
242-
def test_encode_png(img_path):
251+
@pytest.mark.parametrize("scripted", (True, False))
252+
def test_encode_png(img_path, scripted):
243253
pil_image = Image.open(img_path)
244254
img_pil = torch.from_numpy(np.array(pil_image))
245255
img_pil = img_pil.permute(2, 0, 1)
246-
png_buf = encode_png(img_pil, compression_level=6)
256+
encode = torch.jit.script(encode_png) if scripted else encode_png
257+
png_buf = encode(img_pil, compression_level=6)
247258

248259
rec_img = Image.open(io.BytesIO(bytes(png_buf.tolist())))
249260
rec_img = torch.from_numpy(np.array(rec_img))
@@ -270,27 +281,39 @@ def test_encode_png_errors():
270281
"img_path",
271282
[pytest.param(png_path, id=_get_safe_image_name(png_path)) for png_path in get_images(IMAGE_DIR, ".png")],
272283
)
273-
def test_write_png(img_path, tmpdir):
284+
@pytest.mark.parametrize("scripted", (True, False))
285+
def test_write_png(img_path, tmpdir, scripted):
274286
pil_image = Image.open(img_path)
275287
img_pil = torch.from_numpy(np.array(pil_image))
276288
img_pil = img_pil.permute(2, 0, 1)
277289

278290
filename, _ = os.path.splitext(os.path.basename(img_path))
279291
torch_png = os.path.join(tmpdir, f"{filename}_torch.png")
280-
write_png(img_pil, torch_png, compression_level=6)
292+
write = torch.jit.script(write_png) if scripted else write_png
293+
write(img_pil, torch_png, compression_level=6)
281294
saved_image = torch.from_numpy(np.array(Image.open(torch_png)))
282295
saved_image = saved_image.permute(2, 0, 1)
283296

284297
assert_equal(img_pil, saved_image)
285298

286299

287-
def test_read_file(tmpdir):
300+
def test_read_image():
301+
# Just testing torchcsript, the functionality is somewhat tested already in other tests.
302+
path = next(get_images(IMAGE_ROOT, ".jpg"))
303+
out = read_image(path)
304+
out_scripted = torch.jit.script(read_image)(path)
305+
torch.testing.assert_close(out, out_scripted, atol=0, rtol=0)
306+
307+
308+
@pytest.mark.parametrize("scripted", (True, False))
309+
def test_read_file(tmpdir, scripted):
288310
fname, content = "test1.bin", b"TorchVision\211\n"
289311
fpath = os.path.join(tmpdir, fname)
290312
with open(fpath, "wb") as f:
291313
f.write(content)
292314

293-
data = read_file(fpath)
315+
fun = torch.jit.script(read_file) if scripted else read_file
316+
data = fun(fpath)
294317
expected = torch.tensor(list(content), dtype=torch.uint8)
295318
os.unlink(fpath)
296319
assert_equal(data, expected)
@@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir):
311334
assert_equal(data, expected)
312335

313336

314-
def test_write_file(tmpdir):
337+
@pytest.mark.parametrize("scripted", (True, False))
338+
def test_write_file(tmpdir, scripted):
315339
fname, content = "test1.bin", b"TorchVision\211\n"
316340
fpath = os.path.join(tmpdir, fname)
317341
content_tensor = torch.tensor(list(content), dtype=torch.uint8)
318-
write_file(fpath, content_tensor)
342+
write = torch.jit.script(write_file) if scripted else write_file
343+
write(fpath, content_tensor)
319344

320345
with open(fpath, "rb") as f:
321346
saved_content = f.read()
@@ -464,7 +489,8 @@ def test_encode_jpeg_errors():
464489
"img_path",
465490
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
466491
)
467-
def test_encode_jpeg(img_path):
492+
@pytest.mark.parametrize("scripted", (True, False))
493+
def test_encode_jpeg(img_path, scripted):
468494
img = read_image(img_path)
469495

470496
pil_img = F.to_pil_image(img)
@@ -473,8 +499,9 @@ def test_encode_jpeg(img_path):
473499

474500
encoded_jpeg_pil = torch.frombuffer(buf.getvalue(), dtype=torch.uint8)
475501

502+
encode = torch.jit.script(encode_jpeg) if scripted else encode_jpeg
476503
for src_img in [img, img.contiguous()]:
477-
encoded_jpeg_torch = encode_jpeg(src_img, quality=75)
504+
encoded_jpeg_torch = encode(src_img, quality=75)
478505
assert_equal(encoded_jpeg_torch, encoded_jpeg_pil)
479506

480507

@@ -483,15 +510,17 @@ def test_encode_jpeg(img_path):
483510
"img_path",
484511
[pytest.param(jpeg_path, id=_get_safe_image_name(jpeg_path)) for jpeg_path in get_images(ENCODE_JPEG, ".jpg")],
485512
)
486-
def test_write_jpeg(img_path, tmpdir):
513+
@pytest.mark.parametrize("scripted", (True, False))
514+
def test_write_jpeg(img_path, tmpdir, scripted):
487515
tmpdir = Path(tmpdir)
488516
img = read_image(img_path)
489517
pil_img = F.to_pil_image(img)
490518

491519
torch_jpeg = str(tmpdir / "torch.jpg")
492520
pil_jpeg = str(tmpdir / "pil.jpg")
493521

494-
write_jpeg(img, torch_jpeg, quality=75)
522+
write = torch.jit.script(write_jpeg) if scripted else write_jpeg
523+
write(img, torch_jpeg, quality=75)
495524
pil_img.save(pil_jpeg, quality=75)
496525

497526
with open(torch_jpeg, "rb") as f:

0 commit comments

Comments
 (0)