fix: Create and refactor to use safe GetElementCount and GetByteSize APIs#475
fix: Create and refactor to use safe GetElementCount and GetByteSize APIs#475
Conversation
There was a problem hiding this comment.
Pull request overview
This PR strengthens shape/size validation across core components by switching wildcard-dimension checks to shared triton::common sentinels and adding explicit handling for GetElementCount / GetByteSize overflow sentinels to prevent integer overflow in allocation and validation paths.
Changes:
- Replace hard-coded
-1wildcard dimension checks withtriton::common::WILDCARD_DIM/WILDCARD_SIZE. - Add early
INVALID_ARGreturns whenGetElementCount/GetByteSizeindicate overflow viatriton::common::OVERFLOW_SIZE. - Add additional bounds checks before computing byte sizes in several state / warmup / request validation flows.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| src/sequence_state.cc | Converts wildcard dims to WILDCARD_DIM and adds overflow checks for state buffer sizing. |
| src/sequence_batch_scheduler/sequence_batch_scheduler.cc | Uses WILDCARD_DIM and adds overflow guards when computing initial-state memory requirements. |
| src/model_config_utils.cc | Adds overflow detection for reshape/dims element-count computations and updates wildcard dim handling. |
| src/infer_request.cc | Adds OVERFLOW_SIZE checks for input byte-size validation and bytes/string element validation. |
| src/backend_model_instance.cc | Adds overflow checks for warmup input sizing and refactors warmup byte-size calculation. |
| src/backend_model.cc | Adds overflow detection for input element-count calculation when configuring schedulers. |
Comments suppressed due to low confidence (1)
src/backend_model_instance.cc:395
- In the warmup first pass,
batch_byte_sizeis computed viaelement_count * GetDataTypeByteSize(...)without guarding the multiplication. Even with the newOVERFLOW_SIZEcheck, a large-but-validelement_countcan still overflowint64_there (and then be used to size allocations / max comparisons). Consider adding the sameelement_count > INT64_MAX / dtype_byte_sizestyle check used later in the second pass before computingbatch_byte_size.
int64_t batch_byte_size =
element_count *
triton::common::GetDataTypeByteSize(input_meta.second.data_type());
if (batch_byte_size == 0) {
batch_byte_size = element_count * sizeof(int32_t);
}
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
whoisj
left a comment
There was a problem hiding this comment.
looking good, but I have a couple of questions.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…into yinggeh/tri-737-psirt-triton-inference-servercore-integer-overflow
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated 10 comments.
Comments suppressed due to low confidence (1)
src/sequence_batch_scheduler/sequence_batch_scheduler.cc:1769
- If
InferenceRequest::CopyAsNull(...)fails,niwill remain null (it is explicitly set to nullptr) but the code continues and passes it intoSetControlTensors(...)andAddRequest(...), which will crash. Either return/abort the batch,continuewithout enqueueing, or convert this into a hard failure that safely stops the batcher thread.
std::unique_ptr<InferenceRequest> ni = nullptr;
Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
if (!status.IsOk()) {
LOG_ERROR
<< "internal: unexpecting failure copying null request: "
<< status.Message();
}
// Note that when the not-ready control input of the
// request is "true" the model can't assume that any
// other inputs are meaningful, including CORRID. So we
// just use zero for that.
SetControlTensors(
ni, seq_slot, 0 /* corrid */, true /* not_ready */);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
src/model_config_utils.h
Outdated
| Status | ||
| GetElementCount(const T& dims, const std::string& name, int64_t* cnt) | ||
| { | ||
| *cnt = triton::common::GetElementCount(dims); | ||
| if (*cnt == triton::common::INVALID_SIZE) { |
There was a problem hiding this comment.
These template helpers dereference output pointers (cnt, size) without validating they are non-null. Since this is a header-level utility likely to be reused broadly, add a null check (and return INVALID_ARG) to avoid a potential segfault if a caller passes nullptr.
src/model_config_utils.cc
Outdated
| RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size)); | ||
| RETURN_IF_ERROR( | ||
| GetElementCount(io.reshape().shape(), "reshape", &reshape_size)); |
There was a problem hiding this comment.
In this validation path, GetElementCount(...) is called with hardcoded names ("dims" / "reshape"), so overflow/invalid-dimension errors will not include the actual tensor name (io.name()) and can be hard to debug. Consider passing a name that includes io.name() (for example, "<io.name()>.dims") or extending the helper to accept a message prefix.
| RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size)); | |
| RETURN_IF_ERROR( | |
| GetElementCount(io.reshape().shape(), "reshape", &reshape_size)); | |
| RETURN_IF_ERROR( | |
| GetElementCount(io.dims(), io.name() + ".dims", &dims_size)); | |
| RETURN_IF_ERROR(GetElementCount( | |
| io.reshape().shape(), io.name() + ".reshape", &reshape_size)); |
| template <typename T> | ||
| Status | ||
| GetElementCount(const T& dims, const std::string& name, int64_t* cnt) | ||
| { | ||
| *cnt = triton::common::GetElementCount(dims); | ||
| if (*cnt == triton::common::INVALID_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, | ||
| "tensor '" + name + "' contains an invalid dimension in shape " + | ||
| triton::common::DimsListToString(dims)); | ||
| } else if (*cnt == triton::common::OVERFLOW_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "element count for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } else { | ||
| return Status::Success; | ||
| } | ||
| } | ||
|
|
||
| template <typename T> | ||
| Status | ||
| GetByteSize( | ||
| const inference::DataType& dtype, const T& dims, const std::string& name, | ||
| int64_t* size) | ||
| { | ||
| int64_t byte_size = 0; | ||
| if (dtype == inference::DataType::TYPE_STRING) { | ||
| int64_t element_count = 0; | ||
| RETURN_IF_ERROR(GetElementCount(dims, name, &element_count)); | ||
|
|
||
| // Total number of bytes required is equal to the element count | ||
| // multiplied by 4. | ||
| if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "byte size for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } | ||
| byte_size = sizeof(int32_t) * element_count; | ||
| } else { | ||
| byte_size = triton::common::GetByteSize(dtype, dims); | ||
| if (byte_size == triton::common::INVALID_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, | ||
| "tensor '" + name + "' contains an invalid dimension " + | ||
| triton::common::DimsListToString(dims)); | ||
| } else if (byte_size == triton::common::OVERFLOW_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "byte size for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } | ||
| } | ||
| *size = byte_size; | ||
| return Status::Success; | ||
| } |
There was a problem hiding this comment.
New safe size-checking helpers (GetElementCount / GetByteSize) introduce new error paths (INVALID_SIZE / OVERFLOW_SIZE / variable-dim rejection once added). There is existing unit test coverage in src/test/input_byte_size_test.cc; add targeted tests that exercise these new failure modes so regressions are caught.
| int64_t byte_size = 0; | ||
| if (dtype == inference::DataType::TYPE_STRING) { | ||
| int64_t element_count = 0; | ||
| RETURN_IF_ERROR(GetElementCount(dims, name, &element_count)); | ||
|
|
||
| // Total number of bytes required is equal to the element count | ||
| // multiplied by 4. | ||
| if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "byte size for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } | ||
| byte_size = sizeof(int32_t) * element_count; | ||
| } else { | ||
| byte_size = triton::common::GetByteSize(dtype, dims); | ||
| if (byte_size == triton::common::INVALID_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, | ||
| "tensor '" + name + "' contains an invalid dimension " + | ||
| triton::common::DimsListToString(dims)); | ||
| } else if (byte_size == triton::common::OVERFLOW_SIZE) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "byte size for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } | ||
| } |
There was a problem hiding this comment.
GetByteSize(...) returns Status::Success even when dims contains variable-size (wildcard) dimensions. In that case GetElementCount may yield WILDCARD_SIZE (-1) and this code can compute a negative byte_size (or accept a wildcard result from triton::common::GetByteSize), which can later be implicitly converted to size_t and cause huge allocations/corruption. Consider explicitly detecting wildcard dimensions / WILDCARD_SIZE and returning INVALID_ARG with a clear message when the byte size cannot be computed.
| // Total number of bytes required is equal to the element count | ||
| // multiplied by 4. | ||
| if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) { | ||
| return Status( | ||
| Status::Code::INVALID_ARG, "byte size for tensor '" + name + | ||
| "' exceeds maximum size of " + | ||
| std::to_string(INT64_MAX)); | ||
| } |
There was a problem hiding this comment.
This header uses INT64_MAX and int32_t but does not include the standard headers that define them. Relying on transitive includes is brittle and can break compilation depending on include order; consider adding an explicit <cstdint> include (or switch to std::numeric_limits<int64_t>::max() with <limits>).
src/backend_model_instance.cc
Outdated
| RETURN_IF_ERROR(GetByteSize( | ||
| input_meta.second.data_type(), input_meta.second.dims(), | ||
| input_meta.first, reinterpret_cast<int64_t*>(&batch_byte_size))); |
There was a problem hiding this comment.
Passing reinterpret_cast<int64_t*>(&batch_byte_size) where batch_byte_size is a size_t is undefined behavior (strict-aliasing/type-punning) and can miscompute sizes on platforms where size_t differs from int64_t. Use an int64_t temporary for GetByteSize(...) and then range-check and cast to size_t.
| RETURN_IF_ERROR(GetByteSize( | |
| input_meta.second.data_type(), input_meta.second.dims(), | |
| input_meta.first, reinterpret_cast<int64_t*>(&batch_byte_size))); | |
| int64_t batch_byte_size_int = 0; | |
| RETURN_IF_ERROR(GetByteSize( | |
| input_meta.second.data_type(), input_meta.second.dims(), | |
| input_meta.first, &batch_byte_size_int)); | |
| if ((batch_byte_size_int < 0) || | |
| (batch_byte_size_int > static_cast<int64_t>(SIZE_MAX))) { | |
| return Status( | |
| Status::Code::INVALID_ARG, | |
| lrequest->LogRequest() + | |
| "warmup setting expects a valid, non-negative byte size " | |
| "within size_t range"); | |
| } | |
| batch_byte_size = static_cast<size_t>(batch_byte_size_int); |
| size_t total_byte_size = 0; | ||
| RETURN_IF_ERROR(GetByteSize( | ||
| initial_state.data_type(), initial_state.dims(), state.input_name(), | ||
| reinterpret_cast<int64_t*>(&total_byte_size))); |
There was a problem hiding this comment.
Passing reinterpret_cast<int64_t*>(&total_byte_size) where total_byte_size is a size_t is undefined behavior and can lead to incorrect sizing. Use an int64_t temporary for GetByteSize(...) and then validate it is non-negative and fits in size_t before assigning.
| Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni); | ||
| if (!status.IsOk()) { | ||
| LOG_ERROR | ||
| << "internal: unexpecting failure copying null request: " |
There was a problem hiding this comment.
Typo in the log message: "unexpecting" should be "unexpected".
| << "internal: unexpecting failure copying null request: " | |
| << "internal: unexpected failure copying null request: " |
src/infer_request.cc
Outdated
| if (input.second.DType() == inference::DataType::TYPE_STRING) { | ||
| int64_t element_count = | ||
| triton::common::GetElementCount(input.second.Shape()); | ||
|
|
||
| size_t str_byte_size = static_cast<size_t>(4 * element_count); | ||
| size_t str_byte_size = 0; | ||
| RETURN_IF_ERROR(GetByteSize( | ||
| inference::DataType::TYPE_STRING, input.second.Shape(), input.first, | ||
| reinterpret_cast<int64_t*>(&str_byte_size))); |
There was a problem hiding this comment.
Passing reinterpret_cast<int64_t*>(&str_byte_size) where str_byte_size is a size_t is undefined behavior and can lead to mis-sized buffers. Use an int64_t temporary for GetByteSize(...) and then validate/cast to size_t.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 10 out of 10 changed files in this pull request and generated no new comments.
Comments suppressed due to low confidence (1)
src/sequence_batch_scheduler/sequence_batch_scheduler.cc:1769
- If
InferenceRequest::CopyAsNull(*null_irequest, &ni)fails,niremains null but is still dereferenced inSetControlTensors(ni, ...)and later viani->SetSequenceStates(...), which would lead to a crash. This should either useRETURN_IF_ERROR(...)(and propagate/abort the batcher thread iteration appropriately) or explicitly handle the error case (e.g., skip adding the null request / continue the loop after logging).
std::unique_ptr<InferenceRequest> ni = nullptr;
Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
if (!status.IsOk()) {
LOG_ERROR << "internal: unexpected failure copying null request: "
<< status.Message();
}
// Note that when the not-ready control input of the
// request is "true" the model can't assume that any
// other inputs are meaningful, including CORRID. So we
// just use zero for that.
SetControlTensors(
ni, seq_slot, 0 /* corrid */, true /* not_ready */);
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
What does the PR do?
Checklist
<commit_type>: <Title>Commit Type:
Check the conventional commit type
box here and add the label to the github PR.
Related PRs:
Where should the reviewer start?
Test plan:
Caveats:
Background
Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to)