Skip to content

Commit 2287c8f

Browse files
authored
Optimize read_video_timestamps for some formats (#1168)
* Optimize read_video_timestamps for some formats * Add some tests
1 parent 59c97d7 commit 2287c8f

File tree

2 files changed

+33
-7
lines changed

2 files changed

+33
-7
lines changed

test/test_io.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,19 @@ def temp_video(num_frames, height, width, fps, lossless=False, video_codec=None,
5252
yield f.name, data
5353

5454

55+
@unittest.skipIf(av is None, "PyAV unavailable")
5556
class Tester(unittest.TestCase):
5657
# compression adds artifacts, thus we add a tolerance of
5758
# 6 in 0-255 range
5859
TOLERANCE = 6
5960

60-
@unittest.skipIf(av is None, "PyAV unavailable")
6161
def test_write_read_video(self):
6262
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
6363
lv, _, info = io.read_video(f_name)
6464

6565
self.assertTrue(data.equal(lv))
6666
self.assertEqual(info["video_fps"], 5)
6767

68-
@unittest.skipIf(av is None, "PyAV unavailable")
6968
def test_read_timestamps(self):
7069
with temp_video(10, 300, 300, 5) as (f_name, data):
7170
pts, _ = io.read_video_timestamps(f_name)
@@ -81,7 +80,6 @@ def test_read_timestamps(self):
8180

8281
self.assertEqual(pts, expected_pts)
8382

84-
@unittest.skipIf(av is None, "PyAV unavailable")
8583
def test_read_partial_video(self):
8684
with temp_video(10, 300, 300, 5, lossless=True) as (f_name, data):
8785
pts, _ = io.read_video_timestamps(f_name)
@@ -96,7 +94,6 @@ def test_read_partial_video(self):
9694
self.assertEqual(len(lv), 4)
9795
self.assertTrue(data[4:8].equal(lv))
9896

99-
@unittest.skipIf(av is None, "PyAV unavailable")
10097
def test_read_partial_video_bframes(self):
10198
# do not use lossless encoding, to test the presence of B-frames
10299
options = {'bframes': '16', 'keyint': '10', 'min-keyint': '4'}
@@ -113,7 +110,6 @@ def test_read_partial_video_bframes(self):
113110
self.assertEqual(len(lv), 4)
114111
self.assertTrue((data[4:8].float() - lv.float()).abs().max() < self.TOLERANCE)
115112

116-
@unittest.skipIf(av is None, "PyAV unavailable")
117113
def test_read_packed_b_frames_divx_file(self):
118114
with get_tmp_dir() as temp_dir:
119115
name = "hmdb51_Turnk_r_Pippi_Michel_cartwheel_f_cm_np2_le_med_6.avi"
@@ -129,6 +125,23 @@ def test_read_packed_b_frames_divx_file(self):
129125
warnings.warn(msg, RuntimeWarning)
130126
raise unittest.SkipTest(msg)
131127

128+
def test_read_timestamps_from_packet(self):
129+
with temp_video(10, 300, 300, 5, video_codec='mpeg4') as (f_name, data):
130+
pts, _ = io.read_video_timestamps(f_name)
131+
132+
# note: not all formats/codecs provide accurate information for computing the
133+
# timestamps. For the format that we use here, this information is available,
134+
# so we use it as a baseline
135+
container = av.open(f_name)
136+
stream = container.streams[0]
137+
# make sure we went through the optimized codepath
138+
self.assertIn(b'Lavc', stream.codec_context.extradata)
139+
pts_step = int(round(float(1 / (stream.average_rate * stream.time_base))))
140+
num_frames = int(round(float(stream.average_rate * stream.time_base * stream.duration)))
141+
expected_pts = [i * pts_step for i in range(num_frames)]
142+
143+
self.assertEqual(pts, expected_pts)
144+
132145
# TODO add tests for audio
133146

134147

torchvision/io/video.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ def read_video(filename, start_pts=0, end_pts=None):
185185
return vframes, aframes, info
186186

187187

188+
def _can_read_timestamps_from_packets(container):
189+
extradata = container.streams[0].codec_context.extradata
190+
if extradata is None:
191+
return False
192+
if b"Lavc" in extradata:
193+
return True
194+
return False
195+
196+
188197
def read_video_timestamps(filename):
189198
"""
190199
List the video frames timestamps.
@@ -205,8 +214,12 @@ def read_video_timestamps(filename):
205214
video_frames = []
206215
video_fps = None
207216
if container.streams.video:
208-
video_frames = _read_from_stream(container, 0, float("inf"),
209-
container.streams.video[0], {'video': 0})
217+
if _can_read_timestamps_from_packets(container):
218+
# fast path
219+
video_frames = [x for x in container.demux(video=0) if x.pts is not None]
220+
else:
221+
video_frames = _read_from_stream(container, 0, float("inf"),
222+
container.streams.video[0], {'video': 0})
210223
video_fps = float(container.streams.video[0].average_rate)
211224
container.close()
212225
return [x.pts for x in video_frames], video_fps

0 commit comments

Comments
 (0)