Skip to content

Commit 64917bc

Browse files
stephenyan1231fmassa
authored andcommitted
Video transforms (#1353)
* video transforms * [video transforms]in ToTensorVideo, divide value by 255.0 * [video transforms] fix a bug * fix linting * Make changes backwards-compatible
1 parent a15ff20 commit 64917bc

File tree

5 files changed

+468
-12
lines changed

5 files changed

+468
-12
lines changed

test/test_transforms_video.py

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
from __future__ import division
2+
import torch
3+
import torchvision.transforms as transforms
4+
import unittest
5+
import random
6+
import numpy as np
7+
8+
try:
9+
from scipy import stats
10+
except ImportError:
11+
stats = None
12+
13+
14+
class Tester(unittest.TestCase):
15+
16+
def test_random_crop_video(self):
17+
numFrames = random.randint(4, 128)
18+
height = random.randint(10, 32) * 2
19+
width = random.randint(10, 32) * 2
20+
oheight = random.randint(5, (height - 2) / 2) * 2
21+
owidth = random.randint(5, (width - 2) / 2) * 2
22+
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
23+
result = transforms.Compose([
24+
transforms.ToTensorVideo(),
25+
transforms.RandomCropVideo((oheight, owidth)),
26+
])(clip)
27+
assert result.size(2) == oheight
28+
assert result.size(3) == owidth
29+
30+
transforms.RandomCropVideo((oheight, owidth)).__repr__()
31+
32+
def test_random_resized_crop_video(self):
33+
numFrames = random.randint(4, 128)
34+
height = random.randint(10, 32) * 2
35+
width = random.randint(10, 32) * 2
36+
oheight = random.randint(5, (height - 2) / 2) * 2
37+
owidth = random.randint(5, (width - 2) / 2) * 2
38+
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8)
39+
result = transforms.Compose([
40+
transforms.ToTensorVideo(),
41+
transforms.RandomResizedCropVideo((oheight, owidth)),
42+
])(clip)
43+
assert result.size(2) == oheight
44+
assert result.size(3) == owidth
45+
46+
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__()
47+
48+
def test_center_crop_video(self):
49+
numFrames = random.randint(4, 128)
50+
height = random.randint(10, 32) * 2
51+
width = random.randint(10, 32) * 2
52+
oheight = random.randint(5, (height - 2) / 2) * 2
53+
owidth = random.randint(5, (width - 2) / 2) * 2
54+
55+
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255
56+
oh1 = (height - oheight) // 2
57+
ow1 = (width - owidth) // 2
58+
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :]
59+
clipNarrow.fill_(0)
60+
result = transforms.Compose([
61+
transforms.ToTensorVideo(),
62+
transforms.CenterCropVideo((oheight, owidth)),
63+
])(clip)
64+
65+
msg = "height: " + str(height) + " width: " \
66+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
67+
self.assertEqual(result.sum().item(), 0, msg)
68+
69+
oheight += 1
70+
owidth += 1
71+
result = transforms.Compose([
72+
transforms.ToTensorVideo(),
73+
transforms.CenterCropVideo((oheight, owidth)),
74+
])(clip)
75+
sum1 = result.sum()
76+
77+
msg = "height: " + str(height) + " width: " \
78+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
79+
self.assertEqual(sum1.item() > 1, True, msg)
80+
81+
oheight += 1
82+
owidth += 1
83+
result = transforms.Compose([
84+
transforms.ToTensorVideo(),
85+
transforms.CenterCropVideo((oheight, owidth)),
86+
])(clip)
87+
sum2 = result.sum()
88+
89+
msg = "height: " + str(height) + " width: " \
90+
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
91+
self.assertTrue(sum2.item() > 1, msg)
92+
self.assertTrue(sum2.item() > sum1.item(), msg)
93+
94+
@unittest.skipIf(stats is None, 'scipy.stats is not available')
95+
def test_normalize_video(self):
96+
def samples_from_standard_normal(tensor):
97+
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue
98+
return p_value > 0.0001
99+
100+
random_state = random.getstate()
101+
random.seed(42)
102+
for channels in [1, 3]:
103+
numFrames = random.randint(4, 128)
104+
height = random.randint(32, 256)
105+
width = random.randint(32, 256)
106+
mean = random.random()
107+
std = random.random()
108+
clip = torch.normal(mean, std, size=(channels, numFrames, height, width))
109+
mean = [clip[c].mean().item() for c in range(channels)]
110+
std = [clip[c].std().item() for c in range(channels)]
111+
normalized = transforms.NormalizeVideo(mean, std)(clip)
112+
assert samples_from_standard_normal(normalized)
113+
random.setstate(random_state)
114+
115+
# Checking the optional in-place behaviour
116+
tensor = torch.rand((3, 128, 16, 16))
117+
tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor)
118+
assert torch.equal(tensor, tensor_inplace)
119+
120+
transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__()
121+
122+
def test_to_tensor_video(self):
123+
numFrames, height, width = 64, 4, 4
124+
trans = transforms.ToTensorVideo()
125+
126+
with self.assertRaises(TypeError):
127+
trans(np.random.rand(numFrames, height, width, 1).tolist())
128+
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float))
129+
130+
with self.assertRaises(ValueError):
131+
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8))
132+
trans(torch.ones((height, width, 3), dtype=torch.uint8))
133+
trans(torch.ones((width, 3), dtype=torch.uint8))
134+
trans(torch.ones((3), dtype=torch.uint8))
135+
136+
trans.__repr__()
137+
138+
@unittest.skipIf(stats is None, 'scipy.stats not available')
139+
def test_random_horizontal_flip_video(self):
140+
random_state = random.getstate()
141+
random.seed(42)
142+
clip = torch.rand((3, 4, 112, 112), dtype=torch.float)
143+
hclip = clip.flip((-1))
144+
145+
num_samples = 250
146+
num_horizontal = 0
147+
for _ in range(num_samples):
148+
out = transforms.RandomHorizontalFlipVideo()(clip)
149+
if torch.all(torch.eq(out, hclip)):
150+
num_horizontal += 1
151+
152+
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5)
153+
random.setstate(random_state)
154+
assert p_value > 0.0001
155+
156+
num_samples = 250
157+
num_horizontal = 0
158+
for _ in range(num_samples):
159+
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip)
160+
if torch.all(torch.eq(out, hclip)):
161+
num_horizontal += 1
162+
163+
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7)
164+
random.setstate(random_state)
165+
assert p_value > 0.0001
166+
167+
transforms.RandomHorizontalFlipVideo().__repr__()
168+
169+
170+
if __name__ == '__main__':
171+
unittest.main()

