-
Notifications
You must be signed in to change notification settings - Fork 7k
/
Copy pathtest_datasets_utils.py
131 lines (114 loc) · 5.76 KB
/
test_datasets_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import os
import sys
import tempfile
import torchvision.datasets.utils as utils
import unittest
import zipfile
import tarfile
import gzip
import warnings
from torch._utils_internal import get_file_path_2
from urllib.error import URLError
from common_utils import get_tmp_dir
TEST_FILE = get_file_path_2(
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
class Tester(unittest.TestCase):
def test_check_md5(self):
fpath = TEST_FILE
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_md5(fpath, correct_md5))
self.assertFalse(utils.check_md5(fpath, false_md5))
def test_check_integrity(self):
existing_fpath = TEST_FILE
nonexisting_fpath = ''
correct_md5 = '9c0bb82894bb3af7f7675ef2b3b6dcdc'
false_md5 = ''
self.assertTrue(utils.check_integrity(existing_fpath, correct_md5))
self.assertFalse(utils.check_integrity(existing_fpath, false_md5))
self.assertTrue(utils.check_integrity(existing_fpath))
self.assertFalse(utils.check_integrity(nonexisting_fpath))
def test_download_url(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_download_url_retry_http(self):
with get_tmp_dir() as temp_dir:
url = "https://github.com/pytorch/vision/archive/master.zip"
try:
utils.download_url(url, temp_dir)
self.assertFalse(len(os.listdir(temp_dir)) == 0)
except URLError:
msg = "could not download test file '{}'".format(url)
warnings.warn(msg, RuntimeWarning)
raise unittest.SkipTest(msg)
def test_download_url_dont_exist(self):
with get_tmp_dir() as temp_dir:
url = "http://github.com/pytorch/vision/archive/this_doesnt_exist.zip"
with self.assertRaises(URLError):
utils.download_url(url, temp_dir)
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_zip(self):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.zip') as f:
with zipfile.ZipFile(f, 'w') as zf:
zf.writestr('file.tst', 'this is the content')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_tar(self):
for ext, mode in zip(['.tar', '.tar.gz', '.tgz'], ['w', 'w:gz', 'w:gz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_tar_xz(self):
for ext, mode in zip(['.tar.xz'], ['w:xz']):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile() as bf:
bf.write("this is the content".encode())
bf.seek(0)
with tempfile.NamedTemporaryFile(suffix=ext) as f:
with tarfile.open(f.name, mode=mode) as zf:
zf.add(bf.name, arcname='file.tst')
utils.extract_archive(f.name, temp_dir)
self.assertTrue(os.path.exists(os.path.join(temp_dir, 'file.tst')))
with open(os.path.join(temp_dir, 'file.tst'), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
@unittest.skipIf('win' in sys.platform, 'temporarily disabled on Windows')
def test_extract_gzip(self):
with get_tmp_dir() as temp_dir:
with tempfile.NamedTemporaryFile(suffix='.gz') as f:
with gzip.GzipFile(f.name, 'wb') as zf:
zf.write('this is the content'.encode())
utils.extract_archive(f.name, temp_dir)
f_name = os.path.join(temp_dir, os.path.splitext(os.path.basename(f.name))[0])
self.assertTrue(os.path.exists(f_name))
with open(os.path.join(f_name), 'r') as nf:
data = nf.read()
self.assertEqual(data, 'this is the content')
def test_verify_str_arg(self):
self.assertEqual("a", utils.verify_str_arg("a", "arg", ("a",)))
self.assertRaises(ValueError, utils.verify_str_arg, 0, ("a",), "arg")
self.assertRaises(ValueError, utils.verify_str_arg, "b", ("a",), "arg")
if __name__ == '__main__':
unittest.main()