Skip to content

Commit 0ad8a3c

Browse files
committed
wip, comppare w tc_public
1 parent 5e06186 commit 0ad8a3c

File tree

1 file changed

+73
-44
lines changed

1 file changed

+73
-44
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17-
1817
from torchcodec._core import (
1918
_add_video_stream,
2019
create_from_file,
@@ -24,6 +23,8 @@
2423
get_next_frame,
2524
seek_to_pts,
2625
)
26+
27+
from torchcodec._frame import FrameBatch
2728
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata
2829

2930
torch._dynamo.config.cache_size_limit = 100
@@ -1028,59 +1029,87 @@ def run_benchmarks(
10281029

10291030
def verify_outputs(decoders_to_run, video_paths, num_samples):
10301031
from tensorcat import tensorcat
1031-
from torchcodec._frame import FrameBatch
10321032

1033-
# Get frames using a decoder
1033+
# Add torchcodec_public to the list of decoders to run to use as a baseline
1034+
if matches := list(
1035+
filter(
1036+
lambda decoder_name: "TorchCodecPublic" in decoder_name,
1037+
decoders_to_run.keys(),
1038+
)
1039+
):
1040+
torchcodec_display_name = matches[0]
1041+
torchcodec_public_decoder = decoders_to_run[matches[0]]
1042+
# del decoders_to_run["TorchCodecPublic"]
1043+
print(f"Using {torchcodec_public_decoder}")
1044+
else:
1045+
torchcodec_display_name = decoder_registry["torchcodec_public"].display_name
1046+
options = decoder_registry["torchcodec_public"].default_options
1047+
kind = decoder_registry["torchcodec_public"].kind
1048+
torchcodec_public_decoder = kind(**options)
1049+
print("Adding TorchCodecPublic")
1050+
1051+
# Get frames using each decoder
10341052
for video_file_path in video_paths:
10351053
metadata = get_metadata(video_file_path)
10361054
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
10371055
print(f"{metadata_label=}")
10381056

1057+
# Uncomment to use non-sequential frames
10391058
duration = metadata.duration_seconds
10401059
uniform_pts_list = [i * duration / num_samples for i in range(num_samples)]
10411060

1061+
# Use TorchCodecPublic as the baseline
1062+
torchcodec_public_results = decode_and_adjust_frames(
1063+
torchcodec_public_decoder,
1064+
video_file_path,
1065+
num_samples=num_samples,
1066+
pts_list=uniform_pts_list,
1067+
)
1068+
10421069
decoders_and_frames = []
10431070
for decoder_name, decoder in decoders_to_run.items():
10441071
print(f"video={video_file_path}, decoder={decoder_name}")
10451072

1046-
# Decode random or uniform frames
1047-
new_frames = decoder.decode_frames(video_file_path, uniform_pts_list)
1048-
if isinstance(new_frames, FrameBatch):
1049-
new_frames = new_frames.data
1050-
decoders_and_frames.append((decoder_name, new_frames))
1051-
1052-
# Decode the first n frames
1053-
# new_frames = decoder.decode_first_n_frames(video_file_path, num_samples)
1054-
# if isinstance(new_frames, FrameBatch):
1055-
# new_frames = new_frames.data
1056-
# decoders_and_frames.append((decoder_name, new_frames))
1057-
1058-
if len(decoders_and_frames) == 1:
1059-
# Display the frames if only 1 decoder passed in
1060-
baseline = decoders_and_frames[0]
1061-
for frame in baseline[1]:
1062-
tensorcat(frame)
1063-
else:
1064-
# Transitively compare the frames from all decoders
1065-
num_decoders = len(decoders_and_frames)
1066-
prev_decoder = decoders_and_frames[-1]
1067-
for i in range(0, num_decoders):
1068-
all_match = True
1069-
curr_decoder = decoders_and_frames[i]
1070-
print(f"Compare: {prev_decoder[0]} and {curr_decoder[0]}")
1071-
assert len(prev_decoder[1]) == len(curr_decoder[1])
1072-
for f1, f2 in zip(curr_decoder[1], prev_decoder[1]):
1073-
# Validate that the frames are the same with a tolerance
1074-
try:
1075-
torch.testing.assert_close(f1, f2)
1076-
except Exception as e:
1077-
tensorcat(f1)
1078-
tensorcat(f2)
1079-
all_match = False
1080-
print(
1081-
f"Error while comparing {curr_decoder[0]} and {prev_decoder[0]}: {e}"
1082-
)
1083-
break
1084-
prev_decoder = curr_decoder
1085-
if all_match:
1086-
print(f"Results of {curr_decoder[0]} and {prev_decoder[0]} match")
1073+
frames = decode_and_adjust_frames(
1074+
decoder,
1075+
video_file_path,
1076+
num_samples=num_samples,
1077+
pts_list=uniform_pts_list,
1078+
)
1079+
decoders_and_frames.append((decoder_name, frames))
1080+
1081+
# Compare the frames from all decoders to the frames from TorchCodecPublic
1082+
for i in range(0, len(decoders_and_frames)):
1083+
curr_decoder_name, curr_decoder_frames = decoders_and_frames[i]
1084+
print(f"Compare: {f"{torchcodec_display_name}"} and {curr_decoder_name}")
1085+
assert len(torchcodec_public_results) == len(curr_decoder_frames)
1086+
all_match = True
1087+
for f1, f2 in zip(torchcodec_public_results, curr_decoder_frames):
1088+
# Validate that the frames are the same with a tolerance
1089+
try:
1090+
torch.testing.assert_close(f1, f2) # , atol=1, rtol=0.01)
1091+
except Exception as e:
1092+
tensorcat(f1)
1093+
tensorcat(f2)
1094+
all_match = False
1095+
print(
1096+
f"Error while comparing {torchcodec_display_name} and {curr_decoder_name}: {e}"
1097+
)
1098+
break
1099+
if all_match:
1100+
print(
1101+
f"Results of {torchcodec_display_name} and {curr_decoder_name} match!"
1102+
)
1103+
1104+
1105+
def decode_and_adjust_frames(
1106+
decoder, video_file_path, *, num_samples: int, pts_list: list[float] | None
1107+
):
1108+
frames = None
1109+
if pts_list:
1110+
frames = decoder.decode_frames(video_file_path, pts_list)
1111+
else:
1112+
frames = decoder.decode_first_n_frames(video_file_path, num_samples)
1113+
if isinstance(frames, FrameBatch):
1114+
frames = frames.data
1115+
return frames

0 commit comments

Comments
 (0)