Skip to content

Commit e654189

Browse files
authored
Support negative index in SimpleVideoDecoder (#743)
1 parent 1f8e02e commit e654189

File tree

2 files changed

+41
-4
lines changed

2 files changed

+41
-4
lines changed

src/torchcodec/decoders/_video_decoder.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,8 @@ def get_frame_at(self, index: int) -> Frame:
195195
Returns:
196196
Frame: The frame at the given index.
197197
"""
198+
if index < 0:
199+
index += self._num_frames
198200

199201
if not 0 <= index < self._num_frames:
200202
raise IndexError(
@@ -218,6 +220,9 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
218220
Returns:
219221
FrameBatch: The frames at the given indices.
220222
"""
223+
indices = [
224+
index if index >= 0 else index + self._num_frames for index in indices
225+
]
221226

222227
data, pts_seconds, duration_seconds = core.get_frames_at_indices(
223228
self._decoder, frame_indices=indices

test/test_decoders.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,19 @@ def test_getitem_slice(self, device, seek_mode):
328328
)
329329
assert_frames_equal(ref386_389, slice386_389)
330330

331+
# slices with upper bound greater than len(decoder) are supported
332+
slice387_389 = decoder[-3:10000].to(device)
333+
assert slice387_389.shape == torch.Size(
334+
[
335+
3,
336+
NASA_VIDEO.num_color_channels,
337+
NASA_VIDEO.height,
338+
NASA_VIDEO.width,
339+
]
340+
)
341+
ref387_389 = NASA_VIDEO.get_frame_data_by_range(387, 390).to(device)
342+
assert_frames_equal(ref387_389, slice387_389)
343+
331344
# an empty range is valid!
332345
empty_frame = decoder[5:5]
333346
assert_frames_equal(empty_frame, NASA_VIDEO.empty_chw_tensor.to(device))
@@ -437,6 +450,11 @@ def test_get_frame_at(self, device, seek_mode):
437450
expected_frame_info.duration_seconds, rel=1e-3
438451
)
439452

453+
# test negative frame index
454+
frame_minus1 = decoder.get_frame_at(-1)
455+
ref_frame_minus1 = NASA_VIDEO.get_frame_data_by_index(389).to(device)
456+
assert_frames_equal(ref_frame_minus1, frame_minus1.data)
457+
440458
# test numpy.int64
441459
frame9 = decoder.get_frame_at(numpy.int64(9))
442460
assert_frames_equal(ref_frame9, frame9.data)
@@ -470,7 +488,7 @@ def test_get_frame_at_fails(self, device, seek_mode):
470488
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
471489

472490
with pytest.raises(IndexError, match="out of bounds"):
473-
frame = decoder.get_frame_at(-1) # noqa
491+
frame = decoder.get_frame_at(-10000) # noqa
474492

475493
with pytest.raises(IndexError, match="out of bounds"):
476494
frame = decoder.get_frame_at(10000) # noqa
@@ -480,7 +498,8 @@ def test_get_frame_at_fails(self, device, seek_mode):
480498
def test_get_frames_at(self, device, seek_mode):
481499
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
482500

483-
frames = decoder.get_frames_at([35, 25])
501+
# test positive and negative frame index
502+
frames = decoder.get_frames_at([35, 25, -1, -2])
484503

485504
assert isinstance(frames, FrameBatch)
486505

@@ -490,12 +509,20 @@ def test_get_frames_at(self, device, seek_mode):
490509
assert_frames_equal(
491510
frames[1].data, NASA_VIDEO.get_frame_data_by_index(25).to(device)
492511
)
512+
assert_frames_equal(
513+
frames[2].data, NASA_VIDEO.get_frame_data_by_index(389).to(device)
514+
)
515+
assert_frames_equal(
516+
frames[3].data, NASA_VIDEO.get_frame_data_by_index(388).to(device)
517+
)
493518

494519
assert frames.pts_seconds.device.type == "cpu"
495520
expected_pts_seconds = torch.tensor(
496521
[
497522
NASA_VIDEO.get_frame_info(35).pts_seconds,
498523
NASA_VIDEO.get_frame_info(25).pts_seconds,
524+
NASA_VIDEO.get_frame_info(389).pts_seconds,
525+
NASA_VIDEO.get_frame_info(388).pts_seconds,
499526
],
500527
dtype=torch.float64,
501528
)
@@ -508,6 +535,8 @@ def test_get_frames_at(self, device, seek_mode):
508535
[
509536
NASA_VIDEO.get_frame_info(35).duration_seconds,
510537
NASA_VIDEO.get_frame_info(25).duration_seconds,
538+
NASA_VIDEO.get_frame_info(389).duration_seconds,
539+
NASA_VIDEO.get_frame_info(388).duration_seconds,
511540
],
512541
dtype=torch.float64,
513542
)
@@ -520,8 +549,11 @@ def test_get_frames_at(self, device, seek_mode):
520549
def test_get_frames_at_fails(self, device, seek_mode):
521550
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode)
522551

523-
with pytest.raises(RuntimeError, match="Invalid frame index=-1"):
524-
decoder.get_frames_at([-1])
552+
expected_converted_index = -10000 + len(decoder)
553+
with pytest.raises(
554+
RuntimeError, match=f"Invalid frame index={expected_converted_index}"
555+
):
556+
decoder.get_frames_at([-10000])
525557

526558
with pytest.raises(RuntimeError, match="Invalid frame index=390"):
527559
decoder.get_frames_at([390])

0 commit comments

Comments
 (0)