@@ -79,7 +79,9 @@ def normalize_dimensions(img_pil):
79
79
("RGB" , ImageReadMode .RGB ),
80
80
],
81
81
)
82
- def test_decode_jpeg (img_path , pil_mode , mode ):
82
+ @pytest .mark .parametrize ("scripted" , (False , True ))
83
+ @pytest .mark .parametrize ("decode_fun" , (decode_jpeg , decode_image ))
84
+ def test_decode_jpeg (img_path , pil_mode , mode , scripted , decode_fun ):
83
85
84
86
with Image .open (img_path ) as img :
85
87
is_cmyk = img .mode == "CMYK"
@@ -92,7 +94,9 @@ def test_decode_jpeg(img_path, pil_mode, mode):
92
94
93
95
img_pil = normalize_dimensions (img_pil )
94
96
data = read_file (img_path )
95
- img_ljpeg = decode_image (data , mode = mode )
97
+ if scripted :
98
+ decode_fun = torch .jit .script (decode_fun )
99
+ img_ljpeg = decode_fun (data , mode = mode )
96
100
97
101
# Permit a small variation on pixel values to account for implementation
98
102
# differences between Pillow and LibJPEG.
@@ -188,7 +192,12 @@ def test_damaged_corrupt_images(img_path):
188
192
("RGBA" , ImageReadMode .RGB_ALPHA ),
189
193
],
190
194
)
191
- def test_decode_png (img_path , pil_mode , mode ):
195
+ @pytest .mark .parametrize ("scripted" , (False , True ))
196
+ @pytest .mark .parametrize ("decode_fun" , (decode_png , decode_image ))
197
+ def test_decode_png (img_path , pil_mode , mode , scripted , decode_fun ):
198
+
199
+ if scripted :
200
+ decode_fun = torch .jit .script (decode_fun )
192
201
193
202
with Image .open (img_path ) as img :
194
203
if pil_mode is not None :
@@ -202,15 +211,15 @@ def test_decode_png(img_path, pil_mode, mode):
202
211
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
203
212
with pytest .raises (RuntimeError , match = "At most 8-bit PNG images are supported" ):
204
213
data = read_file (img_path )
205
- img_lpng = decode_image (data , mode = mode )
214
+ img_lpng = decode_fun (data , mode = mode )
206
215
207
216
img_lpng = _read_png_16 (img_path , mode = mode )
208
217
assert img_lpng .dtype == torch .int32
209
218
# PIL converts 16 bits pngs in uint8
210
219
img_lpng = torch .round (img_lpng / (2 ** 16 - 1 ) * 255 ).to (torch .uint8 )
211
220
else :
212
221
data = read_file (img_path )
213
- img_lpng = decode_image (data , mode = mode )
222
+ img_lpng = decode_fun (data , mode = mode )
214
223
215
224
tol = 0 if pil_mode is None else 1
216
225
@@ -239,11 +248,13 @@ def test_decode_png_errors():
239
248
"img_path" ,
240
249
[pytest .param (png_path , id = _get_safe_image_name (png_path )) for png_path in get_images (IMAGE_DIR , ".png" )],
241
250
)
242
- def test_encode_png (img_path ):
251
+ @pytest .mark .parametrize ("scripted" , (True , False ))
252
+ def test_encode_png (img_path , scripted ):
243
253
pil_image = Image .open (img_path )
244
254
img_pil = torch .from_numpy (np .array (pil_image ))
245
255
img_pil = img_pil .permute (2 , 0 , 1 )
246
- png_buf = encode_png (img_pil , compression_level = 6 )
256
+ encode = torch .jit .script (encode_png ) if scripted else encode_png
257
+ png_buf = encode (img_pil , compression_level = 6 )
247
258
248
259
rec_img = Image .open (io .BytesIO (bytes (png_buf .tolist ())))
249
260
rec_img = torch .from_numpy (np .array (rec_img ))
@@ -270,27 +281,39 @@ def test_encode_png_errors():
270
281
"img_path" ,
271
282
[pytest .param (png_path , id = _get_safe_image_name (png_path )) for png_path in get_images (IMAGE_DIR , ".png" )],
272
283
)
273
- def test_write_png (img_path , tmpdir ):
284
+ @pytest .mark .parametrize ("scripted" , (True , False ))
285
+ def test_write_png (img_path , tmpdir , scripted ):
274
286
pil_image = Image .open (img_path )
275
287
img_pil = torch .from_numpy (np .array (pil_image ))
276
288
img_pil = img_pil .permute (2 , 0 , 1 )
277
289
278
290
filename , _ = os .path .splitext (os .path .basename (img_path ))
279
291
torch_png = os .path .join (tmpdir , f"{ filename } _torch.png" )
280
- write_png (img_pil , torch_png , compression_level = 6 )
292
+ write = torch .jit .script (write_png ) if scripted else write_png
293
+ write (img_pil , torch_png , compression_level = 6 )
281
294
saved_image = torch .from_numpy (np .array (Image .open (torch_png )))
282
295
saved_image = saved_image .permute (2 , 0 , 1 )
283
296
284
297
assert_equal (img_pil , saved_image )
285
298
286
299
287
- def test_read_file (tmpdir ):
300
+ def test_read_image ():
301
+ # Just testing torchcsript, the functionality is somewhat tested already in other tests.
302
+ path = next (get_images (IMAGE_ROOT , ".jpg" ))
303
+ out = read_image (path )
304
+ out_scripted = torch .jit .script (read_image )(path )
305
+ torch .testing .assert_close (out , out_scripted , atol = 0 , rtol = 0 )
306
+
307
+
308
+ @pytest .mark .parametrize ("scripted" , (True , False ))
309
+ def test_read_file (tmpdir , scripted ):
288
310
fname , content = "test1.bin" , b"TorchVision\211 \n "
289
311
fpath = os .path .join (tmpdir , fname )
290
312
with open (fpath , "wb" ) as f :
291
313
f .write (content )
292
314
293
- data = read_file (fpath )
315
+ fun = torch .jit .script (read_file ) if scripted else read_file
316
+ data = fun (fpath )
294
317
expected = torch .tensor (list (content ), dtype = torch .uint8 )
295
318
os .unlink (fpath )
296
319
assert_equal (data , expected )
@@ -311,11 +334,13 @@ def test_read_file_non_ascii(tmpdir):
311
334
assert_equal (data , expected )
312
335
313
336
314
- def test_write_file (tmpdir ):
337
+ @pytest .mark .parametrize ("scripted" , (True , False ))
338
+ def test_write_file (tmpdir , scripted ):
315
339
fname , content = "test1.bin" , b"TorchVision\211 \n "
316
340
fpath = os .path .join (tmpdir , fname )
317
341
content_tensor = torch .tensor (list (content ), dtype = torch .uint8 )
318
- write_file (fpath , content_tensor )
342
+ write = torch .jit .script (write_file ) if scripted else write_file
343
+ write (fpath , content_tensor )
319
344
320
345
with open (fpath , "rb" ) as f :
321
346
saved_content = f .read ()
@@ -464,7 +489,8 @@ def test_encode_jpeg_errors():
464
489
"img_path" ,
465
490
[pytest .param (jpeg_path , id = _get_safe_image_name (jpeg_path )) for jpeg_path in get_images (ENCODE_JPEG , ".jpg" )],
466
491
)
467
- def test_encode_jpeg (img_path ):
492
+ @pytest .mark .parametrize ("scripted" , (True , False ))
493
+ def test_encode_jpeg (img_path , scripted ):
468
494
img = read_image (img_path )
469
495
470
496
pil_img = F .to_pil_image (img )
@@ -473,8 +499,9 @@ def test_encode_jpeg(img_path):
473
499
474
500
encoded_jpeg_pil = torch .frombuffer (buf .getvalue (), dtype = torch .uint8 )
475
501
502
+ encode = torch .jit .script (encode_jpeg ) if scripted else encode_jpeg
476
503
for src_img in [img , img .contiguous ()]:
477
- encoded_jpeg_torch = encode_jpeg (src_img , quality = 75 )
504
+ encoded_jpeg_torch = encode (src_img , quality = 75 )
478
505
assert_equal (encoded_jpeg_torch , encoded_jpeg_pil )
479
506
480
507
@@ -483,15 +510,17 @@ def test_encode_jpeg(img_path):
483
510
"img_path" ,
484
511
[pytest .param (jpeg_path , id = _get_safe_image_name (jpeg_path )) for jpeg_path in get_images (ENCODE_JPEG , ".jpg" )],
485
512
)
486
- def test_write_jpeg (img_path , tmpdir ):
513
+ @pytest .mark .parametrize ("scripted" , (True , False ))
514
+ def test_write_jpeg (img_path , tmpdir , scripted ):
487
515
tmpdir = Path (tmpdir )
488
516
img = read_image (img_path )
489
517
pil_img = F .to_pil_image (img )
490
518
491
519
torch_jpeg = str (tmpdir / "torch.jpg" )
492
520
pil_jpeg = str (tmpdir / "pil.jpg" )
493
521
494
- write_jpeg (img , torch_jpeg , quality = 75 )
522
+ write = torch .jit .script (write_jpeg ) if scripted else write_jpeg
523
+ write (img , torch_jpeg , quality = 75 )
495
524
pil_img .save (pil_jpeg , quality = 75 )
496
525
497
526
with open (torch_jpeg , "rb" ) as f :
0 commit comments