Skip to content

Commit 314c3cd

Browse files
authored
Merge branch 'main' into align_metadata_names
2 parents 8136466 + 0f22b2b commit 314c3cd

File tree

3 files changed

+62
-27
lines changed

3 files changed

+62
-27
lines changed

src/torchcodec/_core/SingleStreamDecoder.cpp

Lines changed: 58 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -604,16 +604,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
604604
const auto& streamMetadata =
605605
containerMetadata_.allStreamMetadata[activeStreamIndex_];
606606
const auto& streamInfo = streamInfos_[activeStreamIndex_];
607-
int64_t numFrames = getNumFrames(streamMetadata);
608607
TORCH_CHECK(
609608
start >= 0, "Range start, " + std::to_string(start) + " is less than 0.");
610-
TORCH_CHECK(
611-
stop <= numFrames,
612-
"Range stop, " + std::to_string(stop) +
613-
", is more than the number of frames, " + std::to_string(numFrames));
614609
TORCH_CHECK(
615610
step > 0, "Step must be greater than 0; is " + std::to_string(step));
616611

612+
// Note that if we do not have the number of frames available in our metadata,
613+
// then we assume that the upper part of the range is valid.
614+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
615+
if (numFrames.has_value()) {
616+
TORCH_CHECK(
617+
stop <= numFrames.value(),
618+
"Range stop, " + std::to_string(stop) +
619+
", is more than the number of frames, " +
620+
std::to_string(numFrames.value()));
621+
}
622+
617623
int64_t numOutputFrames = std::ceil((stop - start) / double(step));
618624
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
619625
FrameBatchOutput frameBatchOutput(
@@ -678,7 +684,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
678684
containerMetadata_.allStreamMetadata[activeStreamIndex_];
679685

680686
double minSeconds = getMinSeconds(streamMetadata);
681-
double maxSeconds = getMaxSeconds(streamMetadata);
687+
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
682688

683689
// The frame played at timestamp t and the one played at timestamp `t +
684690
// eps` are probably the same frame, with the same index. The easiest way to
@@ -689,10 +695,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
689695
for (size_t i = 0; i < timestamps.size(); ++i) {
690696
auto frameSeconds = timestamps[i];
691697
TORCH_CHECK(
692-
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
698+
frameSeconds >= minSeconds,
693699
"frame pts is " + std::to_string(frameSeconds) +
694-
"; must be in range [" + std::to_string(minSeconds) + ", " +
695-
std::to_string(maxSeconds) + ").");
700+
"; must be greater than or equal to " + std::to_string(minSeconds) +
701+
".");
702+
703+
// Note that if we can't determine the maximum number of seconds from the
704+
// metadata, then we assume the frame's pts is valid.
705+
if (maxSeconds.has_value()) {
706+
TORCH_CHECK(
707+
frameSeconds < maxSeconds.value(),
708+
"frame pts is " + std::to_string(frameSeconds) +
709+
"; must be less than " + std::to_string(maxSeconds.value()) +
710+
".");
711+
}
696712

697713
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
698714
}
@@ -739,17 +755,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
739755
}
740756

741757
double minSeconds = getMinSeconds(streamMetadata);
742-
double maxSeconds = getMaxSeconds(streamMetadata);
743758
TORCH_CHECK(
744-
startSeconds >= minSeconds && startSeconds < maxSeconds,
759+
startSeconds >= minSeconds,
745760
"Start seconds is " + std::to_string(startSeconds) +
746-
"; must be in range [" + std::to_string(minSeconds) + ", " +
747-
std::to_string(maxSeconds) + ").");
748-
TORCH_CHECK(
749-
stopSeconds <= maxSeconds,
750-
"Stop seconds (" + std::to_string(stopSeconds) +
751-
"; must be less than or equal to " + std::to_string(maxSeconds) +
752-
").");
761+
"; must be greater than or equal to " + std::to_string(minSeconds) +
762+
".");
763+
764+
// Note that if we can't determine the maximum seconds from the metadata, then
765+
// we assume upper range is valid.
766+
std::optional<double> maxSeconds = getMaxSeconds(streamMetadata);
767+
if (maxSeconds.has_value()) {
768+
TORCH_CHECK(
769+
startSeconds < maxSeconds.value(),
770+
"Start seconds is " + std::to_string(startSeconds) +
771+
"; must be less than " + std::to_string(maxSeconds.value()) + ".");
772+
TORCH_CHECK(
773+
stopSeconds <= maxSeconds.value(),
774+
"Stop seconds (" + std::to_string(stopSeconds) +
775+
"; must be less than or equal to " +
776+
std::to_string(maxSeconds.value()) + ").");
777+
}
753778

