Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/backend_model.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -825,10 +825,14 @@ TritonModel::SetConfiguredScheduler(
for (const auto& input : config_.input()) {
if (input.is_shape_tensor()) {
enforce_equal_shape_tensors.insert({input.name(), true});
} else if (
!input.allow_ragged_batch() &&
(triton::common::GetElementCount(input) == -1)) {
enforce_equal_shape_tensors.insert({input.name(), false});
} else {
int64_t element_count = 0;
RETURN_IF_ERROR(
GetElementCount(input.dims(), input.name(), &element_count));
if (!input.allow_ragged_batch() &&
(element_count == triton::common::WILDCARD_SIZE)) {
enforce_equal_shape_tensors.insert({input.name(), false});
}
}
}

Expand Down
34 changes: 9 additions & 25 deletions src/backend_model_instance.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -372,22 +372,10 @@ TritonModelInstance::GenerateWarmupData()
int64_t max_zero_byte_size = 0;
int64_t max_random_byte_size = 0;
for (const auto& input_meta : warmup_setting.inputs()) {
auto element_count =
triton::common::GetElementCount(input_meta.second.dims());
if (element_count == -1) {
return Status(
Status::Code::INVALID_ARG,
"warmup setting expects all variable-size dimensions are specified "
"for input '" +
input_meta.first + "'");
}

int64_t batch_byte_size =
element_count *
triton::common::GetDataTypeByteSize(input_meta.second.data_type());
if (batch_byte_size == 0) {
batch_byte_size = element_count * sizeof(int32_t);
}
int64_t batch_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
input_meta.second.data_type(), input_meta.second.dims(),
input_meta.first, &batch_byte_size));

switch (input_meta.second.input_data_type_case()) {
case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData:
Expand Down Expand Up @@ -443,14 +431,10 @@ TritonModelInstance::GenerateWarmupData()
// Second pass to prepare original inputs.
std::vector<std::shared_ptr<InferenceRequest::Input>> input_sps;
for (const auto& input_meta : warmup_setting.inputs()) {
auto batch1_element_count =
triton::common::GetElementCount(input_meta.second.dims());
auto batch_byte_size =
batch1_element_count *
triton::common::GetDataTypeByteSize(input_meta.second.data_type());
if (batch_byte_size == 0) {
batch_byte_size = batch1_element_count * sizeof(int32_t);
}
size_t batch_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
input_meta.second.data_type(), input_meta.second.dims(),
input_meta.first, reinterpret_cast<int64_t*>(&batch_byte_size)));

const char* allocated_ptr;
switch (input_meta.second.input_data_type_case()) {
Expand Down
52 changes: 32 additions & 20 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -515,9 +515,16 @@ InferenceRequest::Release(
return Status::Success;
}

InferenceRequest*
InferenceRequest::CopyAsNull(const InferenceRequest& from)
Status
InferenceRequest::CopyAsNull(
const InferenceRequest& from, std::unique_ptr<InferenceRequest>* to)
{
if (to == nullptr) {
return Status(
Status::Code::INVALID_ARG, "InferenceRequest 'to' must not be null");
}
*to = nullptr;

// Create a copy of 'from' request with artificial inputs and no requested
// outputs. Maybe more efficient to share inputs and other metadata,
// but that binds the Null request with 'from' request's lifecycle.
Expand Down Expand Up @@ -587,10 +594,10 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from)
}

if (input.second.DType() == inference::DataType::TYPE_STRING) {
int64_t element_count =
triton::common::GetElementCount(input.second.Shape());

size_t str_byte_size = static_cast<size_t>(4 * element_count);
size_t str_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
inference::DataType::TYPE_STRING, input.second.Shape(), input.first,
reinterpret_cast<int64_t*>(&str_byte_size)));
max_str_byte_size = std::max(str_byte_size, max_str_byte_size);
if (str_byte_size > max_byte_size) {
max_byte_size = str_byte_size;
Expand Down Expand Up @@ -638,11 +645,12 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from)
if (input.first == *max_input_name) {
new_input->SetData(data);
} else {
if (inference::DataType::TYPE_STRING == input.second.DType()) {
new_input->AppendData(
data_base,
triton::common::GetElementCount(input.second.Shape()) * 4, mem_type,
mem_id);
if (input.second.DType() == inference::DataType::TYPE_STRING) {
int64_t str_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
inference::DataType::TYPE_STRING, input.second.Shape(), input.first,
&str_byte_size));
new_input->AppendData(data_base, str_byte_size, mem_type, mem_id);
} else {
new_input->AppendData(
data_base, input.second.Data()->TotalByteSize(), mem_type, mem_id);
Expand All @@ -662,7 +670,8 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from)
std::make_pair(pr.second.Name(), std::addressof(pr.second)));
}

return lrequest.release();
*to = std::move(lrequest);
return Status::Success;
}

