Skip to content

fix: Create and refactor to use safe GetElementCount and GetByteSize APIs#475

Open
yinggeh wants to merge 15 commits intomainfrom
yinggeh/tri-737-psirt-triton-inference-servercore-integer-overflow
Open

fix: Create and refactor to use safe GetElementCount and GetByteSize APIs#475
yinggeh wants to merge 15 commits intomainfrom
yinggeh/tri-737-psirt-triton-inference-servercore-integer-overflow

Conversation

@yinggeh
Copy link
Contributor

@yinggeh yinggeh commented Feb 26, 2026

What does the PR do?

Checklist

  • PR title reflects the change and is of format <commit_type>: <Title>
  • Changes are described in the pull request.
  • Related issues are referenced.
  • Populated github labels field
  • Added test plan and verified test passes.
  • Verified that the PR passes existing CI.
  • Verified copyright is correct on all changed files.
  • Added succinct git squash message before merging ref.
  • All template sections are filled out.
  • Optional: Additional screenshots for behavior/output changes with before/after.

Commit Type:

Check the conventional commit type
box here and add the label to the github PR.

  • fix

Related PRs:

Where should the reviewer start?

Test plan:

  • CI Pipeline ID:

Caveats:

Background

Related Issues: (use one of the action keywords Closes / Fixes / Resolves / Relates to)

  • closes GitHub issue: #xxx

@yinggeh yinggeh self-assigned this Feb 26, 2026
@yinggeh yinggeh added the bug Something isn't working label Feb 26, 2026
@yinggeh yinggeh changed the title fix: Prevent integer overflow in GetElementCount and GetByteSize APIs fix: Prevent integer overflow calling GetElementCount and GetByteSize APIs Feb 26, 2026
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR strengthens shape/size validation across core components by switching wildcard-dimension checks to shared triton::common sentinels and adding explicit handling for GetElementCount / GetByteSize overflow sentinels to prevent integer overflow in allocation and validation paths.

Changes:

  • Replace hard-coded -1 wildcard dimension checks with triton::common::WILDCARD_DIM / WILDCARD_SIZE.
  • Add early INVALID_ARG returns when GetElementCount / GetByteSize indicate overflow via triton::common::OVERFLOW_SIZE.
  • Add additional bounds checks before computing byte sizes in several state / warmup / request validation flows.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
src/sequence_state.cc Converts wildcard dims to WILDCARD_DIM and adds overflow checks for state buffer sizing.
src/sequence_batch_scheduler/sequence_batch_scheduler.cc Uses WILDCARD_DIM and adds overflow guards when computing initial-state memory requirements.
src/model_config_utils.cc Adds overflow detection for reshape/dims element-count computations and updates wildcard dim handling.
src/infer_request.cc Adds OVERFLOW_SIZE checks for input byte-size validation and bytes/string element validation.
src/backend_model_instance.cc Adds overflow checks for warmup input sizing and refactors warmup byte-size calculation.
src/backend_model.cc Adds overflow detection for input element-count calculation when configuring schedulers.
Comments suppressed due to low confidence (1)

src/backend_model_instance.cc:395

  • In the warmup first pass, batch_byte_size is computed via element_count * GetDataTypeByteSize(...) without guarding the multiplication. Even with the new OVERFLOW_SIZE check, a large-but-valid element_count can still overflow int64_t here (and then be used to size allocations / max comparisons). Consider adding the same element_count > INT64_MAX / dtype_byte_size style check used later in the second pass before computing batch_byte_size.
      int64_t batch_byte_size =
          element_count *
          triton::common::GetDataTypeByteSize(input_meta.second.data_type());
      if (batch_byte_size == 0) {
        batch_byte_size = element_count * sizeof(int32_t);
      }

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link

Copilot AI commented Feb 26, 2026

@yinggeh I've opened a new pull request, #476, to work on those changes. Once the pull request is ready, I'll request review from you.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@whoisj whoisj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looking good, but I have a couple of questions.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

whoisj
whoisj previously approved these changes Mar 2, 2026
@yinggeh yinggeh changed the title fix: Prevent integer overflow calling GetElementCount and GetByteSize APIs fix: Create and refactor to use safe GetElementCount and GetByteSize APIs Mar 6, 2026
@yinggeh yinggeh requested a review from Copilot March 6, 2026 21:34
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 10 comments.

Comments suppressed due to low confidence (1)

