-
Notifications
You must be signed in to change notification settings - Fork 40
Verify decoder outputs #728
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
5e06186
0ad8a3c
e4a6ceb
86c8a4f
f4f8973
dcbc1e8
68f96e1
8517aa2
31c6a5f
c813bfd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
import abc | ||
import json | ||
import subprocess | ||
import typing | ||
import urllib.request | ||
from concurrent.futures import ThreadPoolExecutor, wait | ||
from dataclasses import dataclass | ||
from dataclasses import dataclass, field | ||
from itertools import product | ||
from pathlib import Path | ||
|
||
|
@@ -23,6 +24,7 @@ | |
get_next_frame, | ||
seek_to_pts, | ||
) | ||
from torchcodec._frame import FrameBatch | ||
from torchcodec.decoders import VideoDecoder, VideoStreamMetadata | ||
|
||
torch._dynamo.config.cache_size_limit = 100 | ||
|
@@ -824,6 +826,42 @@ def convert_result_to_df_item( | |
return df_item | ||
|
||
|
||
@dataclass | ||
class DecoderKind: | ||
display_name: str | ||
kind: typing.Type[AbstractDecoder] | ||
default_options: dict[str, str] = field(default_factory=dict) | ||
|
||
|
||
decoder_registry = { | ||
"decord": DecoderKind("DecordAccurate", DecordAccurate), | ||
"decord_batch": DecoderKind("DecordAccurateBatch", DecordAccurateBatch), | ||
"torchcodec_core": DecoderKind("TorchCodecCore", TorchCodecCore), | ||
"torchcodec_core_batch": DecoderKind("TorchCodecCoreBatch", TorchCodecCoreBatch), | ||
"torchcodec_core_nonbatch": DecoderKind( | ||
"TorchCodecCoreNonBatch", TorchCodecCoreNonBatch | ||
), | ||
"torchcodec_core_compiled": DecoderKind( | ||
"TorchCodecCoreCompiled", TorchCodecCoreCompiled | ||
), | ||
"torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), | ||
"torchcodec_public_nonbatch": DecoderKind( | ||
"TorchCodecPublicNonBatch", TorchCodecPublicNonBatch | ||
), | ||
"torchvision": DecoderKind( | ||
# We don't compare against TorchVision's "pyav" backend because it doesn't support | ||
# accurate seeks. | ||
"TorchVision[backend=video_reader]", | ||
TorchVision, | ||
{"backend": "video_reader"}, | ||
), | ||
"torchaudio": DecoderKind("TorchAudio", TorchAudioDecoder), | ||
"opencv": DecoderKind( | ||
"OpenCV[backend=FFMPEG]", OpenCVDecoder, {"backend": "FFMPEG"} | ||
), | ||
} | ||
|
||
|
||
def run_benchmarks( | ||
decoder_dict: dict[str, AbstractDecoder], | ||
video_files_paths: list[Path], | ||
|
@@ -986,3 +1024,99 @@ def run_benchmarks( | |
compare = benchmark.Compare(results) | ||
compare.print() | ||
return df_data | ||
|
||
|
||
def verify_outputs(decoders_to_run, video_paths, num_samples): | ||
# Import library to show frames that don't match | ||
from tensorcat import tensorcat | ||
|
||
# Reuse TorchCodecPublic decoder stream_index option, if provided. | ||
options = decoder_registry["torchcodec_public"].default_options | ||
if torchcodec_decoder := next( | ||
( | ||
decoder | ||
for name, decoder in decoders_to_run.items() | ||
if "TorchCodecPublic" in name | ||
), | ||
None, | ||
): | ||
options["stream_index"] = str(torchcodec_decoder._stream_index) | ||
# Create default TorchCodecPublic decoder to use as a baseline | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This means that the reference decoder will be subject to the options that the user provides, such as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One catch is that given the new 'stream_index' option, we rely on the user to indicate which stream the benchmarks should compare. Without this option, the benchmark for OpenCV vs TorchCodecPublic would not match. I agree that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh! But we want to use the stream_index that the user provided. Good call. :) |
||
torchcodec_public_decoder = TorchCodecPublic(**options) | ||
|
||
# Get frames using each decoder | ||
for video_file_path in video_paths: | ||
metadata = get_metadata(video_file_path) | ||
metadata_label = f"{metadata.codec} {metadata.width}x{metadata.height}, {metadata.duration_seconds}s {metadata.average_fps}fps" | ||
print(f"{metadata_label=}") | ||
|
||
# Generate uniformly random PTS | ||
duration = metadata.duration_seconds | ||
pts_list = [i * duration / num_samples for i in range(num_samples)] | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Technically this is uniformly-spaced PTS values, or just evenly-spaced PTS values. Uniformly random would look like |
||
# Get the frames from TorchCodecPublic as the baseline | ||
torchcodec_public_results = decode_and_adjust_frames( | ||
torchcodec_public_decoder, | ||
video_file_path, | ||
num_samples=num_samples, | ||
pts_list=pts_list, | ||
) | ||
|
||
# Get the frames from each decoder | ||
decoders_and_frames = [] | ||
for decoder_name, decoder in decoders_to_run.items(): | ||
print(f"video={video_file_path}, decoder={decoder_name}") | ||
|
||
frames = decode_and_adjust_frames( | ||
decoder, | ||
video_file_path, | ||
num_samples=num_samples, | ||
pts_list=pts_list, | ||
) | ||
decoders_and_frames.append((decoder_name, frames)) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we just want to assert that the frames are close (see comment below), I think we can simplify even further and just do that assertion here. Then we don't need to do a separate loop over There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point, I was able to significantly simplify this portion. |
||
# Compare the frames from all decoders to the frames from TorchCodecPublic | ||
for curr_decoder_name, curr_decoder_frames in decoders_and_frames: | ||
assert len(torchcodec_public_results) == len(curr_decoder_frames) | ||
all_match = True | ||
for f1, f2 in zip(torchcodec_public_results, curr_decoder_frames): | ||
# Validate that the frames are the same with a tolerance | ||
try: | ||
torch.testing.assert_close(f1, f2) | ||
except Exception as e: | ||
tensorcat(f1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This library is useful for visually comparing frames, but I am open to removing it since it is not necessary. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I think we should remove it. Even though most users won't run the benchmarks, we still want to keep dependencies down. I also think we can simplify this part, and just do the plain assertion. For this kind of validation, it's generally better to get a failure than it is to just print the result to stdout. Our CI, and a lot of other monitoring infra, will use the process exit code to determine if there was a problem. Doing what we're currently doing will mean that we'd need to parse stdout to figure out if there was a status if we wanted to hook this up to CI. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the feedback, I'll remove the library and keep that in mind going forward. |
||
tensorcat(f2) | ||
all_match = False | ||
print( | ||
f"Error while comparing baseline TorchCodecPublic and {curr_decoder_name}: {e}" | ||
) | ||
break | ||
if all_match: | ||
print( | ||
f"Results of baseline TorchCodecPublic and {curr_decoder_name} match!" | ||
) | ||
|
||
|
||
def decode_and_adjust_frames( | ||
decoder, video_file_path, *, num_samples: int, pts_list: list[float] | ||
): | ||
frames = [] | ||
# Decode non-sequential frames using decode_frames function | ||
random_frames = decoder.decode_frames(video_file_path, pts_list) | ||
# Extract the frames from the FrameBatch if necessary | ||
if isinstance(random_frames, FrameBatch): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On comments: often the what is something a reader can easily deduce. It's the why that usually needs a comment. In this case, the why is that TorchCodec's batch APIs return a |
||
random_frames = random_frames.data | ||
frames.extend(random_frames) | ||
|
||
# Decode sequential frames using decode_first_n_frames function | ||
seq_frames = decoder.decode_first_n_frames(video_file_path, num_samples) | ||
# Extract the frames from the FrameBatch if necessary | ||
if isinstance(seq_frames, FrameBatch): | ||
seq_frames = seq_frames.data | ||
frames.extend(seq_frames) | ||
|
||
# Check if frames are returned in H,W,C where 3 is the last dimension. | ||
# If so, convert to C,H,W for consistency with other decoders. | ||
if frames[0].shape[-1] == 3: | ||
frames = [frame.permute(-1, *range(frame.dim() - 1)) for frame in frames] | ||
return frames |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this option is most useful if it's mutually exclusive with running the actual benchmarks. That way someone can specify it to quickly test the benchmark's correctness. So here I think we should make running the benchmarks and printing the benchmark results the else branch here.