Skip to content

Verify decoder outputs #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jun 17, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 11 additions & 50 deletions benchmarks/decoders/benchmark_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,66 +8,18 @@
import importlib.resources
import os
import platform
import typing
from dataclasses import dataclass, field
from pathlib import Path

import torch

from benchmark_decoders_library import (
AbstractDecoder,
DecordAccurate,
DecordAccurateBatch,
OpenCVDecoder,
decoder_registry,
plot_data,
run_benchmarks,
TorchAudioDecoder,
TorchCodecCore,
TorchCodecCoreBatch,
TorchCodecCoreCompiled,
TorchCodecCoreNonBatch,
TorchCodecPublic,
TorchCodecPublicNonBatch,
TorchVision,
verify_outputs,
)


@dataclass
class DecoderKind:
display_name: str
kind: typing.Type[AbstractDecoder]
default_options: dict[str, str] = field(default_factory=dict)


decoder_registry = {
"decord": DecoderKind("DecordAccurate", DecordAccurate),
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch),
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore),
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch),
"torchcodec_core_nonbatch": DecoderKind(
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch
),
"torchcodec_core_compiled": DecoderKind(
"TorchCodecCoreCompiled", TorchCodecCoreCompiled
),
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
"torchcodec_public_nonbatch": DecoderKind(
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch
),
"torchvision": DecoderKind(
# We don't compare against TorchVision's "pyav" backend because it doesn't support
# accurate seeks.
"TorchVision[backend=video_reader]",
TorchVision,
{"backend": "video_reader"},
),
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
"opencv": DecoderKind(
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"}
),
}


def in_fbcode() -> bool:
return "FB_PAR_RUNTIME_FILES" in os.environ

Expand Down Expand Up @@ -148,6 +100,12 @@ def main() -> None:
type=str,
default="benchmarks.png",
)
parser.add_argument(
"--verify-outputs",
help="Verify that the outputs of the decoders are the same",
default=False,
action=argparse.BooleanOptionalAction,
)

args = parser.parse_args()
specified_decoders = set(args.decoders.split(","))
Expand Down Expand Up @@ -177,6 +135,9 @@ def main() -> None:
if entry.is_file() and entry.name.endswith(".mp4"):
video_paths.append(entry.path)

if args.verify_outputs:
verify_outputs(decoders_to_run, video_paths, num_uniform_samples)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this option is most useful if it's mutually exclusive with running the actual benchmarks. That way someone can specify it to quickly test the benchmark's correctness. So here I think we should make running the benchmarks and printing the benchmark results the else branch here.

results = run_benchmarks(
decoders_to_run,
video_paths,
Expand Down
136 changes: 135 additions & 1 deletion benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import abc
import json
import subprocess
import typing
import urllib.request
from concurrent.futures import ThreadPoolExecutor, wait
from dataclasses import dataclass
from dataclasses import dataclass, field
from itertools import product
from pathlib import Path

Expand All @@ -23,6 +24,7 @@
get_next_frame,
seek_to_pts,
)
from torchcodec._frame import FrameBatch
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata

torch._dynamo.config.cache_size_limit = 100
Expand Down Expand Up @@ -824,6 +826,42 @@ def convert_result_to_df_item(
return df_item


@dataclass
class DecoderKind:
display_name: str
kind: typing.Type[AbstractDecoder]
default_options: dict[str, str] = field(default_factory=dict)


decoder_registry = {
"decord": DecoderKind("DecordAccurate", DecordAccurate),
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch),
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore),
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch),
"torchcodec_core_nonbatch": DecoderKind(
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch
),
"torchcodec_core_compiled": DecoderKind(
"TorchCodecCoreCompiled", TorchCodecCoreCompiled
),
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic),
"torchcodec_public_nonbatch": DecoderKind(
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch
),
"torchvision": DecoderKind(
# We don't compare against TorchVision's "pyav" backend because it doesn't support
# accurate seeks.
"TorchVision[backend=video_reader]",
TorchVision,
{"backend": "video_reader"},
),
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder),
"opencv": DecoderKind(
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"}
),
}


