Skip to content

Commit d8dde5c

Browse files
hugo-ijwscotts
andauthored
fix: Solve CUDA AV1 decoding (#448)
Co-authored-by: Scott Schneider <[email protected]>
1 parent 81de40e commit d8dde5c

File tree

11 files changed

+98
-3
lines changed

11 files changed

+98
-3
lines changed

.github/workflows/linux_cuda_wheel.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ jobs:
5656
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation"
5757

5858
install-and-test:
59-
runs-on: linux.4xlarge.nvidia.gpu
59+
runs-on: linux.g5.4xlarge.nvidia.gpu
6060
strategy:
6161
fail-fast: false
6262
matrix:

CONTRIBUTING.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ git clone [email protected]:pytorch/torchcodec.git
4242
cd torchcodec
4343

4444
pip install -e ".[dev]" --no-build-isolation -vv
45+
# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv
4546
```
4647

4748
### Running unit tests
4849

49-
To run python tests run:
50+
To run python tests run (please make sure `torchvision` is installed):
5051

5152
```bash
5253
pytest test -vvv

src/torchcodec/decoders/_core/CPUOnlyDevice.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,10 @@ void releaseContextOnCuda(
3535
throwUnsupportedDeviceError(device);
3636
}
3737

38+
std::optional<AVCodecPtr> findCudaCodec(
39+
const torch::Device& device,
40+
const AVCodecID& codecId) {
41+
throwUnsupportedDeviceError(device);
42+
}
43+
3844
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/CudaDevice.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,4 +256,33 @@ void convertAVFrameToDecodedOutputOnCuda(
256256
<< " took: " << duration.count() << "us" << std::endl;
257257
}
258258

259+
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
260+
// we have to do this because of an FFmpeg bug where hardware decoding is not
261+
// appropriately set, so we just go off and find the matching codec for the CUDA
262+
// device
263+
std::optional<AVCodecPtr> findCudaCodec(
264+
const torch::Device& device,
265+
const AVCodecID& codecId) {
266+
throwErrorIfNonCudaDevice(device);
267+
268+
void* i = NULL;
269+
270+
AVCodecPtr c;
271+
while (c = av_codec_iterate(&i)) {
272+
const AVCodecHWConfig* config;
273+
274+
if (c->id != codecId || !av_codec_is_decoder(c)) {
275+
continue;
276+
}
277+
278+
for (int j = 0; config = avcodec_get_hw_config(c, j); j++) {
279+
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
280+
return c;
281+
}
282+
}
283+
}
284+
285+
return std::nullopt;
286+
}
287+
259288
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/DeviceInterface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <memory>
1111
#include <stdexcept>
1212
#include <string>
13+
#include "FFMPEGCommon.h"
1314
#include "src/torchcodec/decoders/_core/VideoDecoder.h"
1415

1516
extern "C" {
@@ -43,4 +44,8 @@ void releaseContextOnCuda(
4344
const torch::Device& device,
4445
AVCodecContext* codecContext);
4546

47+
std::optional<AVCodecPtr> findCudaCodec(
48+
const torch::Device& device,
49+
const AVCodecID& codecId);
50+
4651
} // namespace facebook::torchcodec

src/torchcodec/decoders/_core/VideoDecoder.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,12 @@ void VideoDecoder::addVideoStreamDecoder(
461461
"Stream with index " + std::to_string(streamNumber) +
462462
" is not a video stream.");
463463
}
464+
465+
if (options.device.type() == torch::kCUDA) {
466+
codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id)
467+
.value_or(codec);
468+
}
469+
464470
AVCodecContext* codecContext = avcodec_alloc_context3(codec);
465471
codecContext->thread_count = options.ffmpegThreadCount.value_or(0);
466472
TORCH_CHECK(codecContext != nullptr);

test/decoders/test_video_decoder.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,14 @@
1111

1212
from torchcodec.decoders import _core, VideoDecoder
1313

14-
from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO
14+
from ..utils import (
15+
assert_frames_equal,
16+
AV1_VIDEO,
17+
cpu_and_cuda,
18+
H265_VIDEO,
19+
in_fbcode,
20+
NASA_VIDEO,
21+
)
1522

1623

1724
class TestVideoDecoder:
@@ -409,6 +416,16 @@ def test_get_frames_at_fails(self, device):
409416
with pytest.raises(RuntimeError, match="Expected a value of type"):
410417
decoder.get_frames_at([0.3])
411418

419+
@pytest.mark.parametrize("device", cpu_and_cuda())
420+
def test_get_frame_at_av1(self, device):
421+
decoder = VideoDecoder(AV1_VIDEO.path, device=device)
422+
ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10)
423+
ref_frame_info10 = AV1_VIDEO.get_frame_info(10)
424+
decoded_frame10 = decoder.get_frame_at(10)
425+
assert decoded_frame10.duration_seconds == ref_frame_info10.duration_seconds
426+
assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds
427+
assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device))
428+
412429
@pytest.mark.parametrize("device", cpu_and_cuda())
413430
def test_get_frame_played_at(self, device):
414431
decoder = VideoDecoder(NASA_VIDEO.path, device=device)

test/generate_reference_resources.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,19 @@ do
6161
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
6262
rm -f "$bmp"
6363
done
64+
65+
# This video was generated by running the following:
66+
# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv
67+
# Note that this video only has 1 stream, at index 0.
68+
VIDEO_PATH=$RESOURCES_DIR/av1_video.mkv
69+
FRAMES=(10)
70+
for frame in "${FRAMES[@]}"; do
71+
frame_name=$(printf "%06d" "$frame")
72+
ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.stream0.frame$frame_name.bmp"
73+
done
74+
75+
for bmp in "$RESOURCES_DIR"/*.bmp
76+
do
77+
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
78+
rm -f "$bmp"
79+
done

test/resources/av1_video.mkv

16 KB
Binary file not shown.
Binary file not shown.

test/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,3 +312,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
312312
},
313313
},
314314
)
315+
316+
AV1_VIDEO = TestVideo(
317+
filename="av1_video.mkv",
318+
default_stream_index=0,
319+
# This metadata is extracted manually.
320+
# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json
321+
stream_infos={
322+
0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3),
323+
},
324+
frames={
325+
0: {
326+
10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000),
327+
},
328+
},
329+
)

0 commit comments

Comments
 (0)