-
Notifications
You must be signed in to change notification settings - Fork 38
fix: Solve CUDA AV1 decoding #448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
d92149c
b4825a7
7e6bc92
a29287c
5e82e48
70c1985
a9fa4bb
d49fde5
25437b0
c64b833
9a9b50f
e453891
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -42,11 +42,12 @@ git clone [email protected]:pytorch/torchcodec.git | |
cd torchcodec | ||
|
||
pip install -e ".[dev]" --no-build-isolation -vv | ||
# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv | ||
``` | ||
|
||
### Running unit tests | ||
|
||
To run python tests run: | ||
To run python tests run (please make sure `torchvision` is installed): | ||
|
||
```bash | ||
pytest test -vvv | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,4 +256,36 @@ void convertAVFrameToDecodedOutputOnCuda( | |
<< " took: " << duration.count() << "us" << std::endl; | ||
} | ||
|
||
// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 | ||
void forceCudaCodec( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's change the semantics of this function so that it returns the codec, if found. Then the caller is responsible for doing the assignment. That would make this signature:
Then, inside the function, when we find the right codec, we just return it. If we loop through all available codecs and never find it, we return |
||
const torch::Device& device, | ||
AVCodecPtr* codec, | ||
const AVCodecID& codecId) { | ||
if (device.type() != torch::kCUDA) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Replace this check with |
||
return; | ||
} | ||
|
||
const AVCodec* c; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's move this definition into the while loop condition itself. With the change in semantics, we no longer need to reference the |
||
void* i = NULL; | ||
bool found = false; | ||
|
||
while (!found && (c = av_codec_iterate(&i))) { | ||
const AVCodecHWConfig* config; | ||
|
||
if (c->id != codecId || !av_codec_is_decoder(c)) { | ||
continue; | ||
} | ||
|
||
for (int j = 0; config = avcodec_get_hw_config(c, j); j++) { | ||
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { | ||
found = true; | ||
} | ||
} | ||
} | ||
|
||
if (found) { | ||
*codec = c; | ||
} | ||
} | ||
|
||
} // namespace facebook::torchcodec |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -456,6 +456,12 @@ void VideoDecoder::addVideoStreamDecoder( | |
"Stream with index " + std::to_string(streamNumber) + | ||
" is not a video stream."); | ||
} | ||
|
||
if (options.device.type() == torch::kCUDA) { | ||
forceCudaCodec( | ||
options.device, &codec, streamInfo.stream->codecpar->codec_id); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With the change in semantics, we can make this call:
|
||
} | ||
|
||
AVCodecContext* codecContext = avcodec_alloc_context3(codec); | ||
codecContext->thread_count = options.ffmpegThreadCount.value_or(0); | ||
TORCH_CHECK(codecContext != nullptr); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,7 +11,14 @@ | |
|
||
from torchcodec.decoders import _core, VideoDecoder | ||
|
||
from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO | ||
from ..utils import ( | ||
assert_frames_equal, | ||
AV1_VIDEO, | ||
cpu_and_cuda, | ||
H265_VIDEO, | ||
in_fbcode, | ||
NASA_VIDEO, | ||
) | ||
|
||
|
||
class TestVideoDecoder: | ||
|
@@ -409,6 +416,17 @@ def test_get_frames_at_fails(self, device): | |
with pytest.raises(RuntimeError, match="Expected a value of type"): | ||
decoder.get_frames_at([0.3]) | ||
|
||
def test_get_frame_at_av1(self): | ||
# We don't parametrize with CUDA because the current GPUs on CI do not | ||
# support AV1: | ||
decoder = VideoDecoder(AV1_VIDEO.path, device="cpu") | ||
ref_frame11 = AV1_VIDEO.get_frame_data_by_index(10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The convention we're following in these tests is that the variable name number matches the index number, so this should be |
||
ref_frame_info11 = AV1_VIDEO.get_frame_info(10) | ||
decoded_frame11 = decoder.get_frame_at(10) | ||
assert decoded_frame11.duration_seconds == ref_frame_info11.duration_seconds | ||
assert decoded_frame11.pts_seconds == ref_frame_info11.pts_seconds | ||
assert_frames_equal(decoded_frame11.data, ref_frame11.to(device="cpu")) | ||
|
||
@pytest.mark.parametrize("device", cpu_and_cuda()) | ||
def test_get_frame_played_at(self, device): | ||
decoder = VideoDecoder(NASA_VIDEO.path, device=device) | ||
|
Uh oh!
There was an error while loading. Please reload this page.