@@ -604,16 +604,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange(
604
604
const auto & streamMetadata =
605
605
containerMetadata_.allStreamMetadata [activeStreamIndex_];
606
606
const auto & streamInfo = streamInfos_[activeStreamIndex_];
607
- int64_t numFrames = getNumFrames (streamMetadata);
608
607
TORCH_CHECK (
609
608
start >= 0 , " Range start, " + std::to_string (start) + " is less than 0." );
610
- TORCH_CHECK (
611
- stop <= numFrames,
612
- " Range stop, " + std::to_string (stop) +
613
- " , is more than the number of frames, " + std::to_string (numFrames));
614
609
TORCH_CHECK (
615
610
step > 0 , " Step must be greater than 0; is " + std::to_string (step));
616
611
612
+ // Note that if we do not have the number of frames available in our metadata,
613
+ // then we assume that the upper part of the range is valid.
614
+ std::optional<int64_t > numFrames = getNumFrames (streamMetadata);
615
+ if (numFrames.has_value ()) {
616
+ TORCH_CHECK (
617
+ stop <= numFrames.value (),
618
+ " Range stop, " + std::to_string (stop) +
619
+ " , is more than the number of frames, " +
620
+ std::to_string (numFrames.value ()));
621
+ }
622
+
617
623
int64_t numOutputFrames = std::ceil ((stop - start) / double (step));
618
624
const auto & videoStreamOptions = streamInfo.videoStreamOptions ;
619
625
FrameBatchOutput frameBatchOutput (
@@ -678,7 +684,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
678
684
containerMetadata_.allStreamMetadata [activeStreamIndex_];
679
685
680
686
double minSeconds = getMinSeconds (streamMetadata);
681
- double maxSeconds = getMaxSeconds (streamMetadata);
687
+ std::optional< double > maxSeconds = getMaxSeconds (streamMetadata);
682
688
683
689
// The frame played at timestamp t and the one played at timestamp `t +
684
690
// eps` are probably the same frame, with the same index. The easiest way to
@@ -689,10 +695,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
689
695
for (size_t i = 0 ; i < timestamps.size (); ++i) {
690
696
auto frameSeconds = timestamps[i];
691
697
TORCH_CHECK (
692
- frameSeconds >= minSeconds && frameSeconds < maxSeconds ,
698
+ frameSeconds >= minSeconds,
693
699
" frame pts is " + std::to_string (frameSeconds) +
694
- " ; must be in range [" + std::to_string (minSeconds) + " , " +
695
- std::to_string (maxSeconds) + " )." );
700
+ " ; must be greater than or equal to " + std::to_string (minSeconds) +
701
+ " ." );
702
+
703
+ // Note that if we can't determine the maximum number of seconds from the
704
+ // metadata, then we assume the frame's pts is valid.
705
+ if (maxSeconds.has_value ()) {
706
+ TORCH_CHECK (
707
+ frameSeconds < maxSeconds.value (),
708
+ " frame pts is " + std::to_string (frameSeconds) +
709
+ " ; must be less than " + std::to_string (maxSeconds.value ()) +
710
+ " ." );
711
+ }
696
712
697
713
frameIndices[i] = secondsToIndexLowerBound (frameSeconds);
698
714
}
@@ -739,17 +755,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
739
755
}
740
756
741
757
double minSeconds = getMinSeconds (streamMetadata);
742
- double maxSeconds = getMaxSeconds (streamMetadata);
743
758
TORCH_CHECK (
744
- startSeconds >= minSeconds && startSeconds < maxSeconds ,
759
+ startSeconds >= minSeconds,
745
760
" Start seconds is " + std::to_string (startSeconds) +
746
- " ; must be in range [" + std::to_string (minSeconds) + " , " +
747
- std::to_string (maxSeconds) + " )." );
748
- TORCH_CHECK (
749
- stopSeconds <= maxSeconds,
750
- " Stop seconds (" + std::to_string (stopSeconds) +
751
- " ; must be less than or equal to " + std::to_string (maxSeconds) +
752
- " )." );
761
+ " ; must be greater than or equal to " + std::to_string (minSeconds) +
762
+ " ." );
763
+
764
+ // Note that if we can't determine the maximum seconds from the metadata, then
765
+ // we assume upper range is valid.
766
+ std::optional<double > maxSeconds = getMaxSeconds (streamMetadata);
767
+ if (maxSeconds.has_value ()) {
768
+ TORCH_CHECK (
769
+ startSeconds < maxSeconds.value (),
770
+ " Start seconds is " + std::to_string (startSeconds) +
771
+ " ; must be less than " + std::to_string (maxSeconds.value ()) + " ." );
772
+ TORCH_CHECK (
773
+ stopSeconds <= maxSeconds.value (),
774
+ " Stop seconds (" + std::to_string (stopSeconds) +
775
+ " ; must be less than or equal to " +
776
+ std::to_string (maxSeconds.value ()) + " )." );
777
+ }
753
778
754
779
// Note that we look at nextPts for a frame, and not its pts or duration.
755
780
// Our abstract player displays frames starting at the pts for that frame
@@ -1459,7 +1484,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
1459
1484
// STREAM AND METADATA APIS
1460
1485
// --------------------------------------------------------------------------
1461
1486
1462
- int64_t SingleStreamDecoder::getNumFrames (
1487
+ std::optional< int64_t > SingleStreamDecoder::getNumFrames (
1463
1488
const StreamMetadata& streamMetadata) {
1464
1489
switch (seekMode_) {
1465
1490
case SeekMode::exact:
@@ -1487,7 +1512,7 @@ double SingleStreamDecoder::getMinSeconds(
1487
1512
}
1488
1513
}
1489
1514
1490
- double SingleStreamDecoder::getMaxSeconds (
1515
+ std::optional< double > SingleStreamDecoder::getMaxSeconds (
1491
1516
const StreamMetadata& streamMetadata) {
1492
1517
switch (seekMode_) {
1493
1518
case SeekMode::exact:
@@ -1542,12 +1567,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
1542
1567
void SingleStreamDecoder::validateFrameIndex (
1543
1568
const StreamMetadata& streamMetadata,
1544
1569
int64_t frameIndex) {
1545
- int64_t numFrames = getNumFrames (streamMetadata);
1546
1570
TORCH_CHECK (
1547
- frameIndex >= 0 && frameIndex < numFrames ,
1571
+ frameIndex >= 0 ,
1548
1572
" Invalid frame index=" + std::to_string (frameIndex) +
1549
1573
" for streamIndex=" + std::to_string (streamMetadata.streamIndex ) +
1550
- " numFrames=" + std::to_string (numFrames));
1574
+ " ; must be greater than or equal to 0" );
1575
+
1576
+ // Note that if we do not have the number of frames available in our metadata,
1577
+ // then we assume that the frameIndex is valid.
1578
+ std::optional<int64_t > numFrames = getNumFrames (streamMetadata);
1579
+ if (numFrames.has_value ()) {
1580
+ TORCH_CHECK (
1581
+ frameIndex < numFrames.value (),
1582
+ " Invalid frame index=" + std::to_string (frameIndex) +
1583
+ " for streamIndex=" + std::to_string (streamMetadata.streamIndex ) +
1584
+ " ; must be less than " + std::to_string (numFrames.value ()));
1585
+ }
1551
1586
}
1552
1587
1553
1588
// --------------------------------------------------------------------------
0 commit comments