|
6 | 6 |
|
7 | 7 | import contextlib
|
8 | 8 | import gc
|
| 9 | +import json |
| 10 | +from unittest.mock import patch |
9 | 11 |
|
10 | 12 | import numpy
|
11 | 13 | import pytest
|
@@ -738,6 +740,76 @@ def test_get_frames_in_range(self, stream_index, device, seek_mode):
|
738 | 740 | empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
|
739 | 741 | )
|
740 | 742 |
|
| 743 | + @pytest.mark.parametrize("device", cpu_and_cuda()) |
| 744 | + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) |
| 745 | + @patch("torchcodec._core._metadata._get_stream_json_metadata") |
| 746 | + def test_get_frames_with_missing_num_frames_metadata( |
| 747 | + self, mock_get_stream_json_metadata, device, seek_mode |
| 748 | + ): |
| 749 | + # Create a mock stream_dict to test that initializing VideoDecoder without |
| 750 | + # num_frames_from_header and num_frames_from_content calculates num_frames |
| 751 | + # using the average_fps and duration_seconds metadata. |
| 752 | + mock_stream_dict = { |
| 753 | + "averageFpsFromHeader": 29.97003, |
| 754 | + "beginStreamSecondsFromContent": 0.0, |
| 755 | + "beginStreamSecondsFromHeader": 0.0, |
| 756 | + "bitRate": 128783.0, |
| 757 | + "codec": "h264", |
| 758 | + "durationSecondsFromHeader": 13.013, |
| 759 | + "endStreamSecondsFromContent": 13.013, |
| 760 | + "width": 480, |
| 761 | + "height": 270, |
| 762 | + "mediaType": "video", |
| 763 | + "numFramesFromHeader": None, |
| 764 | + "numFramesFromContent": None, |
| 765 | + } |
| 766 | + # Set the return value of the mock to be the mock_stream_dict |
| 767 | + mock_get_stream_json_metadata.return_value = json.dumps(mock_stream_dict) |
| 768 | + |
| 769 | + decoder = VideoDecoder( |
| 770 | + NASA_VIDEO.path, |
| 771 | + stream_index=3, |
| 772 | + device=device, |
| 773 | + seek_mode=seek_mode, |
| 774 | + ) |
| 775 | + |
| 776 | + assert decoder.metadata.num_frames_from_header is None |
| 777 | + assert decoder.metadata.num_frames_from_content is None |
| 778 | + assert decoder.metadata.duration_seconds is not None |
| 779 | + assert decoder.metadata.average_fps is not None |
| 780 | + assert decoder.metadata.num_frames == int( |
| 781 | + decoder.metadata.duration_seconds * decoder.metadata.average_fps |
| 782 | + ) |
| 783 | + |
| 784 | + # Test get_frames_in_range |
| 785 | + ref_frames9 = NASA_VIDEO.get_frame_data_by_range( |
| 786 | + start=9, stop=10, stream_index=3 |
| 787 | + ).to(device) |
| 788 | + frames9 = decoder.get_frames_in_range(start=9, stop=10) |
| 789 | + assert_frames_equal(ref_frames9, frames9.data) |
| 790 | + |
| 791 | + # Test get_frame_at |
| 792 | + ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9, stream_index=3).to(device) |
| 793 | + frame9 = decoder.get_frame_at(9) |
| 794 | + torch.testing.assert_close(ref_frame9, frame9.data) |
| 795 | + |
| 796 | + # Test get_frames_at |
| 797 | + indices = [0, 1, 25, 35] |
| 798 | + ref_frames = [ |
| 799 | + NASA_VIDEO.get_frame_data_by_index(i, stream_index=3).to(device) |
| 800 | + for i in indices |
| 801 | + ] |
| 802 | + frames = decoder.get_frames_at(indices) |
| 803 | + for ref, frame in zip(ref_frames, frames.data): |
| 804 | + torch.testing.assert_close(ref, frame) |
| 805 | + |
| 806 | + # Test get_frames_played_in_range to get all frames |
| 807 | + assert decoder.metadata.end_stream_seconds is not None |
| 808 | + all_frames = decoder.get_frames_played_in_range( |
| 809 | + decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds |
| 810 | + ) |
| 811 | + assert_frames_equal(all_frames.data, decoder[:]) |
| 812 | + |
741 | 813 | @pytest.mark.parametrize("dimension_order", ["NCHW", "NHWC"])
|
742 | 814 | @pytest.mark.parametrize(
|
743 | 815 | "frame_getter",
|
|
0 commit comments