14
14
15
15
import torch
16
16
import torch .utils .benchmark as benchmark
17
+
17
18
from torchcodec ._core import (
18
19
_add_video_stream ,
19
20
create_from_file ,
23
24
get_next_frame ,
24
25
seek_to_pts ,
25
26
)
26
-
27
27
from torchcodec ._frame import FrameBatch
28
28
from torchcodec .decoders import VideoDecoder , VideoStreamMetadata
29
29
@@ -1028,44 +1028,34 @@ def run_benchmarks(
1028
1028
1029
1029
1030
1030
def verify_outputs (decoders_to_run , video_paths , num_samples ):
1031
+ # Import library to show frames that don't match
1031
1032
from tensorcat import tensorcat
1032
1033
1033
- # Add torchcodec_public to the list of decoders to run to use as a baseline
1034
- if matches := list (
1035
- filter (
1036
- lambda decoder_name : "TorchCodecPublic" in decoder_name ,
1037
- decoders_to_run .keys (),
1038
- )
1039
- ):
1040
- torchcodec_display_name = matches [0 ]
1041
- torchcodec_public_decoder = decoders_to_run [matches [0 ]]
1042
- # del decoders_to_run["TorchCodecPublic"]
1043
- print (f"Using { torchcodec_public_decoder } " )
1044
- else :
1045
- torchcodec_display_name = decoder_registry ["torchcodec_public" ].display_name
1046
- options = decoder_registry ["torchcodec_public" ].default_options
1047
- kind = decoder_registry ["torchcodec_public" ].kind
1048
- torchcodec_public_decoder = kind (** options )
1049
- print ("Adding TorchCodecPublic" )
1034
+ # Create TorchCodecPublic decoder to use as a baseline
1035
+ torchcodec_display_name = decoder_registry ["torchcodec_public" ].display_name
1036
+ options = decoder_registry ["torchcodec_public" ].default_options
1037
+ kind = decoder_registry ["torchcodec_public" ].kind
1038
+ torchcodec_public_decoder = kind (** options )
1050
1039
1051
1040
# Get frames using each decoder
1052
1041
for video_file_path in video_paths :
1053
1042
metadata = get_metadata (video_file_path )
1054
1043
metadata_label = f"{ metadata .codec } { metadata .width } x{ metadata .height } , { metadata .duration_seconds } s { metadata .average_fps } fps"
1055
1044
print (f"{ metadata_label = } " )
1056
1045
1057
- # Uncomment to use non-sequential frames
1046
+ # Generate uniformly random PTS
1058
1047
duration = metadata .duration_seconds
1059
- uniform_pts_list = [i * duration / num_samples for i in range (num_samples )]
1048
+ pts_list = [i * duration / num_samples for i in range (num_samples )]
1060
1049
1061
- # Use TorchCodecPublic as the baseline
1050
+ # Get the frames from TorchCodecPublic as the baseline
1062
1051
torchcodec_public_results = decode_and_adjust_frames (
1063
1052
torchcodec_public_decoder ,
1064
1053
video_file_path ,
1065
1054
num_samples = num_samples ,
1066
- pts_list = uniform_pts_list ,
1055
+ pts_list = pts_list ,
1067
1056
)
1068
1057
1058
+ # Get the frames from each decoder
1069
1059
decoders_and_frames = []
1070
1060
for decoder_name , decoder in decoders_to_run .items ():
1071
1061
print (f"video={ video_file_path } , decoder={ decoder_name } " )
@@ -1074,20 +1064,18 @@ def verify_outputs(decoders_to_run, video_paths, num_samples):
1074
1064
decoder ,
1075
1065
video_file_path ,
1076
1066
num_samples = num_samples ,
1077
- pts_list = uniform_pts_list ,
1067
+ pts_list = pts_list ,
1078
1068
)
1079
1069
decoders_and_frames .append ((decoder_name , frames ))
1080
1070
1081
1071
# Compare the frames from all decoders to the frames from TorchCodecPublic
1082
- for i in range (0 , len (decoders_and_frames )):
1083
- curr_decoder_name , curr_decoder_frames = decoders_and_frames [i ]
1084
- print (f"Compare: { f"{ torchcodec_display_name } " } and { curr_decoder_name } " )
1072
+ for curr_decoder_name , curr_decoder_frames in decoders_and_frames :
1085
1073
assert len (torchcodec_public_results ) == len (curr_decoder_frames )
1086
1074
all_match = True
1087
1075
for f1 , f2 in zip (torchcodec_public_results , curr_decoder_frames ):
1088
1076
# Validate that the frames are the same with a tolerance
1089
1077
try :
1090
- torch .testing .assert_close (f1 , f2 ) # , atol=1, rtol=0.01)
1078
+ torch .testing .assert_close (f1 , f2 )
1091
1079
except Exception as e :
1092
1080
tensorcat (f1 )
1093
1081
tensorcat (f2 )
@@ -1103,13 +1091,25 @@ def verify_outputs(decoders_to_run, video_paths, num_samples):
1103
1091
1104
1092
1105
1093
def decode_and_adjust_frames (
1106
- decoder , video_file_path , * , num_samples : int , pts_list : list [float ] | None
1094
+ decoder , video_file_path , * , num_samples : int , pts_list : list [float ]
1107
1095
):
1108
- frames = None
1109
- if pts_list :
1110
- frames = decoder .decode_frames (video_file_path , pts_list )
1111
- else :
1112
- frames = decoder .decode_first_n_frames (video_file_path , num_samples )
1113
- if isinstance (frames , FrameBatch ):
1114
- frames = frames .data
1096
+ frames = []
1097
+ # Decode non-sequential frames using decode_frames function
1098
+ random_frames = decoder .decode_frames (video_file_path , pts_list )
1099
+ # Extract the frames from the FrameBatch if necessary
1100
+ if isinstance (random_frames , FrameBatch ):
1101
+ random_frames = random_frames .data
1102
+ frames .extend (random_frames )
1103
+
1104
+ # Decode sequential frames using decode_first_n_frames function
1105
+ seq_frames = decoder .decode_first_n_frames (video_file_path , num_samples )
1106
+ # Extract the frames from the FrameBatch if necessary
1107
+ if isinstance (seq_frames , FrameBatch ):
1108
+ seq_frames = seq_frames .data
1109
+ frames .extend (seq_frames )
1110
+
1111
+ # Check if frames are returned in H,W,C where 3 is the last dimension.
1112
+ # If so, convert to C,H,W for consistency with other decoders.
1113
+ if frames [0 ].shape [- 1 ] == 3 :
1114
+ frames = [frame .permute (- 1 , * range (frame .dim () - 1 )) for frame in frames ]
1115
1115
return frames
0 commit comments