torchvision/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .transforms import *
2+
from .transforms_video import *
+101
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import torch
2+
3+
4+
def _is_tensor_video_clip(clip):
5+
if not torch.is_tensor(clip):
6+
raise TypeError("clip should be Tesnor. Got %s" % type(clip))
7+
8+
if not clip.ndimension() == 4:
9+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
10+
11+
return True
12+
13+
14+
def crop(clip, i, j, h, w):
15+
"""
16+
Args:
17+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
18+
"""
19+
assert len(clip.size()) == 4, "clip should be a 4D tensor"
20+
return clip[..., i:i + h, j:j + w]
21+
22+
23+
def resize(clip, target_size, interpolation_mode):
24+
assert len(target_size) == 2, "target size should be tuple (height, width)"
25+
return torch.nn.functional.interpolate(
26+
clip, size=target_size, mode=interpolation_mode
27+
)
28+
29+
30+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
31+
"""
32+
Do spatial cropping and resizing to the video clip
33+
Args:
34+
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W)
35+
i (int): i in (i,j) i.e coordinates of the upper left corner.
36+
j (int): j in (i,j) i.e coordinates of the upper left corner.
37+
h (int): Height of the cropped region.
38+
w (int): Width of the cropped region.
39+
size (tuple(int, int)): height and width of resized clip
40+
Returns:
41+
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W)
42+
"""
43+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
44+
clip = crop(clip, i, j, h, w)
45+
clip = resize(clip, size, interpolation_mode)
46+
return clip
47+
48+
49+
def center_crop(clip, crop_size):
50+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
51+
h, w = clip.size(-2), clip.size(-1)
52+
th, tw = crop_size
53+
assert h >= th and w >= tw, "height and width must be no smaller than crop_size"
54+
55+
i = int(round((h - th) / 2.0))
56+
j = int(round((w - tw) / 2.0))
57+
return crop(clip, i, j, th, tw)
58+
59+
60+
def to_tensor(clip):
61+
"""
62+
Convert tensor data type from uint8 to float, divide value by 255.0 and
63+
permute the dimenions of clip tensor
64+
Args:
65+
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C)
66+
Return:
67+
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W)
68+
"""
69+
_is_tensor_video_clip(clip)
70+
if not clip.dtype == torch.uint8:
71+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
72+
return clip.float().permute(3, 0, 1, 2) / 255.0
73+
74+
75+
def normalize(clip, mean, std, inplace=False):
76+
"""
77+
Args:
78+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
79+
mean (tuple): pixel RGB mean. Size is (3)
80+
std (tuple): pixel standard deviation. Size is (3)
81+
Returns:
82+
normalized clip (torch.tensor): Size is (C, T, H, W)
83+
"""
84+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
85+
if not inplace:
86+
clip = clip.clone()
87+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
88+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
89+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
90+
return clip
91+
92+
93+
def hflip(clip):
94+
"""
95+
Args:
96+
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W)
97+
Returns:
98+
flipped clip (torch.tensor): Size is (C, T, H, W)
99+
"""
100+
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor"
101+
return clip.flip((-1))

