Skip to content

Commit e4a6ceb

Browse files
committed
Test sequential and random frames
1 parent 0ad8a3c commit e4a6ceb

File tree

1 file changed

+35
-35
lines changed

1 file changed

+35
-35
lines changed

benchmarks/decoders/benchmark_decoders_library.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import torch
1616
import torch.utils.benchmark as benchmark
17+
1718
from torchcodec._core import (
1819
_add_video_stream,
1920
create_from_file,
@@ -23,7 +24,6 @@
2324
get_next_frame,
2425
seek_to_pts,
2526
)
26-
2727
from torchcodec._frame import FrameBatch
2828
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata
2929

@@ -1028,44 +1028,34 @@ def run_benchmarks(
10281028

10291029

10301030
def verify_outputs(decoders_to_run, video_paths, num_samples):
1031+
# Import library to show frames that don't match
10311032
from tensorcat import tensorcat
10321033

1033-
# Add torchcodec_public to the list of decoders to run to use as a baseline
1034-
if matches := list(
1035-
filter(
1036-
lambda decoder_name: "TorchCodecPublic" in decoder_name,
1037-
decoders_to_run.keys(),
1038-
)
1039-
):
1040-
torchcodec_display_name = matches[0]
1041-
torchcodec_public_decoder = decoders_to_run[matches[0]]
1042-
# del decoders_to_run["TorchCodecPublic"]
1043-
print(f"Using {torchcodec_public_decoder}")
1044-
else:
1045-
torchcodec_display_name = decoder_registry["torchcodec_public"].display_name
1046-
options = decoder_registry["torchcodec_public"].default_options
1047-
kind = decoder_registry["torchcodec_public"].kind
1048-
torchcodec_public_decoder = kind(**options)
1049-
print("Adding TorchCodecPublic")
1034+
# Create TorchCodecPublic decoder to use as a baseline
1035+
torchcodec_display_name = decoder_registry["torchcodec_public"].display_name
1036+
options = decoder_registry["torchcodec_public"].default_options
1037+
kind = decoder_registry["torchcodec_public"].kind
1038+
torchcodec_public_decoder = kind(**options)
10501039

10511040
# Get frames using each decoder
10521041
for video_file_path in video_paths:
10531042
metadata = get_metadata(video_file_path)
10541043
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
10551044
print(f"{metadata_label=}")
10561045

1057-
# Uncomment to use non-sequential frames
1046+
# Generate uniformly random PTS
10581047
duration = metadata.duration_seconds
1059-
uniform_pts_list = [i * duration / num_samples for i in range(num_samples)]
1048+
pts_list = [i * duration / num_samples for i in range(num_samples)]
10601049

1061-
# Use TorchCodecPublic as the baseline
1050+
# Get the frames from TorchCodecPublic as the baseline
10621051
torchcodec_public_results = decode_and_adjust_frames(
10631052
torchcodec_public_decoder,
10641053
video_file_path,
10651054
num_samples=num_samples,
1066-
pts_list=uniform_pts_list,
1055+
pts_list=pts_list,
10671056
)
10681057

1058+
# Get the frames from each decoder
10691059
decoders_and_frames = []
10701060
for decoder_name, decoder in decoders_to_run.items():
10711061
print(f"video={video_file_path}, decoder={decoder_name}")
@@ -1074,20 +1064,18 @@ def verify_outputs(decoders_to_run, video_paths, num_samples):
10741064
decoder,
10751065
video_file_path,
10761066
num_samples=num_samples,
1077-
pts_list=uniform_pts_list,
1067+
pts_list=pts_list,
10781068
)
10791069
decoders_and_frames.append((decoder_name, frames))
10801070

10811071
# Compare the frames from all decoders to the frames from TorchCodecPublic
1082-
for i in range(0, len(decoders_and_frames)):
1083-
curr_decoder_name, curr_decoder_frames = decoders_and_frames[i]
1084-
print(f"Compare: {f"{torchcodec_display_name}"} and {curr_decoder_name}")
1072+
for curr_decoder_name, curr_decoder_frames in decoders_and_frames:
10851073
assert len(torchcodec_public_results) == len(curr_decoder_frames)
10861074
all_match = True
10871075
for f1, f2 in zip(torchcodec_public_results, curr_decoder_frames):
10881076
# Validate that the frames are the same with a tolerance
10891077
try:
1090-
torch.testing.assert_close(f1, f2) # , atol=1, rtol=0.01)
1078+
torch.testing.assert_close(f1, f2)
10911079
except Exception as e:
10921080
tensorcat(f1)
10931081
tensorcat(f2)
@@ -1103,13 +1091,25 @@ def verify_outputs(decoders_to_run, video_paths, num_samples):
11031091

11041092

11051093
def decode_and_adjust_frames(
1106-
decoder, video_file_path, *, num_samples: int, pts_list: list[float] | None
1094+
decoder, video_file_path, *, num_samples: int, pts_list: list[float]
11071095
):
1108-
frames = None
1109-
if pts_list:
1110-
frames = decoder.decode_frames(video_file_path, pts_list)
1111-
else:
1112-
frames = decoder.decode_first_n_frames(video_file_path, num_samples)
1113-
if isinstance(frames, FrameBatch):
1114-
frames = frames.data
1096+
frames = []
1097+
# Decode non-sequential frames using decode_frames function
1098+
random_frames = decoder.decode_frames(video_file_path, pts_list)
1099+
# Extract the frames from the FrameBatch if necessary
1100+
if isinstance(random_frames, FrameBatch):
1101+
random_frames = random_frames.data
1102+
frames.extend(random_frames)
1103+
1104+
# Decode sequential frames using decode_first_n_frames function
1105+
seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples)
1106+
# Extract the frames from the FrameBatch if necessary
1107+
if isinstance(seq_frames, FrameBatch):
1108+
seq_frames = seq_frames.data
1109+
frames.extend(seq_frames)
1110+
1111+
# Check if frames are returned in H,W,C where 3 is the last dimension.
1112+
# If so, convert to C,H,W for consistency with other decoders.
1113+
if frames[0].shape[-1] == 3:
1114+
frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames]
11151115
return frames

0 commit comments

Comments
 (0)