Skip to content

Commit ff35e16

Browse files
Dan-Floresdanielflores3
and
danielflores3
authored
Update C++ metadata names to match python (#707)
Co-authored-by: danielflores3 <[email protected]>
1 parent d49d72c commit ff35e16

File tree

8 files changed

+130
-115
lines changed

8 files changed

+130
-115
lines changed

src/torchcodec/_core/Metadata.h

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,22 @@ struct StreamMetadata {
2525
AVMediaType mediaType;
2626
std::optional<AVCodecID> codecId;
2727
std::optional<std::string> codecName;
28-
std::optional<double> durationSeconds;
29-
std::optional<double> beginStreamFromHeader;
30-
std::optional<int64_t> numFrames;
28+
std::optional<double> durationSecondsFromHeader;
29+
std::optional<double> beginStreamSecondsFromHeader;
30+
std::optional<int64_t> numFramesFromHeader;
3131
std::optional<int64_t> numKeyFrames;
32-
std::optional<double> averageFps;
32+
std::optional<double> averageFpsFromHeader;
3333
std::optional<double> bitRate;
3434

3535
// More accurate duration, obtained by scanning the file.
3636
// These presentation timestamps are in time base.
37-
std::optional<int64_t> minPtsFromScan;
38-
std::optional<int64_t> maxPtsFromScan;
37+
std::optional<int64_t> beginStreamPtsFromContent;
38+
std::optional<int64_t> endStreamPtsFromContent;
3939
// These presentation timestamps are in seconds.
40-
std::optional<double> minPtsSecondsFromScan;
41-
std::optional<double> maxPtsSecondsFromScan;
40+
std::optional<double> beginStreamPtsSecondsFromContent;
41+
std::optional<double> endStreamPtsSecondsFromContent;
4242
// This can be useful for index-based seeking.
43-
std::optional<int64_t> numFramesFromScan;
43+
std::optional<int64_t> numFramesFromContent;
4444

4545
// Video-only fields derived from the AVCodecContext.
4646
std::optional<int64_t> width;
@@ -58,7 +58,7 @@ struct ContainerMetadata {
5858
int numVideoStreams = 0;
5959
// Note that this is the container-level duration, which is usually the max
6060
// of all stream durations available in the container.
61-
std::optional<double> durationSeconds;
61+
std::optional<double> durationSecondsFromHeader;
6262
// Total BitRate level information at the container level in bit/s
6363
std::optional<double> bitRate;
6464
// If set, this is the index to the default audio stream.

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -125,22 +125,22 @@ void SingleStreamDecoder::initializeDecoder() {
125125

126126
int64_t frameCount = avStream->nb_frames;
127127
if (frameCount > 0) {
128-
streamMetadata.numFrames = frameCount;
128+
streamMetadata.numFramesFromHeader = frameCount;
129129
}
130130

131131
if (avStream->duration > 0 && avStream->time_base.den > 0) {
132-
streamMetadata.durationSeconds =
132+
streamMetadata.durationSecondsFromHeader =
133133
av_q2d(avStream->time_base) * avStream->duration;
134134
}
135135
if (avStream->start_time != AV_NOPTS_VALUE) {
136-
streamMetadata.beginStreamFromHeader =
136+
streamMetadata.beginStreamSecondsFromHeader =
137137
av_q2d(avStream->time_base) * avStream->start_time;
138138
}
139139

140140
if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
141141
double fps = av_q2d(avStream->r_frame_rate);
142142
if (fps > 0) {
143-
streamMetadata.averageFps = fps;
143+
streamMetadata.averageFpsFromHeader = fps;
144144
}
145145
containerMetadata_.numVideoStreams++;
146146
} else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) {
@@ -163,7 +163,7 @@ void SingleStreamDecoder::initializeDecoder() {
163163

164164
if (formatContext_->duration > 0) {
165165
AVRational defaultTimeBase{1, AV_TIME_BASE};
166-
containerMetadata_.durationSeconds =
166+
containerMetadata_.durationSecondsFromHeader =
167167
ptsToSeconds(formatContext_->duration, defaultTimeBase);
168168
}
169169

@@ -236,13 +236,14 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
236236
// record its relevant metadata.
237237
int streamIndex = packet->stream_index;
238238
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
239-
streamMetadata.minPtsFromScan = std::min(
240-
streamMetadata.minPtsFromScan.value_or(INT64_MAX), getPtsOrDts(packet));
241-
streamMetadata.maxPtsFromScan = std::max(
242-
streamMetadata.maxPtsFromScan.value_or(INT64_MIN),
239+
streamMetadata.beginStreamPtsFromContent = std::min(
240+
streamMetadata.beginStreamPtsFromContent.value_or(INT64_MAX),
241+
getPtsOrDts(packet));
242+
streamMetadata.endStreamPtsFromContent = std::max(
243+
streamMetadata.endStreamPtsFromContent.value_or(INT64_MIN),
243244
getPtsOrDts(packet) + packet->duration);
244-
streamMetadata.numFramesFromScan =
245-
streamMetadata.numFramesFromScan.value_or(0) + 1;
245+
streamMetadata.numFramesFromContent =
246+
streamMetadata.numFramesFromContent.value_or(0) + 1;
246247

247248
// Note that we set the other value in this struct, nextPts, only after
248249
// we have scanned all packets and sorted by pts.
@@ -262,16 +263,17 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
262263
auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex];
263264
auto avStream = formatContext_->streams[streamIndex];
264265

265-
streamMetadata.numFramesFromScan =
266+
streamMetadata.numFramesFromContent =
266267
streamInfos_[streamIndex].allFrames.size();
267268

268-
if (streamMetadata.minPtsFromScan.has_value()) {
269-
streamMetadata.minPtsSecondsFromScan =
270-
*streamMetadata.minPtsFromScan * av_q2d(avStream->time_base);
269+
if (streamMetadata.beginStreamPtsFromContent.has_value()) {
270+
streamMetadata.beginStreamPtsSecondsFromContent =
271+
*streamMetadata.beginStreamPtsFromContent *
272+
av_q2d(avStream->time_base);
271273
}
272-
if (streamMetadata.maxPtsFromScan.has_value()) {
273-
streamMetadata.maxPtsSecondsFromScan =
274-
*streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base);
274+
if (streamMetadata.endStreamPtsFromContent.has_value()) {
275+
streamMetadata.endStreamPtsSecondsFromContent =
276+
*streamMetadata.endStreamPtsFromContent * av_q2d(avStream->time_base);
275277
}
276278
}
277279

@@ -445,7 +447,7 @@ void SingleStreamDecoder::addVideoStream(
445447
containerMetadata_.allStreamMetadata[activeStreamIndex_];
446448

447449
if (seekMode_ == SeekMode::approximate &&
448-
!streamMetadata.averageFps.has_value()) {
450+
!streamMetadata.averageFpsFromHeader.has_value()) {
449451
throw std::runtime_error(
450452
"Seek mode is approximate, but stream " +
451453
std::to_string(activeStreamIndex_) +
@@ -1422,9 +1424,9 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
14221424
auto& streamMetadata =
14231425
containerMetadata_.allStreamMetadata[activeStreamIndex_];
14241426
TORCH_CHECK(
1425-
streamMetadata.averageFps.has_value(),
1427+
streamMetadata.averageFpsFromHeader.has_value(),
14261428
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
1427-
return std::floor(seconds * streamMetadata.averageFps.value());
1429+
return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
14281430
}
14291431
default:
14301432
throw std::runtime_error("Unknown SeekMode");
@@ -1449,9 +1451,9 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
14491451
auto& streamMetadata =
14501452
containerMetadata_.allStreamMetadata[activeStreamIndex_];
14511453
TORCH_CHECK(
1452-
streamMetadata.averageFps.has_value(),
1454+
streamMetadata.averageFpsFromHeader.has_value(),
14531455
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
1454-
return std::ceil(seconds * streamMetadata.averageFps.value());
1456+
return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
14551457
}
14561458
default:
14571459
throw std::runtime_error("Unknown SeekMode");
@@ -1467,10 +1469,11 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14671469
auto& streamMetadata =
14681470
containerMetadata_.allStreamMetadata[activeStreamIndex_];
14691471
TORCH_CHECK(
1470-
streamMetadata.averageFps.has_value(),
1472+
streamMetadata.averageFpsFromHeader.has_value(),
14711473
"Cannot use approximate mode since we couldn't find the average fps from the metadata.");
14721474
return secondsToClosestPts(
1473-
frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase);
1475+
frameIndex / streamMetadata.averageFpsFromHeader.value(),
1476+
streamInfo.timeBase);
14741477
}
14751478
default:
14761479
throw std::runtime_error("Unknown SeekMode");
@@ -1485,9 +1488,9 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14851488
const StreamMetadata& streamMetadata) {
14861489
switch (seekMode_) {
14871490
case SeekMode::exact:
1488-
return streamMetadata.numFramesFromScan.value();
1491+
return streamMetadata.numFramesFromContent.value();
14891492
case SeekMode::approximate: {
1490-
return streamMetadata.numFrames;
1493+
return streamMetadata.numFramesFromHeader;
14911494
}
14921495
default:
14931496
throw std::runtime_error("Unknown SeekMode");
@@ -1498,7 +1501,7 @@ double SingleStreamDecoder::getMinSeconds(
14981501
const StreamMetadata& streamMetadata) {
14991502
switch (seekMode_) {
15001503
case SeekMode::exact:
1501-
return streamMetadata.minPtsSecondsFromScan.value();
1504+
return streamMetadata.beginStreamPtsSecondsFromContent.value();
15021505
case SeekMode::approximate:
15031506
return 0;
15041507
default:
@@ -1510,9 +1513,9 @@ std::optional<double> SingleStreamDecoder::getMaxSeconds(
15101513
const StreamMetadata& streamMetadata) {
15111514
switch (seekMode_) {
15121515
case SeekMode::exact:
1513-
return streamMetadata.maxPtsSecondsFromScan.value();
1516+
return streamMetadata.endStreamPtsSecondsFromContent.value();
15141517
case SeekMode::approximate: {
1515-
return streamMetadata.durationSeconds;
1518+
return streamMetadata.durationSecondsFromHeader;
15161519
}
15171520
default:
15181521
throw std::runtime_error("Unknown SeekMode");

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ class SingleStreamDecoder {
121121
//
122122
// Valid values for startSeconds and stopSeconds are:
123123
//
124-
// [minPtsSecondsFromScan, maxPtsSecondsFromScan)
124+
// [beginStreamPtsSecondsFromContent, endStreamPtsSecondsFromContent)
125125
FrameBatchOutput getFramesPlayedInRange(
126126
double startSeconds,
127127
double stopSeconds);

src/torchcodec/_core/_metadata.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -225,26 +225,28 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
225225
for stream_index in range(container_dict["numStreams"]):
226226
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
227227
common_meta = dict(
228-
duration_seconds_from_header=stream_dict.get("durationSeconds"),
228+
duration_seconds_from_header=stream_dict.get("durationSecondsFromHeader"),
229229
bit_rate=stream_dict.get("bitRate"),
230-
begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"),
230+
begin_stream_seconds_from_header=stream_dict.get(
231+
"beginStreamSecondsFromHeader"
232+
),
231233
codec=stream_dict.get("codec"),
232234
stream_index=stream_index,
233235
)
234236
if stream_dict["mediaType"] == "video":
235237
streams_metadata.append(
236238
VideoStreamMetadata(
237239
begin_stream_seconds_from_content=stream_dict.get(
238-
"minPtsSecondsFromScan"
240+
"beginStreamSecondsFromContent"
239241
),
240242
end_stream_seconds_from_content=stream_dict.get(
241-
"maxPtsSecondsFromScan"
243+
"endStreamSecondsFromContent"
242244
),
243245
width=stream_dict.get("width"),
244246
height=stream_dict.get("height"),
245-
num_frames_from_header=stream_dict.get("numFrames"),
246-
num_frames_from_content=stream_dict.get("numFramesFromScan"),
247-
average_fps_from_header=stream_dict.get("averageFps"),
247+
num_frames_from_header=stream_dict.get("numFramesFromHeader"),
248+
num_frames_from_content=stream_dict.get("numFramesFromContent"),
249+
average_fps_from_header=stream_dict.get("averageFpsFromHeader"),
248250
**common_meta,
249251
)
250252
)
@@ -264,7 +266,7 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
264266
streams_metadata.append(StreamMetadata(**common_meta))
265267

266268
return ContainerMetadata(
267-
duration_seconds_from_header=container_dict.get("durationSeconds"),
269+
duration_seconds_from_header=container_dict.get("durationSecondsFromHeader"),
268270
bit_rate_from_header=container_dict.get("bitRate"),
269271
best_video_stream_index=container_dict.get("bestVideoStreamIndex"),
270272
best_audio_stream_index=container_dict.get("bestAudioStreamIndex"),

0 commit comments

Comments
 (0)