Skip to content

fix: Solve CUDA AV1 decoding #448

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ git clone [email protected]:pytorch/torchcodec.git
cd torchcodec

pip install -e ".[dev]" --no-build-isolation -vv
# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv
```

### Running unit tests

To run python tests run:
To run python tests run (please make sure `torchvision` is installed):

```bash
pytest test -vvv
Expand Down
7 changes: 7 additions & 0 deletions src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,11 @@ void releaseContextOnCuda(
throwUnsupportedDeviceError(device);
}

void forceCudaCodec(
const torch::Device& device,
AVCodecPtr* codec,
const AVCodecID& codecId) {
throwUnsupportedDeviceError(device);
}

} // namespace facebook::torchcodec
32 changes: 32 additions & 0 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,36 @@ void convertAVFrameToDecodedOutputOnCuda(
<< " took: " << duration.count() << "us" << std::endl;
}

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
void forceCudaCodec(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change the semantics of this function so that it returns the codec, if found. Then the caller is responsible for doing the assignment. That would make this signature:

std::optional<AVCodecPtr> findCudaCodec(
  const torch::Device& decive,
  const AVCodecID& codecId);

Then, inside the function, when we find the right codec, we just return it. If we loop through all available codecs and never find it, we return std::nullopt.

const torch::Device& device,
AVCodecPtr* codec,
const AVCodecID& codecId) {
if (device.type() != torch::kCUDA) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace this check with throwErrorIfNonCudaDevice(device). That's the convention used by the other functions in this file, and it also enforces that calling this function in a non-CUDA context is an error.

return;
}

const AVCodec* c;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this definition into the while loop condition itself. With the change in semantics, we no longer need to reference the AVCodec outside of a single loop iteration, so it's fine (and actually good) that its scope is limited to the while loop only.

void* i = NULL;
bool found = false;

while (!found && (c = av_codec_iterate(&i))) {
const AVCodecHWConfig* config;

if (c->id != codecId || !av_codec_is_decoder(c)) {
continue;
}

for (int j = 0; config = avcodec_get_hw_config(c, j); j++) {
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
found = true;
}
}
}

if (found) {
*codec = c;
}
}

} // namespace facebook::torchcodec
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include "FFMPEGCommon.h"
#include "src/torchcodec/decoders/_core/VideoDecoder.h"

extern "C" {
Expand Down Expand Up @@ -43,4 +44,9 @@ void releaseContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext);

void forceCudaCodec(
const torch::Device& device,
AVCodecPtr* codec,
const AVCodecID& codecId);

} // namespace facebook::torchcodec
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,12 @@ void VideoDecoder::addVideoStreamDecoder(
"Stream with index " + std::to_string(streamNumber) +
" is not a video stream.");
}

if (options.device.type() == torch::kCUDA) {
forceCudaCodec(
options.device, &codec, streamInfo.stream->codecpar->codec_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the change in semantics, we can make this call:

codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id).value_or(codec);

}

AVCodecContext* codecContext = avcodec_alloc_context3(codec);
codecContext->thread_count = options.ffmpegThreadCount.value_or(0);
TORCH_CHECK(codecContext != nullptr);
Expand Down
20 changes: 19 additions & 1 deletion test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@

from torchcodec.decoders import _core, VideoDecoder

from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO
from ..utils import (
assert_frames_equal,
AV1_VIDEO,
cpu_and_cuda,
H265_VIDEO,
in_fbcode,
NASA_VIDEO,
)


class TestVideoDecoder:
Expand Down Expand Up @@ -409,6 +416,17 @@ def test_get_frames_at_fails(self, device):
with pytest.raises(RuntimeError, match="Expected a value of type"):
decoder.get_frames_at([0.3])

def test_get_frame_at_av1(self):
# We don't parametrize with CUDA because the current GPUs on CI do not
# support AV1:
decoder = VideoDecoder(AV1_VIDEO.path, device="cpu")
ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convention we're following in these tests is that the variable name number matches the index number, so this should be ref_frame10, ref_frame_info10 and decoded_frame10.

ref_frame_info11 = AV1_VIDEO.get_frame_info(10)
decoded_frame11 = decoder.get_frame_at(10)
assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds
assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds
assert_frames_equal(decoded_frame11.data, ref_frame11.to(device="cpu"))

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frame_played_at(self, device):
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
Expand Down
16 changes: 16 additions & 0 deletions test/generate_reference_resources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ do
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
rm -f "$bmp"
done

# This video was generated by running the following:
# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv
# Note that this video only has 1 stream, at index 0.
VIDEO_PATH=$RESOURCES_DIR/av1_video.mkv
FRAMES=(10)
for frame in "${FRAMES[@]}"; do
frame_name=$(printf "%06d" "$frame")
ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.stream0.frame$frame_name.bmp"
done

for bmp in "$RESOURCES_DIR"/*.bmp
do
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
rm -f "$bmp"
done
Binary file added test/resources/av1_video.mkv
Binary file not shown.
Binary file not shown.
15 changes: 15 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
},
},
)

AV1_VIDEO = TestVideo(
filename="av1_video.mkv",
default_stream_index=0,
# This metadata is extracted manually.
# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json
stream_infos={
0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3),
},
frames={
0: {
10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000),
},
},
)
Loading