torchvision/transforms/transforms.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,15 @@
4040
}
4141

4242

43+
def _get_image_size(img):
44+
if F._is_pil_image(img):
45+
return img.size
46+
elif isinstance(img, torch.Tensor) and img.dim() > 2:
47+
return img.shape[-2:][::-1]
48+
else:
49+
raise TypeError("Unexpected type {}".format(type(img)))
50+
51+
4352
class Compose(object):
4453
"""Composes several transforms together.
4554
@@ -444,7 +453,7 @@ def get_params(img, output_size):
444453
Returns:
445454
tuple: params (i, j, h, w) to be passed to ``crop`` for random crop.
446455
"""
447-
w, h = img.size
456+
w, h = _get_image_size(img)
448457
th, tw = output_size
449458
if w == tw and h == th:
450459
return 0, 0, h, w
@@ -635,7 +644,8 @@ def get_params(img, scale, ratio):
635644
tuple: params (i, j, h, w) to be passed to ``crop`` for a random
636645
sized crop.
637646
"""
638-
area = img.size[0] * img.size[1]
647+
width, height = _get_image_size(img)
648+
area = height * width
639649

640650
for attempt in range(10):
641651
target_area = random.uniform(*scale) * area
@@ -645,24 +655,24 @@ def get_params(img, scale, ratio):
645655
w = int(round(math.sqrt(target_area * aspect_ratio)))
646656
h = int(round(math.sqrt(target_area / aspect_ratio)))
647657

648-
if 0 < w <= img.size[0] and 0 < h <= img.size[1]:
649-
i = random.randint(0, img.size[1] - h)
650-
j = random.randint(0, img.size[0] - w)
658+
if 0 < w <= width and 0 < h <= height:
659+
i = random.randint(0, height - h)
660+
j = random.randint(0, width - w)
651661
return i, j, h, w
652662

653663
# Fallback to central crop
654-
in_ratio = img.size[0] / img.size[1]
664+
in_ratio = float(width) / float(height)
655665
if (in_ratio < min(ratio)):
656-
w = img.size[0]
666+
w = width
657667
h = int(round(w / min(ratio)))
658668
elif (in_ratio > max(ratio)):
659-
h = img.size[1]
669+
h = height
660670
w = int(round(h * max(ratio)))
661671
else: # whole image
662-
w = img.size[0]
663-
h = img.size[1]
664-
i = (img.size[1] - h) // 2
665-
j = (img.size[0] - w) // 2
672+
w = width
673+
h = height
674+
i = (height - h) // 2
675+
j = (width - w) // 2
666676
return i, j, h, w
667677

668678
def __call__(self, img):

0 commit comments

Comments
 (0)