src/sequence_batch_scheduler/sequence_batch_scheduler.cc:1769

  • If InferenceRequest::CopyAsNull(...) fails, ni will remain null (it is explicitly set to nullptr) but the code continues and passes it into SetControlTensors(...) and AddRequest(...), which will crash. Either return/abort the batch, continue without enqueueing, or convert this into a hard failure that safely stops the batcher thread.
            std::unique_ptr<InferenceRequest> ni = nullptr;
            Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
            if (!status.IsOk()) {
              LOG_ERROR
                  << "internal: unexpecting failure copying null request: "
                  << status.Message();
            }
            // Note that when the not-ready control input of the
            // request is "true" the model can't assume that any
            // other inputs are meaningful, including CORRID. So we
            // just use zero for that.
            SetControlTensors(
                ni, seq_slot, 0 /* corrid */, true /* not_ready */);


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +323 to +327
Status
GetElementCount(const T& dims, const std::string& name, int64_t* cnt)
{
*cnt = triton::common::GetElementCount(dims);
if (*cnt == triton::common::INVALID_SIZE) {
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These template helpers dereference output pointers (cnt, size) without validating they are non-null. Since this is a header-level utility likely to be reused broadly, add a null check (and return INVALID_ARG) to avoid a potential segfault if a caller passes nullptr.

Copilot uses AI. Check for mistakes.
Comment on lines +358 to +360
RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size));
RETURN_IF_ERROR(
GetElementCount(io.reshape().shape(), "reshape", &reshape_size));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this validation path, GetElementCount(...) is called with hardcoded names ("dims" / "reshape"), so overflow/invalid-dimension errors will not include the actual tensor name (io.name()) and can be hard to debug. Consider passing a name that includes io.name() (for example, "<io.name()>.dims") or extending the helper to accept a message prefix.

Suggested change
RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size));
RETURN_IF_ERROR(
GetElementCount(io.reshape().shape(), "reshape", &reshape_size));
RETURN_IF_ERROR(
GetElementCount(io.dims(), io.name() + ".dims", &dims_size));
RETURN_IF_ERROR(GetElementCount(
io.reshape().shape(), io.name() + ".reshape", &reshape_size));

Copilot uses AI. Check for mistakes.
Comment on lines +322 to +378
template <typename T>
Status
GetElementCount(const T& dims, const std::string& name, int64_t* cnt)
{
*cnt = triton::common::GetElementCount(dims);
if (*cnt == triton::common::INVALID_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"tensor '" + name + "' contains an invalid dimension in shape " +
triton::common::DimsListToString(dims));
} else if (*cnt == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG, "element count for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
} else {
return Status::Success;
}
}

template <typename T>
Status
GetByteSize(
const inference::DataType& dtype, const T& dims, const std::string& name,
int64_t* size)
{
int64_t byte_size = 0;
if (dtype == inference::DataType::TYPE_STRING) {
int64_t element_count = 0;
RETURN_IF_ERROR(GetElementCount(dims, name, &element_count));

// Total number of bytes required is equal to the element count
// multiplied by 4.
if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
byte_size = sizeof(int32_t) * element_count;
} else {
byte_size = triton::common::GetByteSize(dtype, dims);
if (byte_size == triton::common::INVALID_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"tensor '" + name + "' contains an invalid dimension " +
triton::common::DimsListToString(dims));
} else if (byte_size == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
}
*size = byte_size;
return Status::Success;
}
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New safe size-checking helpers (GetElementCount / GetByteSize) introduce new error paths (INVALID_SIZE / OVERFLOW_SIZE / variable-dim rejection once added). There is existing unit test coverage in src/test/input_byte_size_test.cc; add targeted tests that exercise these new failure modes so regressions are caught.

Copilot uses AI. Check for mistakes.
Comment on lines +348 to +375
int64_t byte_size = 0;
if (dtype == inference::DataType::TYPE_STRING) {
int64_t element_count = 0;
RETURN_IF_ERROR(GetElementCount(dims, name, &element_count));

// Total number of bytes required is equal to the element count
// multiplied by 4.
if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
byte_size = sizeof(int32_t) * element_count;
} else {
byte_size = triton::common::GetByteSize(dtype, dims);
if (byte_size == triton::common::INVALID_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"tensor '" + name + "' contains an invalid dimension " +
triton::common::DimsListToString(dims));
} else if (byte_size == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
}
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GetByteSize(...) returns Status::Success even when dims contains variable-size (wildcard) dimensions. In that case GetElementCount may yield WILDCARD_SIZE (-1) and this code can compute a negative byte_size (or accept a wildcard result from triton::common::GetByteSize), which can later be implicitly converted to size_t and cause huge allocations/corruption. Consider explicitly detecting wildcard dimensions / WILDCARD_SIZE and returning INVALID_ARG with a clear message when the byte size cannot be computed.

