Skip to content

Replacing throw with TORCH_CHECK #725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 30 additions & 34 deletions src/torchcodec/_core/CpuDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,8 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=(
CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!");
if (device_.type() != torch::kCPU) {
throw std::runtime_error("Unsupported device: " + device_.str());
}
TORCH_CHECK(
device_.type() == torch::kCPU, "Unsupported device: ", device_.str());
}

// Note [preAllocatedOutputTensor with swscale and filtergraph]:
Expand Down Expand Up @@ -161,9 +160,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput(
frameOutput.data = outputTensor;
}
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(colorConversionLibrary)));
TORCH_CHECK(
false,
"Invalid color conversion library: ",
static_cast<int>(colorConversionLibrary));
}
}

Expand All @@ -189,9 +189,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph(
const UniqueAVFrame& avFrame) {
int status = av_buffersrc_write_frame(
filterGraphContext_.sourceContext, avFrame.get());
if (status < AVSUCCESS) {
throw std::runtime_error("Failed to add frame to buffer source context");
}
TORCH_CHECK(
status >= AVSUCCESS, "Failed to add frame to buffer source context");

UniqueAVFrame filteredAVFrame(av_frame_alloc());
status = av_buffersink_get_frame(
Expand Down Expand Up @@ -241,11 +240,12 @@ void CpuDeviceInterface::createFilterGraph(
filterArgs.str().c_str(),
nullptr,
filterGraphContext_.filterGraph.get());
if (status < 0) {
throw std::runtime_error(
std::string("Failed to create filter graph: ") + filterArgs.str() +
": " + getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
filterArgs.str(),
": ",
getFFMPEGErrorStringFromErrorCode(status));

status = avfilter_graph_create_filter(
&filterGraphContext_.sinkContext,
Expand All @@ -254,11 +254,10 @@ void CpuDeviceInterface::createFilterGraph(
nullptr,
nullptr,
filterGraphContext_.filterGraph.get());
if (status < 0) {
throw std::runtime_error(
"Failed to create filter graph: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to create filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));

enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};

Expand All @@ -268,11 +267,10 @@ void CpuDeviceInterface::createFilterGraph(
pix_fmts,
AV_PIX_FMT_NONE,
AV_OPT_SEARCH_CHILDREN);
if (status < 0) {
throw std::runtime_error(
"Failed to set output pixel formats: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to set output pixel formats: ",
getFFMPEGErrorStringFromErrorCode(status));

UniqueAVFilterInOut outputs(avfilter_inout_alloc());
UniqueAVFilterInOut inputs(avfilter_inout_alloc());
Expand Down Expand Up @@ -301,19 +299,17 @@ void CpuDeviceInterface::createFilterGraph(
nullptr);
outputs.reset(outputsTmp);
inputs.reset(inputsTmp);
if (status < 0) {
throw std::runtime_error(
"Failed to parse filter description: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to parse filter description: ",
getFFMPEGErrorStringFromErrorCode(status));

status =
avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr);
if (status < 0) {
throw std::runtime_error(
"Failed to configure filter graph: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to configure filter graph: ",
getFFMPEGErrorStringFromErrorCode(status));
}

void CpuDeviceInterface::createSwsContext(
Expand Down
5 changes: 2 additions & 3 deletions src/torchcodec/_core/CudaDeviceInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,8 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device)
: DeviceInterface(device) {
TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!");
if (device_.type() != torch::kCUDA) {
throw std::runtime_error("Unsupported device: " + device_.str());
}
TORCH_CHECK(
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());
}

CudaDeviceInterface::~CudaDeviceInterface() {
Expand Down
106 changes: 51 additions & 55 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,10 @@ void SingleStreamDecoder::initializeDecoder() {
// which decodes a few frames to get missing info. For more, see:
// https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html
int status = avformat_find_stream_info(formatContext_.get(), nullptr);
if (status < 0) {
throw std::runtime_error(
"Failed to find stream info: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Failed to find stream info: ",
getFFMPEGErrorStringFromErrorCode(status));

for (unsigned int i = 0; i < formatContext_->nb_streams; i++) {
AVStream* avStream = formatContext_->streams[i];
Expand Down Expand Up @@ -222,11 +221,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {
break;
}

if (status != AVSUCCESS) {
throw std::runtime_error(
"Failed to read frame from input file: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status == AVSUCCESS,
"Failed to read frame from input file: ",
getFFMPEGErrorStringFromErrorCode(status));

if (packet->flags & AV_PKT_FLAG_DISCARD) {
continue;
Expand Down Expand Up @@ -279,11 +277,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() {

// Reset the seek-cursor back to the beginning.
int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0);
if (status < 0) {
throw std::runtime_error(
"Could not seek file to pts=0: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Could not seek file to pts=0: ",
getFFMPEGErrorStringFromErrorCode(status));

// Sort all frames by their pts.
for (auto& [streamIndex, streamInfo] : streamInfos_) {
Expand Down Expand Up @@ -415,9 +412,7 @@ void SingleStreamDecoder::addStream(
}

retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr);
if (retVal < AVSUCCESS) {
throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal));
}
TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal));

codecContext->time_base = streamInfo.stream->time_base;
containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName =
Expand Down Expand Up @@ -446,11 +441,11 @@ void SingleStreamDecoder::addVideoStream(
auto& streamMetadata =
containerMetadata_.allStreamMetadata[activeStreamIndex_];

if (seekMode_ == SeekMode::approximate &&
!streamMetadata.averageFpsFromHeader.has_value()) {
throw std::runtime_error(
"Seek mode is approximate, but stream " +
std::to_string(activeStreamIndex_) +
if (seekMode_ == SeekMode::approximate) {
TORCH_CHECK(
streamMetadata.averageFpsFromHeader.has_value(),
"Seek mode is approximate, but stream ",
std::to_string(activeStreamIndex_),
" does not have an average fps in its metadata.");
}

Expand Down Expand Up @@ -1048,11 +1043,13 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() {
desiredPts,
desiredPts,
0);
if (status < 0) {
throw std::runtime_error(
"Could not seek file to pts=" + std::to_string(desiredPts) + ": " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= 0,
"Could not seek file to pts=",
std::to_string(desiredPts),
": ",
getFFMPEGErrorStringFromErrorCode(status));

decodeStats_.numFlushes++;
avcodec_flush_buffers(streamInfo.codecContext.get());
}
Expand Down Expand Up @@ -1121,21 +1118,20 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
status = avcodec_send_packet(
streamInfo.codecContext.get(),
/*avpkt=*/nullptr);
if (status < AVSUCCESS) {
throw std::runtime_error(
"Could not flush decoder: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= AVSUCCESS,
"Could not flush decoder: ",
getFFMPEGErrorStringFromErrorCode(status));

reachedEOF = true;
break;
}

if (status < AVSUCCESS) {
throw std::runtime_error(
"Could not read frame from input file: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= AVSUCCESS,
"Could not read frame from input file: ",
getFFMPEGErrorStringFromErrorCode(status));

} while (packet->stream_index != activeStreamIndex_);

if (reachedEOF) {
Expand All @@ -1147,11 +1143,10 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
// We got a valid packet. Send it to the decoder, and we'll receive it in
// the next iteration.
status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get());
if (status < AVSUCCESS) {
throw std::runtime_error(
"Could not push packet to decoder: " +
getFFMPEGErrorStringFromErrorCode(status));
}
TORCH_CHECK(
status >= AVSUCCESS,
"Could not push packet to decoder: ",
getFFMPEGErrorStringFromErrorCode(status));

decodeStats_.numPacketsSentToDecoder++;
}
Expand All @@ -1162,8 +1157,9 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame(
"Requested next frame while there are no more frames left to "
"decode.");
}
throw std::runtime_error(
"Could not receive frame from decoder: " +
TORCH_CHECK(
false,
"Could not receive frame from decoder: ",
getFFMPEGErrorStringFromErrorCode(status));
}

Expand Down Expand Up @@ -1429,7 +1425,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) {
return std::floor(seconds * streamMetadata.averageFpsFromHeader.value());
}
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand All @@ -1456,7 +1452,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) {
return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value());
}
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand All @@ -1476,7 +1472,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) {
streamInfo.timeBase);
}
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand All @@ -1493,7 +1489,7 @@ std::optional<int64_t> SingleStreamDecoder::getNumFrames(
return streamMetadata.numFramesFromHeader;
}
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand All @@ -1505,7 +1501,7 @@ double SingleStreamDecoder::getMinSeconds(
case SeekMode::approximate:
return 0;
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand All @@ -1518,7 +1514,7 @@ std::optional<double> SingleStreamDecoder::getMaxSeconds(
return streamMetadata.durationSecondsFromHeader;
}
default:
throw std::runtime_error("Unknown SeekMode");
TORCH_CHECK(false, "Unknown SeekMode");
}
}

Expand Down Expand Up @@ -1552,10 +1548,10 @@ void SingleStreamDecoder::validateActiveStream(
}

void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) {
if (!scannedAllStreams_) {
throw std::runtime_error(
"Must scan all streams to update metadata before calling " + msg);
}
TORCH_CHECK(
scannedAllStreams_,
"Must scan all streams to update metadata before calling ",
msg);
}

void SingleStreamDecoder::validateFrameIndex(
Expand Down
7 changes: 5 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ void _add_video_stream(
videoStreamOptions.colorConversionLibrary =
ColorConversionLibrary::SWSCALE;
} else {
throw std::runtime_error(
"Invalid color_conversion_library=" + stdColorConversionLibrary +
TORCH_CHECK(
false,
"Invalid color_conversion_library=",
stdColorConversionLibrary,
". color_conversion_library must be either filtergraph or swscale.");
}
}
Expand Down Expand Up @@ -561,6 +563,7 @@ std::string get_stream_json_metadata(
throw std::out_of_range(
"stream_index out of bounds: " + std::to_string(stream_index));
}

auto streamMetadata = allStreamMetadata[stream_index];

std::map<std::string, std::string> map;
Expand Down
Loading