Status
Expand Down Expand Up @@ -844,8 +853,8 @@ InferenceRequest::LoadInputStates()
// Add the input states to the inference request.
if (sequence_states_ != nullptr) {
if (sequence_states_->IsNullRequest()) {
sequence_states_ =
SequenceStates::CopyAsNull(sequence_states_->NullSequenceStates());
RETURN_IF_ERROR(SequenceStates::CopyAsNull(
sequence_states_->NullSequenceStates(), &sequence_states_));
}
for (auto& input_state_pair : sequence_states_->InputStates()) {
auto& input_state = input_state_pair.second;
Expand Down Expand Up @@ -1173,14 +1182,14 @@ InferenceRequest::Normalize()
if (input_config->has_reshape()) {
std::deque<int64_t> variable_size_values;
for (int64_t idx = 0; idx < input_config->dims_size(); idx++) {
if (input_config->dims(idx) == -1) {
if (input_config->dims(idx) == triton::common::WILDCARD_DIM) {
variable_size_values.push_back((*shape)[idx]);
}
}

shape->clear();
for (const auto& dim : input_config->reshape().shape()) {
if (dim == -1) {
if (dim == triton::common::WILDCARD_DIM) {
shape->push_back(variable_size_values.front());
variable_size_values.pop_front();
} else {
Expand Down Expand Up @@ -1219,8 +1228,9 @@ InferenceRequest::Normalize()
const std::vector<int64_t>& input_dims =
input.IsShapeTensor() ? input.OriginalShape()
: input.ShapeWithBatchDim();
int64_t expected_byte_size =
triton::common::GetByteSize(data_type, input_dims);
int64_t expected_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
data_type, input_dims, input_name, &expected_byte_size));
const size_t& byte_size = input.Data()->TotalByteSize();
if ((byte_size > LLONG_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
Expand Down Expand Up @@ -1311,7 +1321,7 @@ InferenceRequest::ValidateBytesInputs(
{
const auto& input_dims = input.ShapeWithBatchDim();

int64_t element_count = triton::common::GetElementCount(input_dims);
int64_t element_count = 0;
int64_t element_checked = 0;
size_t remaining_element_size = 0;

Expand All @@ -1322,6 +1332,8 @@ InferenceRequest::ValidateBytesInputs(
size_t remaining_buffer_size = 0;
int64_t buffer_memory_id;

RETURN_IF_ERROR(GetElementCount(input_dims, input_name, &element_count));

// Validate elements until all buffers have been fully processed.
while (remaining_buffer_size || buffer_next_idx < buffer_count) {
// Get the next buffer if not currently processing one.
Expand Down
5 changes: 3 additions & 2 deletions src/infer_request.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -632,7 +632,8 @@ class InferenceRequest {
// required for the direct sequence batcher. The returned copy will
// contain only the minimum content required for a null request.
// The statistics of the copy will not be collected.
static InferenceRequest* CopyAsNull(const InferenceRequest& from);
static Status CopyAsNull(
const InferenceRequest& from, std::unique_ptr<InferenceRequest>* to);

uint64_t QueueStartNs() const { return queue_start_ns_; }
uint64_t CaptureQueueStartNs()
Expand Down
14 changes: 8 additions & 6 deletions src/model_config_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,11 @@ ValidateIOShape(
}
}

const int64_t dims_size = triton::common::GetElementCount(io.dims());
const int64_t reshape_size =
triton::common::GetElementCount(io.reshape().shape());
int64_t dims_size = 0;
int64_t reshape_size = 0;
RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size));
RETURN_IF_ERROR(
GetElementCount(io.reshape().shape(), "reshape", &reshape_size));

// dims and reshape must both have same element count
// or both have variable-size dimension.
Expand All @@ -372,12 +374,12 @@ ValidateIOShape(
// each pair of the trunks separated by variable-size dimension has
// the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6]
// is valid reshape as 2 * 4 = 8 and 6 = 1 * 6.
if (dims_size == -1) {
if (dims_size == triton::common::WILDCARD_SIZE) {
std::vector<int64_t> dim_element_cnts;
std::vector<int64_t> reshape_element_cnts;
int64_t current_cnt = 1;
for (const auto& dim : io.dims()) {
if (dim != -1) {
if (dim != triton::common::WILDCARD_DIM) {
current_cnt *= dim;
} else {
dim_element_cnts.push_back(current_cnt);
Expand All @@ -388,7 +390,7 @@ ValidateIOShape(

current_cnt = 1;
for (const auto& dim : io.reshape().shape()) {
if (dim != -1) {
if (dim != triton::common::WILDCARD_DIM) {
current_cnt *= dim;
} else {
reshape_element_cnts.push_back(current_cnt);
Expand Down
58 changes: 58 additions & 0 deletions src/model_config_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -319,4 +319,62 @@ bool EquivalentInInstanceConfig(
std::string InstanceConfigSignature(
const inference::ModelInstanceGroup& instance_config);

template <typename T>
Status
GetElementCount(const T& dims, const std::string& name, int64_t* cnt)
{
*cnt = triton::common::GetElementCount(dims);
if (*cnt == triton::common::INVALID_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"tensor '" + name + "' contains an invalid dimension in shape " +
triton::common::DimsListToString(dims));
} else if (*cnt == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG, "element count for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
} else {
return Status::Success;
}
}

template <typename T>
Status
GetByteSize(
const inference::DataType& dtype, const T& dims, const std::string& name,
int64_t* size)
{
int64_t byte_size = 0;
if (dtype == inference::DataType::TYPE_STRING) {
int64_t element_count = 0;
RETURN_IF_ERROR(GetElementCount(dims, name, &element_count));

// Total number of bytes required is equal to the element count
// multiplied by 4.
if (element_count > static_cast<int64_t>(INT64_MAX / sizeof(int32_t))) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
byte_size = sizeof(int32_t) * element_count;
} else {
byte_size = triton::common::GetByteSize(dtype, dims);
if (byte_size == triton::common::INVALID_SIZE) {
return Status(
Status::Code::INVALID_ARG,
"tensor '" + name + "' contains an invalid dimension " +
triton::common::DimsListToString(dims));
} else if (byte_size == triton::common::OVERFLOW_SIZE) {
return Status(
Status::Code::INVALID_ARG, "byte size for tensor '" + name +
"' exceeds maximum size of " +
std::to_string(INT64_MAX));
}
}
*size = byte_size;
return Status::Success;
}

}} // namespace triton::core
29 changes: 15 additions & 14 deletions src/sequence_batch_scheduler/sequence_batch_scheduler.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -384,13 +384,14 @@ SequenceBatchScheduler::GenerateInitialStateData(
auto state_dim = state.dims().begin();
for (; initial_state_dim != initial_state.dims().end();
initial_state_dim++, state_dim++) {
if (*initial_state_dim == -1) {
if (*initial_state_dim == triton::common::WILDCARD_DIM) {
return Status(
Status::Code::INVALID_ARG,
std::string("'initial_state' field for state input name '") +
state.input_name() + "' contains variable dimensions.");
} else {
if (*state_dim != -1 && *initial_state_dim != *state_dim) {
if (*state_dim != triton::common::WILDCARD_DIM &&
*initial_state_dim != *state_dim) {
return Status(
Status::Code::INVALID_ARG,
std::string("'initial_state' dim for input name '") +
Expand All @@ -404,15 +405,10 @@ SequenceBatchScheduler::GenerateInitialStateData(
}

// Calculate total memory byte size
auto element_count = triton::common::GetElementCount(initial_state.dims());
size_t dtype_byte_size =
triton::common::GetDataTypeByteSize(initial_state.data_type());
size_t total_byte_size = element_count * dtype_byte_size;

// Custom handling for TYPE_BYTES
if (dtype_byte_size == 0) {
total_byte_size = sizeof(int32_t) * element_count;
}
size_t total_byte_size = 0;
RETURN_IF_ERROR(GetByteSize(
initial_state.data_type(), initial_state.dims(), state.input_name(),
reinterpret_cast<int64_t*>(&total_byte_size)));
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing reinterpret_cast<int64_t*>(&total_byte_size) where total_byte_size is a size_t is undefined behavior and can lead to incorrect sizing. Use an int64_t temporary for GetByteSize(...) and then validate it is non-negative and fits in size_t before assigning.

Copilot uses AI. Check for mistakes.

switch (initial_state.state_data_case()) {
case inference::ModelSequenceBatching_InitialState::StateDataCase::
Expand Down Expand Up @@ -1757,8 +1753,13 @@ DirectSequenceBatch::BatcherThread(const int nice)
// Use null-request if necessary otherwise use the next
// request in the queue...
if (use_null_request) {
std::unique_ptr<InferenceRequest> ni(
InferenceRequest::CopyAsNull(*null_irequest));
std::unique_ptr<InferenceRequest> ni = nullptr;
Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni);
if (!status.IsOk()) {
LOG_ERROR
<< "internal: unexpecting failure copying null request: "
<< status.Message();
}
// Note that when the not-ready control input of the
// request is "true" the model can't assume that any
// other inputs are meaningful, including CORRID. So we
Expand Down
6 changes: 3 additions & 3 deletions src/sequence_batch_scheduler/sequence_utils.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
Expand Down Expand Up @@ -43,8 +43,8 @@ IterativeSequencer::RescheduleRequest(
else if (!request->IsCancelled()) {
// Use a null request to trigger sequence batcher cancellation so
// additional request manipulation won't affect the actual request.
std::unique_ptr<InferenceRequest> ni(
InferenceRequest::CopyAsNull(*request));
std::unique_ptr<InferenceRequest> ni = nullptr;
RETURN_IF_ERROR(InferenceRequest::CopyAsNull(*request, &ni));
ni->SetCorrelationId(request->CorrelationId());
ni->SetFlags(TRITONSERVER_REQUEST_FLAG_SEQUENCE_END);
ni->Cancel();
Expand Down
Loading
Loading