Skip to content

Update C++ metadata names to match python #707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/torchcodec/_core/Metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ struct StreamMetadata {
AVMediaType mediaType;
std::optional<AVCodecID> codecId;
std::optional<std::string> codecName;
std::optional<double> durationSeconds;
std::optional<double> beginStreamFromHeader;
std::optional<int64_t> numFrames;
std::optional<double> durationSecondsFromHeader;
std::optional<double> beginStreamSecondsFromHeader;
std::optional<int64_t> numFramesFromHeader;
std::optional<int64_t> numKeyFrames;
std::optional<double> averageFps;
std::optional<double> averageFpsFromHeader;
std::optional<double> bitRate;

// More accurate duration, obtained by scanning the file.
// These presentation timestamps are in time base.
std::optional<int64_t> minPtsFromScan;
std::optional<int64_t> maxPtsFromScan;
std::optional<int64_t> beginStreamPtsFromContent;
std::optional<int64_t> endStreamPtsFromContent;
// These presentation timestamps are in seconds.
std::optional<double> minPtsSecondsFromScan;
std::optional<double> maxPtsSecondsFromScan;
std::optional<double> beginStreamPtsSecondsFromContent;
std::optional<double> endStreamPtsSecondsFromContent;
// This can be useful for index-based seeking.
std::optional<int64_t> numFramesFromScan;
std::optional<int64_t> numFramesFromContent;

