diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 8b3f29ca..4d0cbddf 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -37,9 +37,8 @@ bool CpuDeviceInterface::DecodedFrameContext::operator!=( CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cpu, "CpuDeviceInterface was not registered!"); - if (device_.type() != torch::kCPU) { - throw std::runtime_error("Unsupported device: " + device_.str()); - } + TORCH_CHECK( + device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); } // Note [preAllocatedOutputTensor with swscale and filtergraph]: @@ -161,9 +160,10 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } } else { - throw std::runtime_error( - "Invalid color conversion library: " + - std::to_string(static_cast(colorConversionLibrary))); + TORCH_CHECK( + false, + "Invalid color conversion library: ", + static_cast(colorConversionLibrary)); } } @@ -189,9 +189,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame) { int status = av_buffersrc_write_frame( filterGraphContext_.sourceContext, avFrame.get()); - if (status < AVSUCCESS) { - throw std::runtime_error("Failed to add frame to buffer source context"); - } + TORCH_CHECK( + status >= AVSUCCESS, "Failed to add frame to buffer source context"); UniqueAVFrame filteredAVFrame(av_frame_alloc()); status = av_buffersink_get_frame( @@ -241,11 +240,12 @@ void CpuDeviceInterface::createFilterGraph( filterArgs.str().c_str(), nullptr, filterGraphContext_.filterGraph.get()); - if (status < 0) { - throw std::runtime_error( - std::string("Failed to create filter graph: ") + filterArgs.str() + - ": " + getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + filterArgs.str(), + ": ", + getFFMPEGErrorStringFromErrorCode(status)); status = avfilter_graph_create_filter( &filterGraphContext_.sinkContext, @@ -254,11 +254,10 @@ void CpuDeviceInterface::createFilterGraph( nullptr, nullptr, filterGraphContext_.filterGraph.get()); - if (status < 0) { - throw std::runtime_error( - "Failed to create filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to create filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; @@ -268,11 +267,10 @@ void CpuDeviceInterface::createFilterGraph( pix_fmts, AV_PIX_FMT_NONE, AV_OPT_SEARCH_CHILDREN); - if (status < 0) { - throw std::runtime_error( - "Failed to set output pixel formats: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to set output pixel formats: ", + getFFMPEGErrorStringFromErrorCode(status)); UniqueAVFilterInOut outputs(avfilter_inout_alloc()); UniqueAVFilterInOut inputs(avfilter_inout_alloc()); @@ -301,19 +299,17 @@ void CpuDeviceInterface::createFilterGraph( nullptr); outputs.reset(outputsTmp); inputs.reset(inputsTmp); - if (status < 0) { - throw std::runtime_error( - "Failed to parse filter description: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to parse filter description: ", + getFFMPEGErrorStringFromErrorCode(status)); status = avfilter_graph_config(filterGraphContext_.filterGraph.get(), nullptr); - if (status < 0) { - throw std::runtime_error( - "Failed to configure filter graph: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to configure filter graph: ", + getFFMPEGErrorStringFromErrorCode(status)); } void CpuDeviceInterface::createSwsContext( diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index a5ebde8d..8086d0b4 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -166,9 +166,8 @@ AVBufferRef* getCudaContext(const torch::Device& device) { CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) : DeviceInterface(device) { TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); - if (device_.type() != torch::kCUDA) { - throw std::runtime_error("Unsupported device: " + device_.str()); - } + TORCH_CHECK( + device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); } CudaDeviceInterface::~CudaDeviceInterface() { diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index a66281cd..02a2d44c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -103,11 +103,10 @@ void SingleStreamDecoder::initializeDecoder() { // which decodes a few frames to get missing info. For more, see: // https://ffmpeg.org/doxygen/7.0/group__lavf__decoding.html int status = avformat_find_stream_info(formatContext_.get(), nullptr); - if (status < 0) { - throw std::runtime_error( - "Failed to find stream info: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Failed to find stream info: ", + getFFMPEGErrorStringFromErrorCode(status)); for (unsigned int i = 0; i < formatContext_->nb_streams; i++) { AVStream* avStream = formatContext_->streams[i]; @@ -222,11 +221,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { break; } - if (status != AVSUCCESS) { - throw std::runtime_error( - "Failed to read frame from input file: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status == AVSUCCESS, + "Failed to read frame from input file: ", + getFFMPEGErrorStringFromErrorCode(status)); if (packet->flags & AV_PKT_FLAG_DISCARD) { continue; @@ -279,11 +277,10 @@ void SingleStreamDecoder::scanFileAndUpdateMetadataAndIndex() { // Reset the seek-cursor back to the beginning. int status = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0); - if (status < 0) { - throw std::runtime_error( - "Could not seek file to pts=0: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Could not seek file to pts=0: ", + getFFMPEGErrorStringFromErrorCode(status)); // Sort all frames by their pts. for (auto& [streamIndex, streamInfo] : streamInfos_) { @@ -415,9 +412,7 @@ void SingleStreamDecoder::addStream( } retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); - if (retVal < AVSUCCESS) { - throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); - } + TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); codecContext->time_base = streamInfo.stream->time_base; containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = @@ -446,11 +441,11 @@ void SingleStreamDecoder::addVideoStream( auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; - if (seekMode_ == SeekMode::approximate && - !streamMetadata.averageFpsFromHeader.has_value()) { - throw std::runtime_error( - "Seek mode is approximate, but stream " + - std::to_string(activeStreamIndex_) + + if (seekMode_ == SeekMode::approximate) { + TORCH_CHECK( + streamMetadata.averageFpsFromHeader.has_value(), + "Seek mode is approximate, but stream ", + std::to_string(activeStreamIndex_), " does not have an average fps in its metadata."); } @@ -1048,11 +1043,13 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { desiredPts, desiredPts, 0); - if (status < 0) { - throw std::runtime_error( - "Could not seek file to pts=" + std::to_string(desiredPts) + ": " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= 0, + "Could not seek file to pts=", + std::to_string(desiredPts), + ": ", + getFFMPEGErrorStringFromErrorCode(status)); + decodeStats_.numFlushes++; avcodec_flush_buffers(streamInfo.codecContext.get()); } @@ -1121,21 +1118,20 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( status = avcodec_send_packet( streamInfo.codecContext.get(), /*avpkt=*/nullptr); - if (status < AVSUCCESS) { - throw std::runtime_error( - "Could not flush decoder: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= AVSUCCESS, + "Could not flush decoder: ", + getFFMPEGErrorStringFromErrorCode(status)); reachedEOF = true; break; } - if (status < AVSUCCESS) { - throw std::runtime_error( - "Could not read frame from input file: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= AVSUCCESS, + "Could not read frame from input file: ", + getFFMPEGErrorStringFromErrorCode(status)); + } while (packet->stream_index != activeStreamIndex_); if (reachedEOF) { @@ -1147,11 +1143,10 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); - if (status < AVSUCCESS) { - throw std::runtime_error( - "Could not push packet to decoder: " + - getFFMPEGErrorStringFromErrorCode(status)); - } + TORCH_CHECK( + status >= AVSUCCESS, + "Could not push packet to decoder: ", + getFFMPEGErrorStringFromErrorCode(status)); decodeStats_.numPacketsSentToDecoder++; } @@ -1162,8 +1157,9 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( "Requested next frame while there are no more frames left to " "decode."); } - throw std::runtime_error( - "Could not receive frame from decoder: " + + TORCH_CHECK( + false, + "Could not receive frame from decoder: ", getFFMPEGErrorStringFromErrorCode(status)); } @@ -1429,7 +1425,7 @@ int64_t SingleStreamDecoder::secondsToIndexLowerBound(double seconds) { return std::floor(seconds * streamMetadata.averageFpsFromHeader.value()); } default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1456,7 +1452,7 @@ int64_t SingleStreamDecoder::secondsToIndexUpperBound(double seconds) { return std::ceil(seconds * streamMetadata.averageFpsFromHeader.value()); } default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1476,7 +1472,7 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { streamInfo.timeBase); } default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1493,7 +1489,7 @@ std::optional SingleStreamDecoder::getNumFrames( return streamMetadata.numFramesFromHeader; } default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1505,7 +1501,7 @@ double SingleStreamDecoder::getMinSeconds( case SeekMode::approximate: return 0; default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1518,7 +1514,7 @@ std::optional SingleStreamDecoder::getMaxSeconds( return streamMetadata.durationSecondsFromHeader; } default: - throw std::runtime_error("Unknown SeekMode"); + TORCH_CHECK(false, "Unknown SeekMode"); } } @@ -1552,10 +1548,10 @@ void SingleStreamDecoder::validateActiveStream( } void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { - if (!scannedAllStreams_) { - throw std::runtime_error( - "Must scan all streams to update metadata before calling " + msg); - } + TORCH_CHECK( + scannedAllStreams_, + "Must scan all streams to update metadata before calling ", + msg); } void SingleStreamDecoder::validateFrameIndex( diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 03718698..e838b090 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -243,8 +243,10 @@ void _add_video_stream( videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::SWSCALE; } else { - throw std::runtime_error( - "Invalid color_conversion_library=" + stdColorConversionLibrary + + TORCH_CHECK( + false, + "Invalid color_conversion_library=", + stdColorConversionLibrary, ". color_conversion_library must be either filtergraph or swscale."); } } @@ -561,6 +563,7 @@ std::string get_stream_json_metadata( throw std::out_of_range( "stream_index out of bounds: " + std::to_string(stream_index)); } + auto streamMetadata = allStreamMetadata[stream_index]; std::map map;