Skip to content

Commit ffac96c

Browse files
authored
Calculate num_frames if missing from metadata (#732)
1 parent 6907860 commit ffac96c

File tree

4 files changed

+172
-12
lines changed

4 files changed

+172
-12
lines changed

src/torchcodec/_core/_metadata.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -85,17 +85,27 @@ class VideoStreamMetadata(StreamMetadata):
8585
def duration_seconds(self) -> Optional[float]:
8686
"""Duration of the stream in seconds. We try to calculate the duration
8787
from the actual frames if a :term:`scan` was performed. Otherwise we
88-
fall back to ``duration_seconds_from_header``.
88+
fall back to ``duration_seconds_from_header``. If that value is also None,
89+
we instead calculate the duration from ``num_frames_from_header`` and
90+
``average_fps_from_header``.
8991
"""
9092
if (
91-
self.end_stream_seconds_from_content is None
92-
or self.begin_stream_seconds_from_content is None
93+
self.end_stream_seconds_from_content is not None
94+
and self.begin_stream_seconds_from_content is not None
9395
):
96+
return (
97+
self.end_stream_seconds_from_content
98+
- self.begin_stream_seconds_from_content
99+
)
100+
elif self.duration_seconds_from_header is not None:
94101
return self.duration_seconds_from_header
95-
return (
96-
self.end_stream_seconds_from_content
97-
- self.begin_stream_seconds_from_content
98-
)
102+
elif (
103+
self.num_frames_from_header is not None
104+
and self.average_fps_from_header is not None
105+
):
106+
return self.num_frames_from_header / self.average_fps_from_header
107+
else:
108+
return None
99109

100110
@property
101111
def begin_stream_seconds(self) -> float:
@@ -123,14 +133,22 @@ def end_stream_seconds(self) -> Optional[float]:
123133

124134
@property
125135
def num_frames(self) -> Optional[int]:
126-
"""Number of frames in the stream. This corresponds to
127-
``num_frames_from_content`` if a :term:`scan` was made, otherwise it
128-
corresponds to ``num_frames_from_header``.
136+
"""Number of frames in the stream (int or None).
137+
This corresponds to ``num_frames_from_content`` if a :term:`scan` was made,
138+
otherwise it corresponds to ``num_frames_from_header``. If that value is also
139+
None, the number of frames is calculated from the duration and the average fps.
129140
"""
130141
if self.num_frames_from_content is not None:
131142
return self.num_frames_from_content
132-
else:
143+
elif self.num_frames_from_header is not None:
133144
return self.num_frames_from_header
145+
elif (
146+
self.average_fps_from_header is not None
147+
and self.duration_seconds_from_header is not None
148+
):
149+
return int(self.average_fps_from_header * self.duration_seconds_from_header)
150+
else:
151+
return None
134152

135153
@property
136154
def average_fps(self) -> Optional[float]:

test/test_decoders.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import contextlib
88
import gc
9+
import json
10+
from unittest.mock import patch
911

1012
import numpy
1113
import pytest
@@ -738,6 +740,56 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
738740
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
739741
)
740742

743+
@pytest.mark.parametrize("device", cpu_and_cuda())
744+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
745+
@patch("torchcodec._core._metadata._get_stream_json_metadata")
746+
def test_get_frames_with_missing_num_frames_metadata(
747+
self, mock_get_stream_json_metadata, device, seek_mode
748+
):
749+
# Create a mock stream_dict to test that initializing VideoDecoder without
750+
# num_frames_from_header and num_frames_from_content calculates num_frames
751+
# using the average_fps and duration_seconds metadata.
752+
mock_stream_dict = {
753+
"averageFpsFromHeader": 29.97003,
754+
"beginStreamSecondsFromContent": 0.0,
755+
"beginStreamSecondsFromHeader": 0.0,
756+
"bitRate": 128783.0,
757+
"codec": "h264",
758+
"durationSecondsFromHeader": 13.013,
759+
"endStreamSecondsFromContent": 13.013,
760+
"width": 480,
761+
"height": 270,
762+
"mediaType": "video",
763+
"numFramesFromHeader": None,
764+
"numFramesFromContent": None,
765+
}
766+
# Set the return value of the mock to be the mock_stream_dict
767+
mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict)
768+
769+
decoder = VideoDecoder(
770+
NASA_VIDEO.path,
771+
stream_index=3,
772+
device=device,
773+
seek_mode=seek_mode,
774+
)
775+
776+
assert decoder.metadata.num_frames_from_header is None
777+
assert decoder.metadata.num_frames_from_content is None
778+
assert decoder.metadata.duration_seconds is not None
779+
assert decoder.metadata.average_fps is not None
780+
assert decoder.metadata.num_frames == int(
781+
decoder.metadata.duration_seconds * decoder.metadata.average_fps
782+
)
783+
assert len(decoder) == 390
784+
785+
# Test get_frames_in_range Python logic which uses the num_frames metadata mocked earlier.
786+
# The frame is read at the C++ level.
787+
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
788+
start=9, stop=10, stream_index=3
789+
).to(device)
790+
frames9 = decoder.get_frames_in_range(start=9, stop=10)
791+
assert_frames_equal(ref_frames9, frames9.data)
792+
741793
@pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"])
742794
@pytest.mark.parametrize(
743795
"frame_getter",

test/test_metadata.py

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_get_metadata_audio_file(metadata_getter):
119119

120120
@pytest.mark.parametrize(
121121
"num_frames_from_header, num_frames_from_content, expected_num_frames",
122-
[(None, 10, 10), (10, None, 10), (None, None, None)],
122+
[(10, 20, 20), (None, 10, 10), (10, None, 10)],
123123
)
124124
def test_num_frames_fallback(
125125
num_frames_from_header, num_frames_from_content, expected_num_frames
@@ -143,6 +143,93 @@ def test_num_frames_fallback(
143143
assert metadata.num_frames == expected_num_frames
144144

145145

146+
@pytest.mark.parametrize(
147+
"average_fps_from_header, duration_seconds_from_header, expected_num_frames",
148+
[(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)],
149+
)
150+
def test_calculate_num_frames_using_fps_and_duration(
151+
average_fps_from_header, duration_seconds_from_header, expected_num_frames
152+
):
153+
"""Check that if num_frames_from_content and num_frames_from_header are missing,
154+
`.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header
155+
"""
156+
metadata = VideoStreamMetadata(
157+
duration_seconds_from_header=duration_seconds_from_header,
158+
bit_rate=123,
159+
num_frames_from_header=None, # None to test calculating num_frames
160+
num_frames_from_content=None, # None to test calculating num_frames
161+
begin_stream_seconds_from_header=0,
162+
begin_stream_seconds_from_content=0,
163+
end_stream_seconds_from_content=4,
164+
codec="whatever",
165+
width=123,
166+
height=321,
167+
average_fps_from_header=average_fps_from_header,
168+
stream_index=0,
169+
)
170+
171+
assert metadata.num_frames == expected_num_frames
172+
173+
174+
@pytest.mark.parametrize(
175+
"duration_seconds_from_header, begin_stream_seconds_from_content, end_stream_seconds_from_content, expected_duration_seconds",
176+
[(60, 5, 20, 15), (60, 1, None, 60), (60, None, 1, 60), (None, 0, 10, 10)],
177+
)
178+
def test_duration_seconds_fallback(
179+
duration_seconds_from_header,
180+
begin_stream_seconds_from_content,
181+
end_stream_seconds_from_content,
182+
expected_duration_seconds,
183+
):
184+
"""Check that using begin_stream_seconds_from_content and end_stream_seconds_from_content to calculate `.duration_seconds`
185+
has priority. If either value is missing, duration_seconds_from_header is used.
186+
"""
187+
metadata = VideoStreamMetadata(
188+
duration_seconds_from_header=duration_seconds_from_header,
189+
bit_rate=123,
190+
num_frames_from_header=5,
191+
num_frames_from_content=10,
192+
begin_stream_seconds_from_header=0,
193+
begin_stream_seconds_from_content=begin_stream_seconds_from_content,
194+
end_stream_seconds_from_content=end_stream_seconds_from_content,
195+
codec="whatever",
196+
width=123,
197+
height=321,
198+
average_fps_from_header=5,
199+
stream_index=0,
200+
)
201+
202+
assert metadata.duration_seconds == expected_duration_seconds
203+
204+
205+
@pytest.mark.parametrize(
206+
"num_frames_from_header, average_fps_from_header, expected_duration_seconds",
207+
[(100, 10, 10), (100, None, None), (None, 10, None), (None, None, None)],
208+
)
209+
def test_calculate_duration_seconds_using_fps_and_num_frames(
210+
num_frames_from_header, average_fps_from_header, expected_duration_seconds
211+
):
212+
"""Check that duration_seconds is calculated using average_fps_from_header and num_frames_from_header
213+
if duration_seconds_from_header is missing.
214+
"""
215+
metadata = VideoStreamMetadata(
216+
duration_seconds_from_header=None, # None to test calculating duration_seconds
217+
bit_rate=123,
218+
num_frames_from_header=num_frames_from_header,
219+
num_frames_from_content=10,
220+
begin_stream_seconds_from_header=0,
221+
begin_stream_seconds_from_content=None, # None to test calculating duration_seconds
222+
end_stream_seconds_from_content=None, # None to test calculating duration_seconds
223+
codec="whatever",
224+
width=123,
225+
height=321,
226+
average_fps_from_header=average_fps_from_header,
227+
stream_index=0,
228+
)
229+
assert metadata.duration_seconds_from_header is None
230+
assert metadata.duration_seconds == expected_duration_seconds
231+
232+
146233
def test_repr():
147234
# Test for calls to print(), str(), etc. Useful to make sure we don't forget
148235
# to add additional @properties to __repr__

test/test_samplers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -592,6 +592,9 @@ def restore_metadata():
592592
with restore_metadata():
593593
decoder.metadata.end_stream_seconds_from_content = None
594594
decoder.metadata.duration_seconds_from_header = None
595+
decoder.metadata.num_frames_from_header = (
596+
None # Set to none to prevent fallback calculation
597+
)
595598
with pytest.raises(
596599
ValueError, match="Could not infer stream end from video metadata"
597600
):

0 commit comments

Comments
 (0)