754779
// Note that we look at nextPts for a frame, and not its pts or duration.
755780
// Our abstract player displays frames starting at the pts for that frame
@@ -1459,7 +1484,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
14591484
// STREAM AND METADATA APIS
14601485
// --------------------------------------------------------------------------
14611486

1462-
int64_t SingleStreamDecoder::getNumFrames(
1487+
std::optional<int64_t> SingleStreamDecoder::getNumFrames(
14631488
const StreamMetadata& streamMetadata) {
14641489
switch (seekMode_) {
14651490
case SeekMode::exact:
@@ -1487,7 +1512,7 @@ double SingleStreamDecoder::getMinSeconds(
14871512
}
14881513
}
14891514

1490-
double SingleStreamDecoder::getMaxSeconds(
1515+
std::optional<double> SingleStreamDecoder::getMaxSeconds(
14911516
const StreamMetadata& streamMetadata) {
14921517
switch (seekMode_) {
14931518
case SeekMode::exact:
@@ -1542,12 +1567,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
15421567
void SingleStreamDecoder::validateFrameIndex(
15431568
const StreamMetadata& streamMetadata,
15441569
int64_t frameIndex) {
1545-
int64_t numFrames = getNumFrames(streamMetadata);
15461570
TORCH_CHECK(
1547-
frameIndex >= 0 && frameIndex < numFrames,
1571+
frameIndex >= 0,
15481572
"Invalid frame index=" + std::to_string(frameIndex) +
15491573
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1550-
" numFrames=" + std::to_string(numFrames));
1574+
"; must be greater than or equal to 0");
1575+
1576+
// Note that if we do not have the number of frames available in our metadata,
1577+
// then we assume that the frameIndex is valid.
1578+
std::optional<int64_t> numFrames = getNumFrames(streamMetadata);
1579+
if (numFrames.has_value()) {
1580+
TORCH_CHECK(
1581+
frameIndex < numFrames.value(),
1582+
"Invalid frame index=" + std::to_string(frameIndex) +
1583+
" for streamIndex=" + std::to_string(streamMetadata.streamIndex) +
1584+
"; must be less than " + std::to_string(numFrames.value()));
1585+
}
15511586
}
15521587

15531588
// --------------------------------------------------------------------------

src/torchcodec/_core/SingleStreamDecoder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,9 @@ class SingleStreamDecoder {
304304
// index. Note that this index may be truncated for some files.
305305
int getBestStreamIndex(AVMediaType mediaType);
306306

307-
int64_t getNumFrames(const StreamMetadata& streamMetadata);
307+
std::optional<int64_t> getNumFrames(const StreamMetadata& streamMetadata);
308308
double getMinSeconds(const StreamMetadata& streamMetadata);
309-
double getMaxSeconds(const StreamMetadata& streamMetadata);
309+
std::optional<double> getMaxSeconds(const StreamMetadata& streamMetadata);
310310

311311
// --------------------------------------------------------------------------
312312
// VALIDATION UTILS

test/test_decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -597,10 +597,10 @@ def test_get_frames_played_at(self, device, seek_mode):
597597
def test_get_frames_played_at_fails(self, device, seek_mode):
598598
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
599599

600-
with pytest.raises(RuntimeError, match="must be in range"):
600+
with pytest.raises(RuntimeError, match="must be greater than or equal to"):
601601
decoder.get_frames_played_at([-1])
602602

603-
with pytest.raises(RuntimeError, match="must be in range"):
603+
with pytest.raises(RuntimeError, match="must be less than"):
604604
decoder.get_frames_played_at([14])
605605

606606
with pytest.raises(RuntimeError, match="Expected a value of type"):

0 commit comments

Comments
 (0)