From fe9055dcb33e4bcb493885ab2a378c2f21850f6a Mon Sep 17 00:00:00 2001 From: Yingge He Date: Wed, 25 Feb 2026 20:13:07 -0800 Subject: [PATCH 01/15] Check overflow --- src/backend_model.cc | 19 +++++++++---- src/backend_model_instance.cc | 27 ++++++++++++++----- src/infer_request.cc | 21 ++++++++++++--- src/model_config_utils.cc | 17 +++++++++--- .../sequence_batch_scheduler.cc | 19 ++++++++++--- src/sequence_state.cc | 19 +++++++++++-- 6 files changed, 98 insertions(+), 24 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index cefce6b59..7fac94b4e 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -825,10 +825,19 @@ TritonModel::SetConfiguredScheduler( for (const auto& input : config_.input()) { if (input.is_shape_tensor()) { enforce_equal_shape_tensors.insert({input.name(), true}); - } else if ( - !input.allow_ragged_batch() && - (triton::common::GetElementCount(input) == -1)) { - enforce_equal_shape_tensors.insert({input.name(), false}); + } else { + auto element_count = triton::common::GetElementCount(input); + if (element_count == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "input '" + input.name() + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); + } + if (!input.allow_ragged_batch() && + (element_count == triton::common::WILDCARD_SIZE)) { + enforce_equal_shape_tensors.insert({input.name(), false}); + } } } diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index b5f595c87..0973403b4 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -374,14 +374,19 @@ TritonModelInstance::GenerateWarmupData() for (const auto& input_meta : warmup_setting.inputs()) { auto element_count = triton::common::GetElementCount(input_meta.second.dims()); - if (element_count == -1) { + if (element_count == triton::common::WILDCARD_SIZE) { return Status( Status::Code::INVALID_ARG, "warmup setting expects all variable-size dimensions are specified " "for input '" + input_meta.first + "'"); + } else if (element_count == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting for input '" + input_meta.first + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); } - int64_t batch_byte_size = element_count * triton::common::GetDataTypeByteSize(input_meta.second.data_type()); @@ -445,12 +450,20 @@ TritonModelInstance::GenerateWarmupData() for (const auto& input_meta : warmup_setting.inputs()) { auto batch1_element_count = triton::common::GetElementCount(input_meta.second.dims()); - auto batch_byte_size = - batch1_element_count * + auto dtype_byte_size = triton::common::GetDataTypeByteSize(input_meta.second.data_type()); - if (batch_byte_size == 0) { - batch_byte_size = batch1_element_count * sizeof(int32_t); + dtype_byte_size = + dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; + if (batch1_element_count == triton::common::OVERFLOW_SIZE || + (batch1_element_count > + INT64_MAX / static_cast(dtype_byte_size))) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting for input '" + input_meta.first + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); } + auto batch_byte_size = batch1_element_count * dtype_byte_size; const char* allocated_ptr; switch (input_meta.second.input_data_type_case()) { diff --git a/src/infer_request.cc b/src/infer_request.cc index 41074effc..e9e4f3446 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1,4 +1,4 @@ -// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -1173,14 +1173,14 @@ InferenceRequest::Normalize() if (input_config->has_reshape()) { std::deque variable_size_values; for (int64_t idx = 0; idx < input_config->dims_size(); idx++) { - if (input_config->dims(idx) == -1) { + if (input_config->dims(idx) == triton::common::WILDCARD_DIM) { variable_size_values.push_back((*shape)[idx]); } } shape->clear(); for (const auto& dim : input_config->reshape().shape()) { - if (dim == -1) { + if (dim == triton::common::WILDCARD_DIM) { shape->push_back(variable_size_values.front()); variable_size_values.pop_front(); } else { @@ -1222,6 +1222,13 @@ InferenceRequest::Normalize() int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); const size_t& byte_size = input.Data()->TotalByteSize(); + if (expected_byte_size == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input_name + + "' causes total byte size to exceed maximum size of " + + std::to_string(INT64_MAX)); + } if ((byte_size > LLONG_MAX) || (static_cast(byte_size) != expected_byte_size)) { return Status( @@ -1322,6 +1329,14 @@ InferenceRequest::ValidateBytesInputs( size_t remaining_buffer_size = 0; int64_t buffer_memory_id; + if (element_count == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input_name + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); + } + // Validate elements until all buffers have been fully processed. while (remaining_buffer_size || buffer_next_idx < buffer_count) { // Get the next buffer if not currently processing one. diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index a182fe397..6bb76c6bc 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -1,4 +1,4 @@ -// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -357,6 +357,15 @@ ValidateIOShape( const int64_t reshape_size = triton::common::GetElementCount(io.reshape().shape()); + if (dims_size == triton::common::OVERFLOW_SIZE || + reshape_size == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, + message_prefix_with_name + + "causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); + } + // dims and reshape must both have same element count // or both have variable-size dimension. // Special case for empty reshape... expect dims to have element @@ -372,12 +381,12 @@ ValidateIOShape( // each pair of the trunks separated by variable-size dimension has // the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6] // is valid reshape as 2 * 4 = 8 and 6 = 1 * 6. - if (dims_size == -1) { + if (dims_size == triton::common::WILDCARD_DIM) { std::vector dim_element_cnts; std::vector reshape_element_cnts; int64_t current_cnt = 1; for (const auto& dim : io.dims()) { - if (dim != -1) { + if (dim != triton::common::WILDCARD_DIM) { current_cnt *= dim; } else { dim_element_cnts.push_back(current_cnt); @@ -388,7 +397,7 @@ ValidateIOShape( current_cnt = 1; for (const auto& dim : io.reshape().shape()) { - if (dim != -1) { + if (dim != triton::common::WILDCARD_DIM) { current_cnt *= dim; } else { reshape_element_cnts.push_back(current_cnt); diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 45e9c037c..263d213ac 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -1,4 +1,4 @@ -// Copyright 2018-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -384,13 +384,14 @@ SequenceBatchScheduler::GenerateInitialStateData( auto state_dim = state.dims().begin(); for (; initial_state_dim != initial_state.dims().end(); initial_state_dim++, state_dim++) { - if (*initial_state_dim == -1) { + if (*initial_state_dim == triton::common::WILDCARD_DIM) { return Status( Status::Code::INVALID_ARG, std::string("'initial_state' field for state input name '") + state.input_name() + "' contains variable dimensions."); } else { - if (*state_dim != -1 && *initial_state_dim != *state_dim) { + if (*state_dim != triton::common::WILDCARD_DIM && + *initial_state_dim != *state_dim) { return Status( Status::Code::INVALID_ARG, std::string("'initial_state' dim for input name '") + @@ -409,6 +410,18 @@ SequenceBatchScheduler::GenerateInitialStateData( triton::common::GetDataTypeByteSize(initial_state.data_type()); size_t total_byte_size = element_count * dtype_byte_size; + if (element_count == triton::common::OVERFLOW_SIZE || + (dtype_byte_size != 0 && + (element_count > INT64_MAX / (static_cast(dtype_byte_size)))) || + (total_byte_size > INT64_MAX / sizeof(int32_t))) { + return Status( + Status::Code::INVALID_ARG, + std::string("'initial_state' field for state input name '") + + state.input_name() + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); + } + // Custom handling for TYPE_BYTES if (dtype_byte_size == 0) { total_byte_size = sizeof(int32_t) * element_count; diff --git a/src/sequence_state.cc b/src/sequence_state.cc index e1c4dc13d..23b7a3abe 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -159,7 +159,7 @@ SequenceStates::Initialize( // Convert the variable dimensions to 1 for the first request. for (auto& dim : state_config.dims()) { - if (dim == -1) { + if (dim == triton::common::WILDCARD_DIM) { dims.push_back(1); } else { dims.push_back(dim); @@ -212,12 +212,27 @@ SequenceStates::Initialize( size_t state_size; if (state.second.data_type() == inference::DataType::TYPE_STRING) { auto element_count = triton::common::GetElementCount(dims); + if (element_count == triton::common::OVERFLOW_SIZE || + (element_count > INT64_MAX / 4)) { + return Status( + Status::Code::INVALID_ARG, + "state '" + state_config.input_name() + + "' causes total element count to exceed maximum size of " + + std::to_string(INT64_MAX)); + } // Total number of bytes required is equal to the element count // multiplied by 4. state_size = 4 * element_count; } else { state_size = triton::common::GetByteSize(state.second.data_type(), dims); + if (state_size == static_cast(triton::common::OVERFLOW_SIZE)) { + return Status( + Status::Code::INVALID_ARG, + "state '" + state_config.input_name() + + "' causes total byte size to exceed maximum size of " + + std::to_string(INT64_MAX)); + } } if (use_growable_memory) { std::unique_ptr growable_memory; From 78be987e79402ac1b1576fd794699a9139a3896e Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 26 Feb 2026 02:42:27 -0800 Subject: [PATCH 02/15] Address copilot comments --- src/model_config_utils.cc | 2 +- .../sequence_batch_scheduler.cc | 14 ++++---------- src/sequence_state.cc | 5 +++-- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index 6bb76c6bc..7bc4415e4 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -381,7 +381,7 @@ ValidateIOShape( // each pair of the trunks separated by variable-size dimension has // the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6] // is valid reshape as 2 * 4 = 8 and 6 = 1 * 6. - if (dims_size == triton::common::WILDCARD_DIM) { + if (dims_size == triton::common::WILDCARD_SIZE) { std::vector dim_element_cnts; std::vector reshape_element_cnts; int64_t current_cnt = 1; diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 263d213ac..81718c368 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -408,24 +408,18 @@ SequenceBatchScheduler::GenerateInitialStateData( auto element_count = triton::common::GetElementCount(initial_state.dims()); size_t dtype_byte_size = triton::common::GetDataTypeByteSize(initial_state.data_type()); - size_t total_byte_size = element_count * dtype_byte_size; - + dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; if (element_count == triton::common::OVERFLOW_SIZE || - (dtype_byte_size != 0 && - (element_count > INT64_MAX / (static_cast(dtype_byte_size)))) || - (total_byte_size > INT64_MAX / sizeof(int32_t))) { + (static_cast(element_count) > SIZE_MAX / dtype_byte_size)) { return Status( Status::Code::INVALID_ARG, std::string("'initial_state' field for state input name '") + state.input_name() + "' causes total element count to exceed maximum size of " + - std::to_string(INT64_MAX)); + std::to_string(SIZE_MAX)); } - // Custom handling for TYPE_BYTES - if (dtype_byte_size == 0) { - total_byte_size = sizeof(int32_t) * element_count; - } + size_t total_byte_size = static_cast(element_count) * dtype_byte_size; switch (initial_state.state_data_case()) { case inference::ModelSequenceBatching_InitialState::StateDataCase:: diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 23b7a3abe..e94e98fcd 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -224,15 +224,16 @@ SequenceStates::Initialize( // multiplied by 4. state_size = 4 * element_count; } else { - state_size = + auto byte_size = triton::common::GetByteSize(state.second.data_type(), dims); - if (state_size == static_cast(triton::common::OVERFLOW_SIZE)) { + if (byte_size == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, "state '" + state_config.input_name() + "' causes total byte size to exceed maximum size of " + std::to_string(INT64_MAX)); } + state_size = static_cast(byte_size); } if (use_growable_memory) { std::unique_ptr growable_memory; From 82d93f15e8de83b896e04831b6bcabba8cfd2bb7 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Thu, 26 Feb 2026 03:02:46 -0800 Subject: [PATCH 03/15] Change error message --- src/backend_model.cc | 3 ++- src/backend_model_instance.cc | 21 ++++++++++++------- .../sequence_batch_scheduler.cc | 3 ++- src/sequence_state.cc | 3 ++- 4 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index 7fac94b4e..066e263d7 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -831,7 +831,8 @@ TritonModel::SetConfiguredScheduler( return Status( Status::Code::INVALID_ARG, "input '" + input.name() + - "' causes total element count to exceed maximum size of " + + "' causes total element count or byte size to exceed maximum " + "size of " + std::to_string(INT64_MAX)); } if (!input.allow_ragged_batch() && diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index 0973403b4..a742ed366 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -374,25 +374,29 @@ TritonModelInstance::GenerateWarmupData() for (const auto& input_meta : warmup_setting.inputs()) { auto element_count = triton::common::GetElementCount(input_meta.second.dims()); + size_t dtype_byte_size = + triton::common::GetDataTypeByteSize(input_meta.second.data_type()); + dtype_byte_size = + dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; + if (element_count == triton::common::WILDCARD_SIZE) { return Status( Status::Code::INVALID_ARG, "warmup setting expects all variable-size dimensions are specified " "for input '" + input_meta.first + "'"); - } else if (element_count == triton::common::OVERFLOW_SIZE) { + } else if ( + element_count == triton::common::OVERFLOW_SIZE || + (element_count > INT64_MAX / static_cast(dtype_byte_size))) { return Status( Status::Code::INVALID_ARG, "warmup setting for input '" + input_meta.first + - "' causes total element count to exceed maximum size of " + + "' causes total element count or byte size to exceed maximum " + "size of " + std::to_string(INT64_MAX)); } int64_t batch_byte_size = - element_count * - triton::common::GetDataTypeByteSize(input_meta.second.data_type()); - if (batch_byte_size == 0) { - batch_byte_size = element_count * sizeof(int32_t); - } + element_count * static_cast(dtype_byte_size); switch (input_meta.second.input_data_type_case()) { case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: @@ -460,7 +464,8 @@ TritonModelInstance::GenerateWarmupData() return Status( Status::Code::INVALID_ARG, "warmup setting for input '" + input_meta.first + - "' causes total element count to exceed maximum size of " + + "' causes total element count or byte size to exceed maximum " + "size of " + std::to_string(INT64_MAX)); } auto batch_byte_size = batch1_element_count * dtype_byte_size; diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 81718c368..e6cc57907 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -415,7 +415,8 @@ SequenceBatchScheduler::GenerateInitialStateData( Status::Code::INVALID_ARG, std::string("'initial_state' field for state input name '") + state.input_name() + - "' causes total element count to exceed maximum size of " + + "' causes total element count or byte size to exceed maximum size " + "of " + std::to_string(SIZE_MAX)); } diff --git a/src/sequence_state.cc b/src/sequence_state.cc index e94e98fcd..77a1f845c 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -217,7 +217,8 @@ SequenceStates::Initialize( return Status( Status::Code::INVALID_ARG, "state '" + state_config.input_name() + - "' causes total element count to exceed maximum size of " + + "' causes total element count or byte size to exceed maximum " + "size of " + std::to_string(INT64_MAX)); } // Total number of bytes required is equal to the element count From 300957aa9c58901a9c491b19b9889606d1d61685 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 27 Feb 2026 20:41:45 -0800 Subject: [PATCH 04/15] Address review comments --- src/backend_model.cc | 8 ++++++- src/backend_model_instance.cc | 18 ++++++++++++-- src/infer_request.cc | 24 +++++++++++++++---- src/model_config_utils.cc | 12 +++++++++- .../sequence_batch_scheduler.cc | 10 +++++++- src/sequence_state.cc | 16 +++++++++---- 6 files changed, 74 insertions(+), 14 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index 066e263d7..ba68366ec 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -827,7 +827,13 @@ TritonModel::SetConfiguredScheduler( enforce_equal_shape_tensors.insert({input.name(), true}); } else { auto element_count = triton::common::GetElementCount(input); - if (element_count == triton::common::OVERFLOW_SIZE) { + if (element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "input '" + input.name() + "' shape " + + triton::common::DimsListToString(input.dims()) + + " contains an invalid dimension"); + } else if (element_count == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, "input '" + input.name() + diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index a742ed366..c83d11b8e 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -379,7 +379,13 @@ TritonModelInstance::GenerateWarmupData() dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (element_count == triton::common::WILDCARD_SIZE) { + if (element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting for input '" + input_meta.first + "' shape " + + triton::common::DimsListToString(input_meta.second.dims()) + + " contains an invalid dimension"); + } else if (element_count == triton::common::WILDCARD_SIZE) { return Status( Status::Code::INVALID_ARG, "warmup setting expects all variable-size dimensions are specified " @@ -458,7 +464,15 @@ TritonModelInstance::GenerateWarmupData() triton::common::GetDataTypeByteSize(input_meta.second.data_type()); dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (batch1_element_count == triton::common::OVERFLOW_SIZE || + + if (batch1_element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting for input '" + input_meta.first + "' shape " + + triton::common::DimsListToString(input_meta.second.dims()) + + " contains an invalid dimension"); + } else if ( + batch1_element_count == triton::common::OVERFLOW_SIZE || (batch1_element_count > INT64_MAX / static_cast(dtype_byte_size))) { return Status( diff --git a/src/infer_request.cc b/src/infer_request.cc index e9e4f3446..7d2ea062b 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -590,7 +590,8 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from) int64_t element_count = triton::common::GetElementCount(input.second.Shape()); - size_t str_byte_size = static_cast(4 * element_count); + size_t str_byte_size = + static_cast(sizeof(int32_t) * element_count); max_str_byte_size = std::max(str_byte_size, max_str_byte_size); if (str_byte_size > max_byte_size) { max_byte_size = str_byte_size; @@ -641,8 +642,9 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from) if (inference::DataType::TYPE_STRING == input.second.DType()) { new_input->AppendData( data_base, - triton::common::GetElementCount(input.second.Shape()) * 4, mem_type, - mem_id); + triton::common::GetElementCount(input.second.Shape()) * + sizeof(int32_t), + mem_type, mem_id); } else { new_input->AppendData( data_base, input.second.Data()->TotalByteSize(), mem_type, mem_id); @@ -1222,7 +1224,13 @@ InferenceRequest::Normalize() int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); const size_t& byte_size = input.Data()->TotalByteSize(); - if (expected_byte_size == triton::common::OVERFLOW_SIZE) { + if (expected_byte_size == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input_name + "' shape " + + triton::common::DimsListToString(input_dims) + + " contains an invalid dimension"); + } else if (expected_byte_size == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, LogRequest() + "input '" + input_name + @@ -1329,7 +1337,13 @@ InferenceRequest::ValidateBytesInputs( size_t remaining_buffer_size = 0; int64_t buffer_memory_id; - if (element_count == triton::common::OVERFLOW_SIZE) { + if (element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input '" + input_name + "' shape " + + triton::common::DimsListToString(input_dims) + + " contains an invalid dimension"); + } else if (element_count == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, LogRequest() + "input '" + input_name + diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index 7bc4415e4..293eca880 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -357,7 +357,17 @@ ValidateIOShape( const int64_t reshape_size = triton::common::GetElementCount(io.reshape().shape()); - if (dims_size == triton::common::OVERFLOW_SIZE || + if (dims_size == triton::common::INVALID_SIZE || + reshape_size == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + message_prefix_with_name + "shape " + + triton::common::DimsListToString(io.dims()) + + " or reshaped shape " + + triton::common::DimsListToString(io.reshape().shape()) + + " contains an invalid dimension"); + } else if ( + dims_size == triton::common::OVERFLOW_SIZE || reshape_size == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index e6cc57907..73bca49c0 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -409,7 +409,15 @@ SequenceBatchScheduler::GenerateInitialStateData( size_t dtype_byte_size = triton::common::GetDataTypeByteSize(initial_state.data_type()); dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (element_count == triton::common::OVERFLOW_SIZE || + if (element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + std::string("'initial_state' field for state input name '") + + state.input_name() + "' shape " + + triton::common::DimsListToString(initial_state.dims()) + + " contains an invalid dimension"); + } else if ( + element_count == triton::common::OVERFLOW_SIZE || (static_cast(element_count) > SIZE_MAX / dtype_byte_size)) { return Status( Status::Code::INVALID_ARG, diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 77a1f845c..1b35092d7 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -212,8 +212,16 @@ SequenceStates::Initialize( size_t state_size; if (state.second.data_type() == inference::DataType::TYPE_STRING) { auto element_count = triton::common::GetElementCount(dims); - if (element_count == triton::common::OVERFLOW_SIZE || - (element_count > INT64_MAX / 4)) { + if (element_count == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "state '" + state_config.input_name() + "' shape " + + triton::common::DimsListToString(dims) + + " contains an invalid dimension"); + } else if ( + element_count == triton::common::OVERFLOW_SIZE || + (element_count > + INT64_MAX / static_cast(sizeof(int32_t)))) { return Status( Status::Code::INVALID_ARG, "state '" + state_config.input_name() + @@ -223,7 +231,7 @@ SequenceStates::Initialize( } // Total number of bytes required is equal to the element count // multiplied by 4. - state_size = 4 * element_count; + state_size = sizeof(int32_t) * element_count; } else { auto byte_size = triton::common::GetByteSize(state.second.data_type(), dims); @@ -414,7 +422,7 @@ SequenceStates::CopyAsNull(const std::shared_ptr& from) // Use all-zero input states for null requests. auto element_count = triton::common::GetElementCount(from_input_state_tensor->Shape()); - auto state_size = 4 * element_count; + auto state_size = sizeof(int32_t) * element_count; data = std::make_shared( state_size, TRITONSERVER_MEMORY_CPU, 0); } else { From c0b41cecaf72f8334111aa80e6d4514058706912 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 27 Feb 2026 21:09:27 -0800 Subject: [PATCH 05/15] Address comments --- src/backend_model.cc | 3 +-- src/backend_model_instance.cc | 3 ++- src/sequence_state.cc | 8 +++++++- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index ba68366ec..4814c1db7 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -837,8 +837,7 @@ TritonModel::SetConfiguredScheduler( return Status( Status::Code::INVALID_ARG, "input '" + input.name() + - "' causes total element count or byte size to exceed maximum " - "size of " + + "' causes total element count to exceed maximum size of " + std::to_string(INT64_MAX)); } if (!input.allow_ragged_batch() && diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index c83d11b8e..e4883a189 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -482,7 +482,8 @@ TritonModelInstance::GenerateWarmupData() "size of " + std::to_string(INT64_MAX)); } - auto batch_byte_size = batch1_element_count * dtype_byte_size; + auto batch_byte_size = + batch1_element_count * static_cast(dtype_byte_size); const char* allocated_ptr; switch (input_meta.second.input_data_type_case()) { diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 1b35092d7..2e9ab2a1a 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -235,7 +235,13 @@ SequenceStates::Initialize( } else { auto byte_size = triton::common::GetByteSize(state.second.data_type(), dims); - if (byte_size == triton::common::OVERFLOW_SIZE) { + if (byte_size == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "state '" + state_config.input_name() + "' shape " + + triton::common::DimsListToString(dims) + + " contains an invalid dimension"); + } else if (byte_size == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, "state '" + state_config.input_name() + From d43597a94cc94e3ab7ff055fc5c0b3fa7e51dc83 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 27 Feb 2026 21:22:44 -0800 Subject: [PATCH 06/15] Address comments --- src/sequence_state.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 2e9ab2a1a..5a87c440a 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -231,7 +231,7 @@ SequenceStates::Initialize( } // Total number of bytes required is equal to the element count // multiplied by 4. - state_size = sizeof(int32_t) * element_count; + state_size = sizeof(int32_t) * static_cast(element_count); } else { auto byte_size = triton::common::GetByteSize(state.second.data_type(), dims); From eb9d8ca9e1718bef9a83fa019a081b5c11b217d1 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 27 Feb 2026 21:26:53 -0800 Subject: [PATCH 07/15] Address comment --- src/sequence_state.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 5a87c440a..8d8e1169e 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -220,8 +220,7 @@ SequenceStates::Initialize( " contains an invalid dimension"); } else if ( element_count == triton::common::OVERFLOW_SIZE || - (element_count > - INT64_MAX / static_cast(sizeof(int32_t)))) { + (static_cast(element_count) > SIZE_MAX / sizeof(int32_t))) { return Status( Status::Code::INVALID_ARG, "state '" + state_config.input_name() + From 234426861b8afd0457351566c2226096bfb43362 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 11:46:56 -0800 Subject: [PATCH 08/15] Refactor to use safe function calls --- src/backend_model.cc | 17 +--- src/backend_model_instance.cc | 65 +++++---------- src/infer_request.cc | 82 ++++++++++--------- src/infer_request.h | 3 +- src/model_config_utils.cc | 27 ++---- src/model_config_utils.h | 43 +++++++++- .../sequence_batch_scheduler.cc | 32 +++----- .../sequence_utils.cc | 4 +- src/sequence_state.cc | 75 ++++++++--------- src/sequence_state.h | 8 +- 10 files changed, 172 insertions(+), 184 deletions(-) diff --git a/src/backend_model.cc b/src/backend_model.cc index 4814c1db7..c3b0fc2dc 100644 --- a/src/backend_model.cc +++ b/src/backend_model.cc @@ -826,20 +826,9 @@ TritonModel::SetConfiguredScheduler( if (input.is_shape_tensor()) { enforce_equal_shape_tensors.insert({input.name(), true}); } else { - auto element_count = triton::common::GetElementCount(input); - if (element_count == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "input '" + input.name() + "' shape " + - triton::common::DimsListToString(input.dims()) + - " contains an invalid dimension"); - } else if (element_count == triton::common::OVERFLOW_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "input '" + input.name() + - "' causes total element count to exceed maximum size of " + - std::to_string(INT64_MAX)); - } + int64_t element_count = 0; + RETURN_IF_ERROR( + GetElementCount(input.dims(), input.name(), &element_count)); if (!input.allow_ragged_batch() && (element_count == triton::common::WILDCARD_SIZE)) { enforce_equal_shape_tensors.insert({input.name(), false}); diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index e4883a189..73710339f 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -372,37 +372,20 @@ TritonModelInstance::GenerateWarmupData() int64_t max_zero_byte_size = 0; int64_t max_random_byte_size = 0; for (const auto& input_meta : warmup_setting.inputs()) { - auto element_count = - triton::common::GetElementCount(input_meta.second.dims()); - size_t dtype_byte_size = + int64_t element_count = 0; + RETURN_IF_ERROR(GetElementCount( + input_meta.second.dims(), input_meta.first, &element_count)); + int64_t dtype_byte_size = triton::common::GetDataTypeByteSize(input_meta.second.data_type()); dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - - if (element_count == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "warmup setting for input '" + input_meta.first + "' shape " + - triton::common::DimsListToString(input_meta.second.dims()) + - " contains an invalid dimension"); - } else if (element_count == triton::common::WILDCARD_SIZE) { + if (element_count > INT64_MAX / dtype_byte_size) { return Status( Status::Code::INVALID_ARG, - "warmup setting expects all variable-size dimensions are specified " - "for input '" + - input_meta.first + "'"); - } else if ( - element_count == triton::common::OVERFLOW_SIZE || - (element_count > INT64_MAX / static_cast(dtype_byte_size))) { - return Status( - Status::Code::INVALID_ARG, - "warmup setting for input '" + input_meta.first + - "' causes total element count or byte size to exceed maximum " - "size of " + - std::to_string(INT64_MAX)); + "element count for input '" + input_meta.first + + "' exceeds maximum size of " + std::to_string(INT64_MAX)); } - int64_t batch_byte_size = - element_count * static_cast(dtype_byte_size); + int64_t batch_byte_size = element_count * dtype_byte_size; switch (input_meta.second.input_data_type_case()) { case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: @@ -458,32 +441,24 @@ TritonModelInstance::GenerateWarmupData() // Second pass to prepare original inputs. std::vector> input_sps; for (const auto& input_meta : warmup_setting.inputs()) { - auto batch1_element_count = - triton::common::GetElementCount(input_meta.second.dims()); - auto dtype_byte_size = - triton::common::GetDataTypeByteSize(input_meta.second.data_type()); + int64_t batch1_element_count, dtype_byte_size = 0; + RETURN_IF_ERROR(GetElementCount( + input_meta.second.dims(), input_meta.first, &batch1_element_count)); + RETURN_IF_ERROR(GetByteSize( + input_meta.second.data_type(), input_meta.second.dims(), + input_meta.first, &dtype_byte_size)); dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (batch1_element_count == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "warmup setting for input '" + input_meta.first + "' shape " + - triton::common::DimsListToString(input_meta.second.dims()) + - " contains an invalid dimension"); - } else if ( - batch1_element_count == triton::common::OVERFLOW_SIZE || - (batch1_element_count > - INT64_MAX / static_cast(dtype_byte_size))) { + if (static_cast(batch1_element_count) > + SIZE_MAX / dtype_byte_size) { return Status( Status::Code::INVALID_ARG, - "warmup setting for input '" + input_meta.first + - "' causes total element count or byte size to exceed maximum " - "size of " + - std::to_string(INT64_MAX)); + "element count for input '" + input_meta.first + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); } - auto batch_byte_size = - batch1_element_count * static_cast(dtype_byte_size); + size_t batch_byte_size = + static_cast(batch1_element_count) * dtype_byte_size; const char* allocated_ptr; switch (input_meta.second.input_data_type_case()) { diff --git a/src/infer_request.cc b/src/infer_request.cc index 7d2ea062b..2b310400b 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -515,9 +515,16 @@ InferenceRequest::Release( return Status::Success; } -InferenceRequest* -InferenceRequest::CopyAsNull(const InferenceRequest& from) +Status +InferenceRequest::CopyAsNull( + const InferenceRequest& from, std::unique_ptr* to) { + if (to == nullptr) { + return Status( + Status::Code::INVALID_ARG, "InferenceRequest 'to' must not be null"); + } + *to = nullptr; + // Create a copy of 'from' request with artificial inputs and no requested // outputs. Maybe more efficient to share inputs and other metadata, // but that binds the Null request with 'from' request's lifecycle. @@ -587,9 +594,15 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from) } if (input.second.DType() == inference::DataType::TYPE_STRING) { - int64_t element_count = - triton::common::GetElementCount(input.second.Shape()); - + int64_t element_count = 0; + RETURN_IF_ERROR( + GetElementCount(input.second.Shape(), input.first, &element_count)); + if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { + return Status( + Status::Code::INVALID_ARG, + "element count for input '" + input.first + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); + } size_t str_byte_size = static_cast(sizeof(int32_t) * element_count); max_str_byte_size = std::max(str_byte_size, max_str_byte_size); @@ -640,10 +653,17 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from) new_input->SetData(data); } else { if (inference::DataType::TYPE_STRING == input.second.DType()) { + int64_t element_count = 0; + RETURN_IF_ERROR( + GetElementCount(input.second.Shape(), input.first, &element_count)); + if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { + return Status( + Status::Code::INVALID_ARG, + "element count for input '" + input.first + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); + } new_input->AppendData( - data_base, - triton::common::GetElementCount(input.second.Shape()) * - sizeof(int32_t), + data_base, static_cast(element_count) * sizeof(int32_t), mem_type, mem_id); } else { new_input->AppendData( @@ -664,7 +684,8 @@ InferenceRequest::CopyAsNull(const InferenceRequest& from) std::make_pair(pr.second.Name(), std::addressof(pr.second))); } - return lrequest.release(); + *to = std::move(lrequest); + return Status::Success; } Status @@ -846,8 +867,13 @@ InferenceRequest::LoadInputStates() // Add the input states to the inference request. if (sequence_states_ != nullptr) { if (sequence_states_->IsNullRequest()) { - sequence_states_ = - SequenceStates::CopyAsNull(sequence_states_->NullSequenceStates()); + std::shared_ptr copied; + Status status = SequenceStates::CopyAsNull( + sequence_states_->NullSequenceStates(), &copied); + if (!status.IsOk()) { + return status; + } + sequence_states_ = copied; } for (auto& input_state_pair : sequence_states_->InputStates()) { auto& input_state = input_state_pair.second; @@ -1221,22 +1247,10 @@ InferenceRequest::Normalize() const std::vector& input_dims = input.IsShapeTensor() ? input.OriginalShape() : input.ShapeWithBatchDim(); - int64_t expected_byte_size = - triton::common::GetByteSize(data_type, input_dims); + int64_t expected_byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + data_type, input_dims, input_name, &expected_byte_size)); const size_t& byte_size = input.Data()->TotalByteSize(); - if (expected_byte_size == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "input '" + input_name + "' shape " + - triton::common::DimsListToString(input_dims) + - " contains an invalid dimension"); - } else if (expected_byte_size == triton::common::OVERFLOW_SIZE) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "input '" + input_name + - "' causes total byte size to exceed maximum size of " + - std::to_string(INT64_MAX)); - } if ((byte_size > LLONG_MAX) || (static_cast(byte_size) != expected_byte_size)) { return Status( @@ -1326,7 +1340,7 @@ InferenceRequest::ValidateBytesInputs( { const auto& input_dims = input.ShapeWithBatchDim(); - int64_t element_count = triton::common::GetElementCount(input_dims); + int64_t element_count = 0; int64_t element_checked = 0; size_t remaining_element_size = 0; @@ -1337,19 +1351,7 @@ InferenceRequest::ValidateBytesInputs( size_t remaining_buffer_size = 0; int64_t buffer_memory_id; - if (element_count == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "input '" + input_name + "' shape " + - triton::common::DimsListToString(input_dims) + - " contains an invalid dimension"); - } else if (element_count == triton::common::OVERFLOW_SIZE) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "input '" + input_name + - "' causes total element count to exceed maximum size of " + - std::to_string(INT64_MAX)); - } + RETURN_IF_ERROR(GetElementCount(input_dims, input_name, &element_count)); // Validate elements until all buffers have been fully processed. while (remaining_buffer_size || buffer_next_idx < buffer_count) { diff --git a/src/infer_request.h b/src/infer_request.h index 1c7e83d6d..e5c9fff87 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -632,7 +632,8 @@ class InferenceRequest { // required for the direct sequence batcher. The returned copy will // contain only the minimum content required for a null request. // The statistics of the copy will not be collected. - static InferenceRequest* CopyAsNull(const InferenceRequest& from); + static Status CopyAsNull( + const InferenceRequest& from, std::unique_ptr* to); uint64_t QueueStartNs() const { return queue_start_ns_; } uint64_t CaptureQueueStartNs() diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index 293eca880..045a70ed2 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -353,28 +353,11 @@ ValidateIOShape( } } - const int64_t dims_size = triton::common::GetElementCount(io.dims()); - const int64_t reshape_size = - triton::common::GetElementCount(io.reshape().shape()); - - if (dims_size == triton::common::INVALID_SIZE || - reshape_size == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - message_prefix_with_name + "shape " + - triton::common::DimsListToString(io.dims()) + - " or reshaped shape " + - triton::common::DimsListToString(io.reshape().shape()) + - " contains an invalid dimension"); - } else if ( - dims_size == triton::common::OVERFLOW_SIZE || - reshape_size == triton::common::OVERFLOW_SIZE) { - return Status( - Status::Code::INVALID_ARG, - message_prefix_with_name + - "causes total element count to exceed maximum size of " + - std::to_string(INT64_MAX)); - } + int64_t dims_size = 0; + int64_t reshape_size = 0; + RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size)); + RETURN_IF_ERROR( + GetElementCount(io.reshape().shape(), "reshape", &reshape_size)); // dims and reshape must both have same element count // or both have variable-size dimension. diff --git a/src/model_config_utils.h b/src/model_config_utils.h index ba4d75636..891a2345e 100644 --- a/src/model_config_utils.h +++ b/src/model_config_utils.h @@ -1,4 +1,4 @@ -// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -318,4 +318,45 @@ bool EquivalentInInstanceConfig( std::string InstanceConfigSignature( const inference::ModelInstanceGroup& instance_config); +template +Status +GetElementCount(const T& dims, const std::string& name, int64_t* cnt) +{ + *cnt = triton::common::GetElementCount(dims); + if (*cnt == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "tensor '" + name + "' contains an invalid dimension in shape " + + triton::common::DimsListToString(dims)); + } else if (*cnt == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, "element count for tensor '" + name + + "' exceeds maximum size of " + + std::to_string(INT64_MAX)); + } else { + return Status::Success; + } +} + +template +Status +GetByteSize( + const inference::DataType& dtype, const T& dims, const std::string& name, + int64_t* size) +{ + *size = triton::common::GetByteSize(dtype, dims); + if (*size == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, "tensor '" + name + + "' contains an invalid dimension " + + triton::common::DimsListToString(dims)); + } else if (*size == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, "byte size for tensor '" + name + + "' exceeds maximum size of " + + std::to_string(INT64_MAX)); + } else { + return Status::Success; + } +} }} // namespace triton::core diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 73bca49c0..df93594e2 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -405,29 +405,18 @@ SequenceBatchScheduler::GenerateInitialStateData( } // Calculate total memory byte size - auto element_count = triton::common::GetElementCount(initial_state.dims()); + int64_t element_count = 0; + RETURN_IF_ERROR(GetElementCount( + initial_state.dims(), state.input_name(), &element_count)); size_t dtype_byte_size = triton::common::GetDataTypeByteSize(initial_state.data_type()); dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (element_count == triton::common::INVALID_SIZE) { + if (static_cast(element_count) > SIZE_MAX / dtype_byte_size) { return Status( Status::Code::INVALID_ARG, - std::string("'initial_state' field for state input name '") + - state.input_name() + "' shape " + - triton::common::DimsListToString(initial_state.dims()) + - " contains an invalid dimension"); - } else if ( - element_count == triton::common::OVERFLOW_SIZE || - (static_cast(element_count) > SIZE_MAX / dtype_byte_size)) { - return Status( - Status::Code::INVALID_ARG, - std::string("'initial_state' field for state input name '") + - state.input_name() + - "' causes total element count or byte size to exceed maximum size " - "of " + - std::to_string(SIZE_MAX)); + "element count for input '" + state.input_name() + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); } - size_t total_byte_size = static_cast(element_count) * dtype_byte_size; switch (initial_state.state_data_case()) { @@ -1773,8 +1762,13 @@ DirectSequenceBatch::BatcherThread(const int nice) // Use null-request if necessary otherwise use the next // request in the queue... if (use_null_request) { - std::unique_ptr ni( - InferenceRequest::CopyAsNull(*null_irequest)); + std::unique_ptr ni = nullptr; + Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni); + if (!status.IsOk()) { + LOG_ERROR + << "internal: unexpecting failure copying null request: " + << status.Message(); + } // Note that when the not-ready control input of the // request is "true" the model can't assume that any // other inputs are meaningful, including CORRID. So we diff --git a/src/sequence_batch_scheduler/sequence_utils.cc b/src/sequence_batch_scheduler/sequence_utils.cc index c916ccca7..b469f1472 100644 --- a/src/sequence_batch_scheduler/sequence_utils.cc +++ b/src/sequence_batch_scheduler/sequence_utils.cc @@ -43,8 +43,8 @@ IterativeSequencer::RescheduleRequest( else if (!request->IsCancelled()) { // Use a null request to trigger sequence batcher cancellation so // additional request manipulation won't affect the actual request. - std::unique_ptr ni( - InferenceRequest::CopyAsNull(*request)); + std::unique_ptr ni = nullptr; + RETURN_IF_ERROR(InferenceRequest::CopyAsNull(*request, &ni)); ni->SetCorrelationId(request->CorrelationId()); ni->SetFlags(TRITONSERVER_REQUEST_FLAG_SEQUENCE_END); ni->Cancel(); diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 8d8e1169e..1939124f5 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -26,8 +26,11 @@ #include "sequence_state.h" +#include + #include "cuda_utils.h" #include "memory.h" +#include "model_config_utils.h" #include "triton/common/logging.h" namespace triton { namespace core { @@ -211,42 +214,23 @@ SequenceStates::Initialize( } else { size_t state_size; if (state.second.data_type() == inference::DataType::TYPE_STRING) { - auto element_count = triton::common::GetElementCount(dims); - if (element_count == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "state '" + state_config.input_name() + "' shape " + - triton::common::DimsListToString(dims) + - " contains an invalid dimension"); - } else if ( - element_count == triton::common::OVERFLOW_SIZE || - (static_cast(element_count) > SIZE_MAX / sizeof(int32_t))) { - return Status( - Status::Code::INVALID_ARG, - "state '" + state_config.input_name() + - "' causes total element count or byte size to exceed maximum " - "size of " + - std::to_string(INT64_MAX)); - } + int64_t element_count = 0; + RETURN_IF_ERROR( + GetElementCount(dims, state_config.input_name(), &element_count)); // Total number of bytes required is equal to the element count // multiplied by 4. - state_size = sizeof(int32_t) * static_cast(element_count); - } else { - auto byte_size = - triton::common::GetByteSize(state.second.data_type(), dims); - if (byte_size == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, - "state '" + state_config.input_name() + "' shape " + - triton::common::DimsListToString(dims) + - " contains an invalid dimension"); - } else if (byte_size == triton::common::OVERFLOW_SIZE) { + if (static_cast(state_size) > SIZE_MAX / sizeof(int32_t)) { return Status( Status::Code::INVALID_ARG, - "state '" + state_config.input_name() + - "' causes total byte size to exceed maximum size of " + - std::to_string(INT64_MAX)); + "element count for input '" + state_config.input_name() + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); } + state_size = sizeof(int32_t) * static_cast(element_count); + } else { + int64_t byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + state.second.data_type(), dims, state_config.input_name(), + &byte_size)); state_size = static_cast(byte_size); } if (use_growable_memory) { @@ -404,9 +388,16 @@ SequenceStates::OutputState( return OutputState(name, datatype, shape.data(), shape.size(), output_state); } -std::shared_ptr -SequenceStates::CopyAsNull(const std::shared_ptr& from) +Status +SequenceStates::CopyAsNull( + const std::shared_ptr& from, + std::shared_ptr* to) { + if (to == nullptr) { + return Status( + Status::Code::INVALID_ARG, "SequenceStates 'to' must not be null"); + } + *to = nullptr; std::shared_ptr lsequence_states; if (from != nullptr) { lsequence_states.reset(new SequenceStates); @@ -425,9 +416,18 @@ SequenceStates::CopyAsNull(const std::shared_ptr& from) if (from_input_state_tensor->DType() == inference::DataType::TYPE_STRING) { // Use all-zero input states for null requests. - auto element_count = - triton::common::GetElementCount(from_input_state_tensor->Shape()); - auto state_size = sizeof(int32_t) * element_count; + int64_t element_count = 0; + RETURN_IF_ERROR(GetElementCount( + from_input_state_tensor->Shape(), from_input_state_tensor->Name(), + &element_count)); + if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { + return Status( + Status::Code::INVALID_ARG, + "element count for input '" + from_input_state_tensor->Name() + + "' exceeds maximum size of " + std::to_string(SIZE_MAX)); + } + size_t state_size = + static_cast(element_count) * sizeof(int32_t); data = std::make_shared( state_size, TRITONSERVER_MEMORY_CPU, 0); } else { @@ -454,6 +454,7 @@ SequenceStates::CopyAsNull(const std::shared_ptr& from) false /* use_growable_memory */))); } } - return lsequence_states; + *to = std::move(lsequence_states); + return Status::Success; } }} // namespace triton::core diff --git a/src/sequence_state.h b/src/sequence_state.h index 7faba3429..c2e9fe909 100644 --- a/src/sequence_state.h +++ b/src/sequence_state.h @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -160,8 +160,10 @@ class SequenceStates { const std::vector& shape, SequenceState** output_state); // Create a copy of the 'from' sequence states for NULL requests. - static std::shared_ptr CopyAsNull( - const std::shared_ptr& from); + // On success, sets *to and returns Status::Success; on failure returns error. + static Status CopyAsNull( + const std::shared_ptr& from, + std::shared_ptr* to); const std::map>& InputStates() { From aab8e04b78312f82a1cc4301d35b7dcdc7c51d3f Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 11:53:13 -0800 Subject: [PATCH 09/15] Copyrights --- src/infer_request.h | 2 +- src/model_repository_manager/model_repository_manager.cc | 2 +- src/sequence_batch_scheduler/sequence_utils.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/infer_request.h b/src/infer_request.h index e5c9fff87..02ab5a4f0 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -1,4 +1,4 @@ -// Copyright 2020-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2020-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions diff --git a/src/model_repository_manager/model_repository_manager.cc b/src/model_repository_manager/model_repository_manager.cc index 96df7b637..b9927e647 100644 --- a/src/model_repository_manager/model_repository_manager.cc +++ b/src/model_repository_manager/model_repository_manager.cc @@ -1,4 +1,4 @@ -// Copyright 2018-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2018-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions diff --git a/src/sequence_batch_scheduler/sequence_utils.cc b/src/sequence_batch_scheduler/sequence_utils.cc index b469f1472..96abf62bf 100644 --- a/src/sequence_batch_scheduler/sequence_utils.cc +++ b/src/sequence_batch_scheduler/sequence_utils.cc @@ -1,4 +1,4 @@ -// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions From a63e5767b4b847f6abda474ebe9a7f186df5ea1a Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 13:24:02 -0800 Subject: [PATCH 10/15] Refactor GetByteSize --- src/backend_model_instance.cc | 18 ++------ src/infer_request.cc | 43 ++++++------------- src/model_config_utils.h | 41 ++++++++++++------ .../sequence_batch_scheduler.cc | 17 ++------ src/sequence_state.cc | 42 ++++-------------- 5 files changed, 57 insertions(+), 104 deletions(-) diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index 73710339f..7b48ca6d0 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -372,20 +372,10 @@ TritonModelInstance::GenerateWarmupData() int64_t max_zero_byte_size = 0; int64_t max_random_byte_size = 0; for (const auto& input_meta : warmup_setting.inputs()) { - int64_t element_count = 0; - RETURN_IF_ERROR(GetElementCount( - input_meta.second.dims(), input_meta.first, &element_count)); - int64_t dtype_byte_size = - triton::common::GetDataTypeByteSize(input_meta.second.data_type()); - dtype_byte_size = - dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (element_count > INT64_MAX / dtype_byte_size) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + input_meta.first + - "' exceeds maximum size of " + std::to_string(INT64_MAX)); - } - int64_t batch_byte_size = element_count * dtype_byte_size; + int64_t batch_byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + input_meta.second.data_type(), input_meta.second.dims(), + input_meta.first, &batch_byte_size)); switch (input_meta.second.input_data_type_case()) { case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: diff --git a/src/infer_request.cc b/src/infer_request.cc index 2b310400b..2b0f106f6 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -594,17 +594,10 @@ InferenceRequest::CopyAsNull( } if (input.second.DType() == inference::DataType::TYPE_STRING) { - int64_t element_count = 0; - RETURN_IF_ERROR( - GetElementCount(input.second.Shape(), input.first, &element_count)); - if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + input.first + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - size_t str_byte_size = - static_cast(sizeof(int32_t) * element_count); + size_t str_byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + inference::DataType::TYPE_STRING, input.second.Shape(), input.first, + reinterpret_cast(&str_byte_size))); max_str_byte_size = std::max(str_byte_size, max_str_byte_size); if (str_byte_size > max_byte_size) { max_byte_size = str_byte_size; @@ -652,19 +645,12 @@ InferenceRequest::CopyAsNull( if (input.first == *max_input_name) { new_input->SetData(data); } else { - if (inference::DataType::TYPE_STRING == input.second.DType()) { - int64_t element_count = 0; - RETURN_IF_ERROR( - GetElementCount(input.second.Shape(), input.first, &element_count)); - if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + input.first + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - new_input->AppendData( - data_base, static_cast(element_count) * sizeof(int32_t), - mem_type, mem_id); + if (input.second.DType() == inference::DataType::TYPE_STRING) { + int64_t str_byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + inference::DataType::TYPE_STRING, input.second.Shape(), input.first, + &str_byte_size)); + new_input->AppendData(data_base, str_byte_size, mem_type, mem_id); } else { new_input->AppendData( data_base, input.second.Data()->TotalByteSize(), mem_type, mem_id); @@ -867,13 +853,8 @@ InferenceRequest::LoadInputStates() // Add the input states to the inference request. if (sequence_states_ != nullptr) { if (sequence_states_->IsNullRequest()) { - std::shared_ptr copied; - Status status = SequenceStates::CopyAsNull( - sequence_states_->NullSequenceStates(), &copied); - if (!status.IsOk()) { - return status; - } - sequence_states_ = copied; + RETURN_IF_ERROR(SequenceStates::CopyAsNull( + sequence_states_->NullSequenceStates(), &sequence_states_)); } for (auto& input_state_pair : sequence_states_->InputStates()) { auto& input_state = input_state_pair.second; diff --git a/src/model_config_utils.h b/src/model_config_utils.h index 14bad0b0e..58b785029 100644 --- a/src/model_config_utils.h +++ b/src/model_config_utils.h @@ -345,19 +345,36 @@ GetByteSize( const inference::DataType& dtype, const T& dims, const std::string& name, int64_t* size) { - *size = triton::common::GetByteSize(dtype, dims); - if (*size == triton::common::INVALID_SIZE) { - return Status( - Status::Code::INVALID_ARG, "tensor '" + name + - "' contains an invalid dimension " + - triton::common::DimsListToString(dims)); - } else if (*size == triton::common::OVERFLOW_SIZE) { - return Status( - Status::Code::INVALID_ARG, "byte size for tensor '" + name + - "' exceeds maximum size of " + - std::to_string(INT64_MAX)); + int64_t byte_size = 0; + if (dtype == inference::DataType::TYPE_STRING) { + int64_t element_count = 0; + RETURN_IF_ERROR(GetElementCount(dims, name, &element_count)); + + // Total number of bytes required is equal to the element count + // multiplied by 4. + if (element_count > static_cast(INT64_MAX / sizeof(int32_t))) { + return Status( + Status::Code::INVALID_ARG, "byte size for tensor '" + name + + "' exceeds maximum size of " + + std::to_string(INT64_MAX)); + } + byte_size = sizeof(int32_t) * element_count; } else { - return Status::Success; + byte_size = triton::common::GetByteSize(dtype, dims); + if (byte_size == triton::common::INVALID_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "tensor '" + name + "' contains an invalid dimension " + + triton::common::DimsListToString(dims)); + } else if (byte_size == triton::common::OVERFLOW_SIZE) { + return Status( + Status::Code::INVALID_ARG, "byte size for tensor '" + name + + "' exceeds maximum size of " + + std::to_string(INT64_MAX)); + } } + *size = byte_size; + return Status::Success; } + }} // namespace triton::core diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index df93594e2..2847cee7a 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -405,19 +405,10 @@ SequenceBatchScheduler::GenerateInitialStateData( } // Calculate total memory byte size - int64_t element_count = 0; - RETURN_IF_ERROR(GetElementCount( - initial_state.dims(), state.input_name(), &element_count)); - size_t dtype_byte_size = - triton::common::GetDataTypeByteSize(initial_state.data_type()); - dtype_byte_size = dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - if (static_cast(element_count) > SIZE_MAX / dtype_byte_size) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + state.input_name() + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - size_t total_byte_size = static_cast(element_count) * dtype_byte_size; + size_t total_byte_size = 0; + RETURN_IF_ERROR(GetByteSize( + initial_state.data_type(), initial_state.dims(), state.input_name(), + reinterpret_cast(&total_byte_size))); switch (initial_state.state_data_case()) { case inference::ModelSequenceBatching_InitialState::StateDataCase:: diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 1939124f5..511bfb9bb 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -212,27 +212,10 @@ SequenceStates::Initialize( initial_state_it->second.data_->TotalByteSize()); } } else { - size_t state_size; - if (state.second.data_type() == inference::DataType::TYPE_STRING) { - int64_t element_count = 0; - RETURN_IF_ERROR( - GetElementCount(dims, state_config.input_name(), &element_count)); - // Total number of bytes required is equal to the element count - // multiplied by 4. - if (static_cast(state_size) > SIZE_MAX / sizeof(int32_t)) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + state_config.input_name() + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - state_size = sizeof(int32_t) * static_cast(element_count); - } else { - int64_t byte_size = 0; - RETURN_IF_ERROR(GetByteSize( - state.second.data_type(), dims, state_config.input_name(), - &byte_size)); - state_size = static_cast(byte_size); - } + int64_t state_size = 0; + RETURN_IF_ERROR(GetByteSize( + state.second.data_type(), dims, state_config.input_name(), + &state_size)); if (use_growable_memory) { std::unique_ptr growable_memory; RETURN_IF_ERROR(GrowableMemory::Create( @@ -415,19 +398,10 @@ SequenceStates::CopyAsNull( std::shared_ptr data; if (from_input_state_tensor->DType() == inference::DataType::TYPE_STRING) { - // Use all-zero input states for null requests. - int64_t element_count = 0; - RETURN_IF_ERROR(GetElementCount( - from_input_state_tensor->Shape(), from_input_state_tensor->Name(), - &element_count)); - if (static_cast(element_count) > SIZE_MAX / sizeof(int32_t)) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + from_input_state_tensor->Name() + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - size_t state_size = - static_cast(element_count) * sizeof(int32_t); + int64_t state_size = 0; + RETURN_IF_ERROR(GetByteSize( + inference::DataType::TYPE_STRING, from_input_state_tensor->Shape(), + from_input_state_tensor->Name(), &state_size)); data = std::make_shared( state_size, TRITONSERVER_MEMORY_CPU, 0); } else { From 0e20a056ebeeebd4ef9d25eab4735b81f3f81b11 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 13:29:33 -0800 Subject: [PATCH 11/15] fIx --- src/backend_model_instance.cc | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index 7b48ca6d0..6710bcdf3 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -431,24 +431,10 @@ TritonModelInstance::GenerateWarmupData() // Second pass to prepare original inputs. std::vector> input_sps; for (const auto& input_meta : warmup_setting.inputs()) { - int64_t batch1_element_count, dtype_byte_size = 0; - RETURN_IF_ERROR(GetElementCount( - input_meta.second.dims(), input_meta.first, &batch1_element_count)); + size_t batch_byte_size = 0; RETURN_IF_ERROR(GetByteSize( input_meta.second.data_type(), input_meta.second.dims(), - input_meta.first, &dtype_byte_size)); - dtype_byte_size = - dtype_byte_size == 0 ? sizeof(int32_t) : dtype_byte_size; - - if (static_cast(batch1_element_count) > - SIZE_MAX / dtype_byte_size) { - return Status( - Status::Code::INVALID_ARG, - "element count for input '" + input_meta.first + - "' exceeds maximum size of " + std::to_string(SIZE_MAX)); - } - size_t batch_byte_size = - static_cast(batch1_element_count) * dtype_byte_size; + input_meta.first, reinterpret_cast(&batch_byte_size))); const char* allocated_ptr; switch (input_meta.second.input_data_type_case()) { From 197e7ea23cd7b94eb2b6cde683bcabb7e7cd6fb4 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 13:32:49 -0800 Subject: [PATCH 12/15] xx --- src/sequence_state.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 511bfb9bb..584ae4056 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -381,6 +381,7 @@ SequenceStates::CopyAsNull( Status::Code::INVALID_ARG, "SequenceStates 'to' must not be null"); } *to = nullptr; + std::shared_ptr lsequence_states; if (from != nullptr) { lsequence_states.reset(new SequenceStates); From 807dde8e21c57328b2d333e5444c2faa6d55ad03 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 19:57:26 -0800 Subject: [PATCH 13/15] Address comments --- src/backend_model_instance.cc | 8 ++++--- src/infer_request.cc | 5 ++-- src/model_config_utils.cc | 5 ++-- src/model_config_utils.h | 23 ++++++++++++++----- .../sequence_batch_scheduler.cc | 5 ++-- 5 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index 6710bcdf3..31a7b9ebe 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -431,10 +431,11 @@ TritonModelInstance::GenerateWarmupData() // Second pass to prepare original inputs. std::vector> input_sps; for (const auto& input_meta : warmup_setting.inputs()) { - size_t batch_byte_size = 0; + int64_t batch_byte_size_signed = 0; RETURN_IF_ERROR(GetByteSize( input_meta.second.data_type(), input_meta.second.dims(), - input_meta.first, reinterpret_cast(&batch_byte_size))); + input_meta.first, &batch_byte_size_signed)); + size_t batch_byte_size = static_cast(batch_byte_size_signed); const char* allocated_ptr; switch (input_meta.second.input_data_type_case()) { @@ -460,10 +461,11 @@ TritonModelInstance::GenerateWarmupData() {model_->LocalizedModelPath(), kWarmupDataFolder, input_meta.second.input_data_file()}), input_data)); + if (input_meta.second.data_type() == inference::DataType::TYPE_STRING) { batch_byte_size = input_data->size(); - } else if (((size_t)batch_byte_size) > input_data->size()) { + } else if (batch_byte_size > input_data->size()) { return Status( Status::Code::INVALID_ARG, lrequest->LogRequest() + "warmup setting expects " + diff --git a/src/infer_request.cc b/src/infer_request.cc index 2b0f106f6..6101302f1 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -594,10 +594,11 @@ InferenceRequest::CopyAsNull( } if (input.second.DType() == inference::DataType::TYPE_STRING) { - size_t str_byte_size = 0; + int64_t str_byte_size_signed = 0; RETURN_IF_ERROR(GetByteSize( inference::DataType::TYPE_STRING, input.second.Shape(), input.first, - reinterpret_cast(&str_byte_size))); + &str_byte_size_signed)); + size_t str_byte_size = static_cast(str_byte_size_signed); max_str_byte_size = std::max(str_byte_size, max_str_byte_size); if (str_byte_size > max_byte_size) { max_byte_size = str_byte_size; diff --git a/src/model_config_utils.cc b/src/model_config_utils.cc index a0261363c..012d8f9b3 100644 --- a/src/model_config_utils.cc +++ b/src/model_config_utils.cc @@ -355,9 +355,10 @@ ValidateIOShape( int64_t dims_size = 0; int64_t reshape_size = 0; - RETURN_IF_ERROR(GetElementCount(io.dims(), "dims", &dims_size)); RETURN_IF_ERROR( - GetElementCount(io.reshape().shape(), "reshape", &reshape_size)); + GetElementCount(io.dims(), io.name() + " dims", &dims_size)); + RETURN_IF_ERROR(GetElementCount( + io.reshape().shape(), io.name() + " reshape", &reshape_size)); // dims and reshape must both have same element count // or both have variable-size dimension. diff --git a/src/model_config_utils.h b/src/model_config_utils.h index 58b785029..e24c56c33 100644 --- a/src/model_config_utils.h +++ b/src/model_config_utils.h @@ -25,12 +25,13 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include + #include "filesystem/api.h" #include "model_config.pb.h" #include "status.h" #include "triton/common/model_config.h" #include "tritonserver_apis.h" - namespace triton { namespace core { /// Enumeration for the different backend types. @@ -323,20 +324,26 @@ template Status GetElementCount(const T& dims, const std::string& name, int64_t* cnt) { - *cnt = triton::common::GetElementCount(dims); - if (*cnt == triton::common::INVALID_SIZE) { + if (cnt == nullptr) { + return Status(Status::Code::INTERNAL, "argument `cnt` cannot be nullptr"); + } + + int64_t element_count = 0; + element_count = triton::common::GetElementCount(dims); + if (element_count == triton::common::INVALID_SIZE) { return Status( Status::Code::INVALID_ARG, "tensor '" + name + "' contains an invalid dimension in shape " + triton::common::DimsListToString(dims)); - } else if (*cnt == triton::common::OVERFLOW_SIZE) { + } else if (element_count == triton::common::OVERFLOW_SIZE) { return Status( Status::Code::INVALID_ARG, "element count for tensor '" + name + "' exceeds maximum size of " + std::to_string(INT64_MAX)); - } else { - return Status::Success; } + + *cnt = element_count; + return Status::Success; } template @@ -345,6 +352,10 @@ GetByteSize( const inference::DataType& dtype, const T& dims, const std::string& name, int64_t* size) { + if (size == nullptr) { + return Status(Status::Code::INTERNAL, "argument `size` cannot be nullptr"); + } + int64_t byte_size = 0; if (dtype == inference::DataType::TYPE_STRING) { int64_t element_count = 0; diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 2847cee7a..8d3b2d1b9 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -405,10 +405,11 @@ SequenceBatchScheduler::GenerateInitialStateData( } // Calculate total memory byte size - size_t total_byte_size = 0; + int64_t total_byte_size_signed = 0; RETURN_IF_ERROR(GetByteSize( initial_state.data_type(), initial_state.dims(), state.input_name(), - reinterpret_cast(&total_byte_size))); + &total_byte_size_signed)); + size_t total_byte_size = static_cast(total_byte_size_signed); switch (initial_state.state_data_case()) { case inference::ModelSequenceBatching_InitialState::StateDataCase:: From a6d6963f191ea94aa835cb5635244d7370e6b08e Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 6 Mar 2026 20:12:58 -0800 Subject: [PATCH 14/15] Fix --- src/backend_model_instance.cc | 7 +++++++ src/model_config_utils.h | 5 +++++ src/sequence_batch_scheduler/sequence_batch_scheduler.cc | 5 ++--- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/backend_model_instance.cc b/src/backend_model_instance.cc index 31a7b9ebe..fd8d5eb52 100644 --- a/src/backend_model_instance.cc +++ b/src/backend_model_instance.cc @@ -376,6 +376,13 @@ TritonModelInstance::GenerateWarmupData() RETURN_IF_ERROR(GetByteSize( input_meta.second.data_type(), input_meta.second.dims(), input_meta.first, &batch_byte_size)); + if (batch_byte_size == triton::common::WILDCARD_SIZE) { + return Status( + Status::Code::INVALID_ARG, + "warmup setting expects all variable-size dimensions are specified " + "for input '" + + input_meta.first + "'"); + } switch (input_meta.second.input_data_type_case()) { case inference::ModelWarmup_Input::InputDataTypeCase::kZeroData: diff --git a/src/model_config_utils.h b/src/model_config_utils.h index e24c56c33..61547077b 100644 --- a/src/model_config_utils.h +++ b/src/model_config_utils.h @@ -361,6 +361,11 @@ GetByteSize( int64_t element_count = 0; RETURN_IF_ERROR(GetElementCount(dims, name, &element_count)); + if (element_count == triton::common::WILDCARD_SIZE) { + *size = triton::common::WILDCARD_SIZE; + return Status::Success; + } + // Total number of bytes required is equal to the element count // multiplied by 4. if (element_count > static_cast(INT64_MAX / sizeof(int32_t))) { diff --git a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc index 8d3b2d1b9..f51877301 100644 --- a/src/sequence_batch_scheduler/sequence_batch_scheduler.cc +++ b/src/sequence_batch_scheduler/sequence_batch_scheduler.cc @@ -1757,9 +1757,8 @@ DirectSequenceBatch::BatcherThread(const int nice) std::unique_ptr ni = nullptr; Status status = InferenceRequest::CopyAsNull(*null_irequest, &ni); if (!status.IsOk()) { - LOG_ERROR - << "internal: unexpecting failure copying null request: " - << status.Message(); + LOG_ERROR << "internal: unexpected failure copying null request: " + << status.Message(); } // Note that when the not-ready control input of the // request is "true" the model can't assume that any From 92beeb8d55ea0d6ece7ffd6fa1b5aee05595b8e3 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Tue, 10 Mar 2026 20:28:29 -0700 Subject: [PATCH 15/15] Fix bug --- src/infer_request.cc | 1 - src/sequence_state.cc | 1 - 2 files changed, 2 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 6101302f1..bdcc8e031 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -523,7 +523,6 @@ InferenceRequest::CopyAsNull( return Status( Status::Code::INVALID_ARG, "InferenceRequest 'to' must not be null"); } - *to = nullptr; // Create a copy of 'from' request with artificial inputs and no requested // outputs. Maybe more efficient to share inputs and other metadata, diff --git a/src/sequence_state.cc b/src/sequence_state.cc index 584ae4056..66e03b9f8 100644 --- a/src/sequence_state.cc +++ b/src/sequence_state.cc @@ -380,7 +380,6 @@ SequenceStates::CopyAsNull( return Status( Status::Code::INVALID_ARG, "SequenceStates 'to' must not be null"); } - *to = nullptr; std::shared_ptr lsequence_states; if (from != nullptr) {