Copilot uses AI. Check for mistakes.
Comment on lines +353 to +360
// Total number of bytes required is equal to the element count
// multiplied by 4.
if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This header uses INT64_MAX and int32_t but does not include the standard headers that define them. Relying on transitive includes is brittle and can break compilation depending on include order; consider adding an explicit <cstdint> include (or switch to std::numeric_limits<int64_t>::max() with <limits>).

Copilot uses AI. Check for mistakes.
Comment on lines +435 to +437
RETURN_IF_ERROR(GetByteSize(
input_meta.second.data_type(), input_meta.second.dims(),
input_meta.first, reinterpret_cast<int64_t*>(&batch_byte_size)));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing reinterpret_cast<int64_t*>(&batch_byte_size) where batch_byte_size is a size_t is undefined behavior (strict-aliasing/type-punning) and can miscompute sizes on platforms where size_t differs from int64_t. Use an int64_t temporary for GetByteSize(...) and then range-check and cast to size_t.

Suggested change
RETURN_IF_ERROR(GetByteSize(
input_meta.second.data_type(), input_meta.second.dims(),
input_meta.first, reinterpret_cast<int64_t*>(&batch_byte_size)));
int64_t batch_byte_size_int = 0;
RETURN_IF_ERROR(GetByteSize(
input_meta.second.data_type(), input_meta.second.dims(),
input_meta.first, &batch_byte_size_int));
if ((batch_byte_size_int < 0) ||
(batch_byte_size_int > static_cast<int64_t>(SIZE_MAX))) {
return Status(
Status::Code::INVALID_ARG,
lrequest->LogRequest() +
"warmup setting expects a valid, non-negative byte size "
"within size_t range");
}
batch_byte_size = static_cast<size_t>(batch_byte_size_int);

Copilot uses AI. Check for mistakes.
Comment on lines +408 to +411
size_t total_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
initial_state.data_type(), initial_state.dims(), state.input_name(),
reinterpret_cast<int64_t*>(&total_byte_size)));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing reinterpret_cast<int64_t*>(&total_byte_size) where total_byte_size is a size_t is undefined behavior and can lead to incorrect sizing. Use an int64_t temporary for GetByteSize(...) and then validate it is non-negative and fits in size_t before assigning.

Copilot uses AI. Check for mistakes.
Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
if (!status.IsOk()) {
LOG_ERROR
<< "internal: unexpecting failure copying null request: "
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in the log message: "unexpecting" should be "unexpected".

Suggested change
<< "internal: unexpecting failure copying null request: "
<< "internal: unexpected failure copying null request: "

Copilot uses AI. Check for mistakes.
Comment on lines +596 to +600
if (input.second.DType() == inference::DataType::TYPE_STRING) {
int64_t element_count =
triton::common::GetElementCount(input.second.Shape());

size_t str_byte_size = static_cast<size_t>(4 * element_count);
size_t str_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
inference::DataType::TYPE_STRING, input.second.Shape(), input.first,
reinterpret_cast<int64_t*>(&str_byte_size)));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing reinterpret_cast<int64_t*>(&str_byte_size) where str_byte_size is a size_t is undefined behavior and can lead to mis-sized buffers. Use an int64_t temporary for GetByteSize(...) and then validate/cast to size_t.

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated no new comments.

Comments suppressed due to low confidence (1)

src/sequence_batch_scheduler/sequence_batch_scheduler.cc:1769

  • If InferenceRequest::CopyAsNull(*null_irequest, &ni) fails, ni remains null but is still dereferenced in SetControlTensors(ni, ...) and later via ni->SetSequenceStates(...), which would lead to a crash. This should either use RETURN_IF_ERROR(...) (and propagate/abort the batcher thread iteration appropriately) or explicitly handle the error case (e.g., skip adding the null request / continue the loop after logging).
            std::unique_ptr<InferenceRequest> ni = nullptr;
            Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
            if (!status.IsOk()) {
              LOG_ERROR << "internal: unexpected failure copying null request: "
                        << status.Message();
            }
            // Note that when the not-ready control input of the
            // request is "true" the model can't assume that any
            // other inputs are meaningful, including CORRID. So we
            // just use zero for that.
            SetControlTensors(
                ni, seq_slot, 0 /* corrid */, true /* not_ready */);


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Development

Successfully merging this pull request may close these issues.

4 participants