def run_benchmarks(
decoder_dict: dict[str, AbstractDecoder],
video_files_paths: list[Path],
Expand Down Expand Up @@ -986,3 +1024,99 @@ def run_benchmarks(
compare = benchmark.Compare(results)
compare.print()
return df_data


def verify_outputs(decoders_to_run, video_paths, num_samples):
# Import library to show frames that don't match
from tensorcat import tensorcat

# Reuse TorchCodecPublic decoder stream_index option, if provided.
options = decoder_registry["torchcodec_public"].default_options
if torchcodec_decoder := next(
(
decoder
for name, decoder in decoders_to_run.items()
if "TorchCodecPublic" in name
),
None,
):
options["stream_index"] = str(torchcodec_decoder._stream_index)
# Create default TorchCodecPublic decoder to use as a baseline
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means that the reference decoder will be subject to the options that the user provides, such as seek_mode. I think we shouldn't try to use the options the user provided, but instead decide what the reference decoder is, and always use that. That means that we probably shouldn't use the default options, but decide what options we want to use. Regarding seek_mode, I think we should probably use exact as the reference.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One catch is that given the new 'stream_index' option, we rely on the user to indicate which stream the benchmarks should compare. Without this option, the benchmark for OpenCV vs TorchCodecPublic would not match.

I agree that exact should be used as the reference. I've updated the code to only reuse the stream_index argument.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh! But we want to use the stream_index that the user provided. Good call. :)

torchcodec_public_decoder = TorchCodecPublic(**options)

# Get frames using each decoder
for video_file_path in video_paths:
metadata = get_metadata(video_file_path)
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps"
print(f"{metadata_label=}")

# Generate uniformly random PTS
duration = metadata.duration_seconds
pts_list = [i * duration / num_samples for i in range(num_samples)]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically this is uniformly-spaced PTS values, or just evenly-spaced PTS values. Uniformly random would look like pts_list = (torch.rand(num_samples) * duration).tolist().

# Get the frames from TorchCodecPublic as the baseline
torchcodec_public_results = decode_and_adjust_frames(
torchcodec_public_decoder,
video_file_path,
num_samples=num_samples,
pts_list=pts_list,
)

# Get the frames from each decoder
decoders_and_frames = []
for decoder_name, decoder in decoders_to_run.items():
print(f"video={video_file_path}, decoder={decoder_name}")

frames = decode_and_adjust_frames(
decoder,
video_file_path,
num_samples=num_samples,
pts_list=pts_list,
)
decoders_and_frames.append((decoder_name, frames))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we just want to assert that the frames are close (see comment below), I think we can simplify even further and just do that assertion here. Then we don't need to do a separate loop over decoders_and_frames. In fact, I think we don't even need to record decoders_and_frames, as we don't need to remember the frames after the assertion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I was able to significantly simplify this portion.

# Compare the frames from all decoders to the frames from TorchCodecPublic
for curr_decoder_name, curr_decoder_frames in decoders_and_frames:
assert len(torchcodec_public_results) == len(curr_decoder_frames)
all_match = True
for f1, f2 in zip(torchcodec_public_results, curr_decoder_frames):
# Validate that the frames are the same with a tolerance
try:
torch.testing.assert_close(f1, f2)
except Exception as e:
tensorcat(f1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This library is useful for visually comparing frames, but I am open to removing it since it is not necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think we should remove it. Even though most users won't run the benchmarks, we still want to keep dependencies down.

I also think we can simplify this part, and just do the plain assertion. For this kind of validation, it's generally better to get a failure than it is to just print the result to stdout. Our CI, and a lot of other monitoring infra, will use the process exit code to determine if there was a problem. Doing what we're currently doing will mean that we'd need to parse stdout to figure out if there was a status if we wanted to hook this up to CI.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback, I'll remove the library and keep that in mind going forward.

tensorcat(f2)
all_match = False
print(
f"Error while comparing baseline TorchCodecPublic and {curr_decoder_name}: {e}"
)
break
if all_match:
print(
f"Results of baseline TorchCodecPublic and {curr_decoder_name} match!"
)


def decode_and_adjust_frames(
decoder, video_file_path, *, num_samples: int, pts_list: list[float]
):
frames = []
# Decode non-sequential frames using decode_frames function
random_frames = decoder.decode_frames(video_file_path, pts_list)
# Extract the frames from the FrameBatch if necessary
if isinstance(random_frames, FrameBatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On comments: often the what is something a reader can easily deduce. It's the why that usually needs a comment. In this case, the why is that TorchCodec's batch APIs return a FrameBatch. We sometimes use these APIs in our experiments, and we just return that directly. But for all other decoders, we just return a list of frames.

random_frames = random_frames.data
frames.extend(random_frames)

# Decode sequential frames using decode_first_n_frames function
seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples)
# Extract the frames from the FrameBatch if necessary
if isinstance(seq_frames, FrameBatch):
seq_frames = seq_frames.data
frames.extend(seq_frames)

# Check if frames are returned in H,W,C where 3 is the last dimension.
# If so, convert to C,H,W for consistency with other decoders.
if frames[0].shape[-1] == 3:
frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames]
return frames
Loading