Skip to content

Commit 86905c1

Browse files
committed
Calculate num_frames, add mock tests
1 parent d88cf1a commit 86905c1

File tree

3 files changed

+109
-2
lines changed

3 files changed

+109
-2
lines changed

src/torchcodec/_core/_metadata.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,15 @@ def num_frames(self) -> Optional[int]:
129129
"""
130130
if self.num_frames_from_content is not None:
131131
return self.num_frames_from_content
132-
else:
132+
elif self.num_frames_from_header is not None:
133133
return self.num_frames_from_header
134+
elif (
135+
self.average_fps_from_header is not None
136+
and self.duration_seconds_from_header is not None
137+
):
138+
return int(self.average_fps_from_header * self.duration_seconds_from_header)
139+
else:
140+
return None
134141

135142
@property
136143
def average_fps(self) -> Optional[float]:

test/test_decoders.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import contextlib
88
import gc
9+
import json
10+
from unittest.mock import patch
911

1012
import numpy
1113
import pytest
@@ -738,6 +740,76 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
738740
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
739741
)
740742

743+
@pytest.mark.parametrize("device", cpu_and_cuda())
744+
@pytest.mark.parametrize("seek_mode", ("exact", "approximate"))
745+
@patch("torchcodec._core._metadata._get_stream_json_metadata")
746+
def test_get_frames_with_missing_num_frames_metadata(
747+
self, mock_get_stream_json_metadata, device, seek_mode
748+
):
749+
# Create a mock stream_dict to test that initializing VideoDecoder without
750+
# num_frames_from_header and num_frames_from_content calculates num_frames
751+
# using the average_fps and duration_seconds metadata.
752+
mock_stream_dict = {
753+
"averageFpsFromHeader": 29.97003,
754+
"beginStreamSecondsFromContent": 0.0,
755+
"beginStreamSecondsFromHeader": 0.0,
756+
"bitRate": 128783.0,
757+
"codec": "h264",
758+
"durationSecondsFromHeader": 13.013,
759+
"endStreamSecondsFromContent": 13.013,
760+
"width": 480,
761+
"height": 270,
762+
"mediaType": "video",
763+
"numFramesFromHeader": None,
764+
"numFramesFromContent": None,
765+
}
766+
# Set the return value of the mock to be the mock_stream_dict
767+
mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict)
768+
769+
decoder = VideoDecoder(
770+
NASA_VIDEO.path,
771+
stream_index=3,
772+
device=device,
773+
seek_mode=seek_mode,
774+
)
775+
776+
assert decoder.metadata.num_frames_from_header is None
777+
assert decoder.metadata.num_frames_from_content is None
778+
assert decoder.metadata.duration_seconds is not None
779+
assert decoder.metadata.average_fps is not None
780+
assert decoder.metadata.num_frames == int(
781+
decoder.metadata.duration_seconds * decoder.metadata.average_fps
782+
)
783+
784+
# Test get_frames_in_range
785+
ref_frames9 = NASA_VIDEO.get_frame_data_by_range(
786+
start=9, stop=10, stream_index=3
787+
).to(device)
788+
frames9 = decoder.get_frames_in_range(start=9, stop=10)
789+
assert_frames_equal(ref_frames9, frames9.data)
790+
791+
# Test get_frame_at
792+
ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9, stream_index=3).to(device)
793+
frame9 = decoder.get_frame_at(9)
794+
torch.testing.assert_close(ref_frame9, frame9.data)
795+
796+
# Test get_frames_at
797+
indices = [0, 1, 25, 35]
798+
ref_frames = [
799+
NASA_VIDEO.get_frame_data_by_index(i, stream_index=3).to(device)
800+
for i in indices
801+
]
802+
frames = decoder.get_frames_at(indices)
803+
for ref, frame in zip(ref_frames, frames.data):
804+
torch.testing.assert_close(ref, frame)
805+
806+
# Test get_frames_played_in_range to get all frames
807+
assert decoder.metadata.end_stream_seconds is not None
808+
all_frames = decoder.get_frames_played_in_range(
809+
decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds
810+
)
811+
assert_frames_equal(all_frames.data, decoder[:])
812+
741813
@pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"])
742814
@pytest.mark.parametrize(
743815
"frame_getter",

test/test_metadata.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def test_get_metadata_audio_file(metadata_getter):
119119

120120
@pytest.mark.parametrize(
121121
"num_frames_from_header, num_frames_from_content, expected_num_frames",
122-
[(None, 10, 10), (10, None, 10), (None, None, None)],
122+
[(10, 20, 20), (None, 10, 10), (10, None, 10)],
123123
)
124124
def test_num_frames_fallback(
125125
num_frames_from_header, num_frames_from_content, expected_num_frames
@@ -143,6 +143,34 @@ def test_num_frames_fallback(
143143
assert metadata.num_frames == expected_num_frames
144144

145145

146+
@pytest.mark.parametrize(
147+
"average_fps_from_header, duration_seconds_from_header, expected_num_frames",
148+
[(60, 10, 600), (60, None, None), (None, 10, None), (None, None, None)],
149+
)
150+
def test_calculate_num_frames_using_fps_and_duration(
151+
average_fps_from_header, duration_seconds_from_header, expected_num_frames
152+
):
153+
"""Check that if num_frames_from_content and num_frames_from_header are missing,
154+
`.num_frames` is calculated using average_fps_from_header and duration_seconds_from_header
155+
"""
156+
metadata = VideoStreamMetadata(
157+
duration_seconds_from_header=duration_seconds_from_header,
158+
bit_rate=123,
159+
num_frames_from_header=None, # None to test calculating num_frames
160+
num_frames_from_content=None, # None to test calculating num_frames
161+
begin_stream_seconds_from_header=0,
162+
begin_stream_seconds_from_content=0,
163+
end_stream_seconds_from_content=4,
164+
codec="whatever",
165+
width=123,
166+
height=321,
167+
average_fps_from_header=average_fps_from_header,
168+
stream_index=0,
169+
)
170+
171+
assert metadata.num_frames == expected_num_frames
172+
173+
146174
def test_repr():
147175
# Test for calls to print(), str(), etc. Useful to make sure we don't forget
148176
# to add additional @properties to __repr__

0 commit comments

Comments
 (0)