// Video-only fields derived from the AVCodecContext.
std::optional<int64_t> width;
Expand All @@ -58,7 +58,7 @@ struct ContainerMetadata {
int numVideoStreams = 0;
// Note that this is the container-level duration, which is usually the max
// of all stream durations available in the container.
std::optional<double> durationSeconds;
std::optional<double> durationSecondsFromHeader;
// Total BitRate level information at the container level in bit/s
std::optional<double> bitRate;
// If set, this is the index to the default audio stream.
Expand Down
63 changes: 33 additions & 30 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,22 @@ void SingleStreamDecoder::initializeDecoder() {

int64_t frameCount = avStream->nb_frames;
if (frameCount > 0) {
streamMetadata.numFrames = frameCount;
streamMetadata.numFramesFromHeader = frameCount;
}

if (avStream->duration > 0 && avStream->time_base.den > 0) {
streamMetadata.durationSeconds =
streamMetadata.durationSecondsFromHeader =
av_q2d(avStream->time_base) * avStream->duration;
}
if (avStream->start_time != AV_NOPTS_VALUE) {
streamMetadata.beginStreamFromHeader =
streamMetadata.beginStreamSecondsFromHeader =
av_q2d(avStream->time_base) * avStream->start_time;
}

if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
double fps = av_q2d(avStream->r_frame_rate);
if (fps > 0) {
streamMetadata.averageFps = fps;
streamMetadata.averageFpsFromHeader = fps;
}
containerMetadata_.numVideoStreams++;
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
Expand All @@ -163,7 +163,7 @@ void SingleStreamDecoder::initializeDecoder() {

if (formatContext_->duration > 0) {
AVRational defaultTimeBase{1, AV_TIME_BASE};
containerMetadata_.durationSeconds =
containerMetadata_.durationSecondsFromHeader =
ptsToSeconds(formatContext_->duration, defaultTimeBase);
}

Expand Down Expand Up @@ -236,13 +236,14 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
// record its relevant metadata.
int streamIndex = packet->stream_index;
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
streamMetadata.minPtsFromScan = std::min(
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
streamMetadata.maxPtsFromScan = std::max(
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
streamMetadata.beginStreamPtsFromContent = std::min(
streamMetadata.beginStreamPtsFromContent.value_or(INT64_MAX),
getPtsOrDts(packet));
streamMetadata.endStreamPtsFromContent = std::max(
streamMetadata.endStreamPtsFromContent.value_or(INT64_MIN),
getPtsOrDts(packet) + packet->duration);
streamMetadata.numFramesFromScan =
streamMetadata.numFramesFromScan.value_or(0) + 1;
streamMetadata.numFramesFromContent =
streamMetadata.numFramesFromContent.value_or(0) + 1;

// Note that we set the other value in this struct, nextPts, only after
// we have scanned all packets and sorted by pts.
Expand All @@ -262,16 +263,17 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
auto avStream = formatContext_->streams[streamIndex];

streamMetadata.numFramesFromScan =
streamMetadata.numFramesFromContent =
streamInfos_[streamIndex].allFrames.size();

if (streamMetadata.minPtsFromScan.has_value()) {
streamMetadata.minPtsSecondsFromScan =
*streamMetadata.minPtsFromScan * av_q2d(avStream->time_base);
if (streamMetadata.beginStreamPtsFromContent.has_value()) {
streamMetadata.beginStreamPtsSecondsFromContent =
*streamMetadata.beginStreamPtsFromContent *
av_q2d(avStream->time_base);
}
if (streamMetadata.maxPtsFromScan.has_value()) {
streamMetadata.maxPtsSecondsFromScan =
*streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base);
if (streamMetadata.endStreamPtsFromContent.has_value()) {
streamMetadata.endStreamPtsSecondsFromContent =
*streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base);
}
}

Expand Down Expand Up @@ -445,7 +447,7 @@ void SingleStreamDecoder::addVideoStream(
containerMetadata_.allStreamMetadata[activeStreamIndex_];

if (seekMode_ == SeekMode::approximate &&
!streamMetadata.averageFps.has_value()) {
!streamMetadata.averageFpsFromHeader.has_value()) {
throw std::runtime_error(
"Seek mode is approximate, but stream " +
std::to_string(activeStreamIndex_) +
Expand Down Expand Up @@ -1422,9 +1424,9 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
TORCH_CHECK(
streamMetadata.averageFps.has_value(),
streamMetadata.averageFpsFromHeader.has_value(),
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
return std::floor(seconds * streamMetadata.averageFps.value());
return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand All @@ -1449,9 +1451,9 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
TORCH_CHECK(
streamMetadata.averageFps.has_value(),
streamMetadata.averageFpsFromHeader.has_value(),
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
return std::ceil(seconds * streamMetadata.averageFps.value());
return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand All @@ -1467,10 +1469,11 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];
TORCH_CHECK(
streamMetadata.averageFps.has_value(),
streamMetadata.averageFpsFromHeader.has_value(),
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
return secondsToClosestPts(
frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase);
frameIndex / streamMetadata.averageFpsFromHeader.value(),
streamInfo.timeBase);
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand All @@ -1485,9 +1488,9 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.numFramesFromScan.value();
return streamMetadata.numFramesFromContent.value();
case SeekMode::approximate: {
return streamMetadata.numFrames;
return streamMetadata.numFramesFromHeader;
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand All @@ -1498,7 +1501,7 @@ double SingleStreamDecoder::getMinSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.minPtsSecondsFromScan.value();
return streamMetadata.beginStreamPtsSecondsFromContent.value();
case SeekMode::approximate:
return 0;
default:
Expand All @@ -1510,9 +1513,9 @@ std::optional<double> SingleStreamDecoder::getMaxSeconds(
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.maxPtsSecondsFromScan.value();
return streamMetadata.endStreamPtsSecondsFromContent.value();
case SeekMode::approximate: {
return streamMetadata.durationSeconds;
return streamMetadata.durationSecondsFromHeader;
}
default:
throw std::runtime_error("Unknown SeekMode");
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class SingleStreamDecoder {
//
// Valid values for startSeconds and stopSeconds are:
//
// [minPtsSecondsFromScan, maxPtsSecondsFromScan)
// [beginStreamPtsSecondsFromContent, endStreamPtsSecondsFromContent)
FrameBatchOutput getFramesPlayedInRange(
double startSeconds,
double stopSeconds);
Expand Down
18 changes: 10 additions & 8 deletions src/torchcodec/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,26 +225,28 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
for stream_index in range(container_dict["numStreams"]):
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
common_meta = dict(
duration_seconds_from_header=stream_dict.get("durationSeconds"),
duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"),
bit_rate=stream_dict.get("bitRate"),
begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"),
begin_stream_seconds_from_header=stream_dict.get(
"beginStreamSecondsFromHeader"
),
codec=stream_dict.get("codec"),
stream_index=stream_index,
)
if stream_dict["mediaType"] == "video":
streams_metadata.append(
VideoStreamMetadata(
begin_stream_seconds_from_content=stream_dict.get(
"minPtsSecondsFromScan"
"beginStreamSecondsFromContent"
),
end_stream_seconds_from_content=stream_dict.get(
"maxPtsSecondsFromScan"
"endStreamSecondsFromContent"
),
width=stream_dict.get("width"),
height=stream_dict.get("height"),
num_frames_from_header=stream_dict.get("numFrames"),
num_frames_from_content=stream_dict.get("numFramesFromScan"),
average_fps_from_header=stream_dict.get("averageFps"),
num_frames_from_header=stream_dict.get("numFramesFromHeader"),
num_frames_from_content=stream_dict.get("numFramesFromContent"),
average_fps_from_header=stream_dict.get("averageFpsFromHeader"),
**common_meta,
)
)
Expand All @@ -264,7 +266,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
streams_metadata.append(StreamMetadata(**common_meta))

return ContainerMetadata(
duration_seconds_from_header=container_dict.get("durationSeconds"),
duration_seconds_from_header=container_dict.get("durationSecondsFromHeader"),
bit_rate_from_header=container_dict.get("bitRate"),
best_video_stream_index=container_dict.get("bestVideoStreamIndex"),
best_audio_stream_index=container_dict.get("bestAudioStreamIndex"),
Expand Down
Loading
Loading