|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.utils.benchmark as benchmark
|
17 |
| - |
18 | 17 | from torchcodec._core import (
|
19 | 18 | _add_video_stream,
|
20 | 19 | create_from_file,
|
|
24 | 23 | get_next_frame,
|
25 | 24 | seek_to_pts,
|
26 | 25 | )
|
| 26 | + |
| 27 | +from torchcodec._frame import FrameBatch |
27 | 28 | from torchcodec.decoders import VideoDecoder, VideoStreamMetadata
|
28 | 29 |
|
29 | 30 | torch._dynamo.config.cache_size_limit = 100
|
@@ -1028,59 +1029,87 @@ def run_benchmarks(
|
1028 | 1029 |
|
1029 | 1030 | def verify_outputs(decoders_to_run, video_paths, num_samples):
|
1030 | 1031 | from tensorcat import tensorcat
|
1031 |
| - from torchcodec._frame import FrameBatch |
1032 | 1032 |
|
1033 |
| - # Get frames using a decoder |
| 1033 | + # Add torchcodec_public to the list of decoders to run to use as a baseline |
| 1034 | + if matches := list( |
| 1035 | + filter( |
| 1036 | + lambda decoder_name: "TorchCodecPublic" in decoder_name, |
| 1037 | + decoders_to_run.keys(), |
| 1038 | + ) |
| 1039 | + ): |
| 1040 | + torchcodec_display_name = matches[0] |
| 1041 | + torchcodec_public_decoder = decoders_to_run[matches[0]] |
| 1042 | + # del decoders_to_run["TorchCodecPublic"] |
| 1043 | + print(f"Using {torchcodec_public_decoder}") |
| 1044 | + else: |
| 1045 | + torchcodec_display_name = decoder_registry["torchcodec_public"].display_name |
| 1046 | + options = decoder_registry["torchcodec_public"].default_options |
| 1047 | + kind = decoder_registry["torchcodec_public"].kind |
| 1048 | + torchcodec_public_decoder = kind(**options) |
| 1049 | + print("Adding TorchCodecPublic") |
| 1050 | + |
| 1051 | + # Get frames using each decoder |
1034 | 1052 | for video_file_path in video_paths:
|
1035 | 1053 | metadata = get_metadata(video_file_path)
|
1036 | 1054 | metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
|
1037 | 1055 | print(f"{metadata_label=}")
|
1038 | 1056 |
|
| 1057 | + # Uncomment to use non-sequential frames |
1039 | 1058 | duration = metadata.duration_seconds
|
1040 | 1059 | uniform_pts_list = [i * duration / num_samples for i in range(num_samples)]
|
1041 | 1060 |
|
| 1061 | + # Use TorchCodecPublic as the baseline |
| 1062 | + torchcodec_public_results = decode_and_adjust_frames( |
| 1063 | + torchcodec_public_decoder, |
| 1064 | + video_file_path, |
| 1065 | + num_samples=num_samples, |
| 1066 | + pts_list=uniform_pts_list, |
| 1067 | + ) |
| 1068 | + |
1042 | 1069 | decoders_and_frames = []
|
1043 | 1070 | for decoder_name, decoder in decoders_to_run.items():
|
1044 | 1071 | print(f"video={video_file_path}, decoder={decoder_name}")
|
1045 | 1072 |
|
1046 |
| - # Decode random or uniform frames |
1047 |
| - new_frames = decoder.decode_frames(video_file_path, uniform_pts_list) |
1048 |
| - if isinstance(new_frames, FrameBatch): |
1049 |
| - new_frames = new_frames.data |
1050 |
| - decoders_and_frames.append((decoder_name, new_frames)) |
1051 |
| - |
1052 |
| - # Decode the first n frames |
1053 |
| - # new_frames = decoder.decode_first_n_frames(video_file_path, num_samples) |
1054 |
| - # if isinstance(new_frames, FrameBatch): |
1055 |
| - # new_frames = new_frames.data |
1056 |
| - # decoders_and_frames.append((decoder_name, new_frames)) |
1057 |
| - |
1058 |
| - if len(decoders_and_frames) == 1: |
1059 |
| - # Display the frames if only 1 decoder passed in |
1060 |
| - baseline = decoders_and_frames[0] |
1061 |
| - for frame in baseline[1]: |
1062 |
| - tensorcat(frame) |
1063 |
| - else: |
1064 |
| - # Transitively compare the frames from all decoders |
1065 |
| - num_decoders = len(decoders_and_frames) |
1066 |
| - prev_decoder = decoders_and_frames[-1] |
1067 |
| - for i in range(0, num_decoders): |
1068 |
| - all_match = True |
1069 |
| - curr_decoder = decoders_and_frames[i] |
1070 |
| - print(f"Compare: {prev_decoder[0]} and {curr_decoder[0]}") |
1071 |
| - assert len(prev_decoder[1]) == len(curr_decoder[1]) |
1072 |
| - for f1, f2 in zip(curr_decoder[1], prev_decoder[1]): |
1073 |
| - # Validate that the frames are the same with a tolerance |
1074 |
| - try: |
1075 |
| - torch.testing.assert_close(f1, f2) |
1076 |
| - except Exception as e: |
1077 |
| - tensorcat(f1) |
1078 |
| - tensorcat(f2) |
1079 |
| - all_match = False |
1080 |
| - print( |
1081 |
| - f"Error while comparing {curr_decoder[0]} and {prev_decoder[0]}: {e}" |
1082 |
| - ) |
1083 |
| - break |
1084 |
| - prev_decoder = curr_decoder |
1085 |
| - if all_match: |
1086 |
| - print(f"Results of {curr_decoder[0]} and {prev_decoder[0]} match") |
| 1073 | + frames = decode_and_adjust_frames( |
| 1074 | + decoder, |
| 1075 | + video_file_path, |
| 1076 | + num_samples=num_samples, |
| 1077 | + pts_list=uniform_pts_list, |
| 1078 | + ) |
| 1079 | + decoders_and_frames.append((decoder_name, frames)) |
| 1080 | + |
| 1081 | + # Compare the frames from all decoders to the frames from TorchCodecPublic |
| 1082 | + for i in range(0, len(decoders_and_frames)): |
| 1083 | + curr_decoder_name, curr_decoder_frames = decoders_and_frames[i] |
| 1084 | + print(f"Compare: {f"{torchcodec_display_name}"} and {curr_decoder_name}") |
| 1085 | + assert len(torchcodec_public_results) == len(curr_decoder_frames) |
| 1086 | + all_match = True |
| 1087 | + for f1, f2 in zip(torchcodec_public_results, curr_decoder_frames): |
| 1088 | + # Validate that the frames are the same with a tolerance |
| 1089 | + try: |
| 1090 | + torch.testing.assert_close(f1, f2) # , atol=1, rtol=0.01) |
| 1091 | + except Exception as e: |
| 1092 | + tensorcat(f1) |
| 1093 | + tensorcat(f2) |
| 1094 | + all_match = False |
| 1095 | + print( |
| 1096 | + f"Error while comparing {torchcodec_display_name} and {curr_decoder_name}: {e}" |
| 1097 | + ) |
| 1098 | + break |
| 1099 | + if all_match: |
| 1100 | + print( |
| 1101 | + f"Results of {torchcodec_display_name} and {curr_decoder_name} match!" |
| 1102 | + ) |
| 1103 | + |
| 1104 | + |
| 1105 | +def decode_and_adjust_frames( |
| 1106 | + decoder, video_file_path, *, num_samples: int, pts_list: list[float] | None |
| 1107 | +): |
| 1108 | + frames = None |
| 1109 | + if pts_list: |
| 1110 | + frames = decoder.decode_frames(video_file_path, pts_list) |
| 1111 | + else: |
| 1112 | + frames = decoder.decode_first_n_frames(video_file_path, num_samples) |
| 1113 | + if isinstance(frames, FrameBatch): |
| 1114 | + frames = frames.data |
| 1115 | + return frames |
0 commit comments