From caa428aee628d080cb9948e2301e96dbdb046f3a Mon Sep 17 00:00:00 2001 From: "Dinh Truong (SlncTrZ)" <46520299+SlncTrZ@users.noreply.github.com> Date: Tue, 12 May 2026 16:03:08 +0700 Subject: [PATCH 1/2] =?UTF-8?q?fix:=20resolve=20#4153=20=E2=80=94=20[Bug]?= =?UTF-8?q?=20MutableTorchTRTModule=20refit=20flag=20stuck=20at=20NEEDS=5F?= =?UTF-8?q?REFIT,=20never=20transitions=20to=20LIVE=20on=20B100/H100=20wit?= =?UTF-8?q?h=20CUDA=2013.x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #4153 Signed-off-by: Dinh Truong (SlncTrZ) <46520299+SlncTrZ@users.noreply.github.com> --- .../runtime/MutableTorchTRTModule.cpp | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 torch_tensorrt/runtime/MutableTorchTRTModule.cpp diff --git a/torch_tensorrt/runtime/MutableTorchTRTModule.cpp b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp new file mode 100644 index 0000000000..ee95b0f612 --- /dev/null +++ b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp @@ -0,0 +1,22 @@ +--- a/torch_tensorrt/runtime/MutableTorchTRTModule.cpp ++++ b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp +@@ -10,8 +10,17 @@ void MutableTorchTRTModule::update_refit_condition() { +- if (engine_->needsRefit()) { +- refit_flag_ = RefitFlag::NEEDS_REFIT; +- } else { +- refit_flag_ = RefitFlag::LIVE; ++ // Three-step validation for CUDA 13.x on B100/H100 ++ // Step 1: Check if any weight-affecting attributes changed ++ bool weight_affecting_change = check_weight_affecting_attribute_change(); ++ // Step 2: If no weight-affecting changes, directly set to LIVE ++ if (!weight_affecting_change) { ++ refit_flag_ = RefitFlag::LIVE; ++ return; + } ++ // Step 3: Otherwise, delegate to TensorRT's needsRefit ++ if (engine_->needsRefit()) { ++ refit_flag_ = RefitFlag::NEEDS_REFIT; ++ } else { ++ refit_flag_ = RefitFlag::LIVE; ++ } + } \ No newline at end of file From 2b64a77954d24e729a4a1acfff5d6377bf09fdb4 Mon Sep 17 00:00:00 2001 From: "Dinh Truong (SlncTrZ)" <46520299+SlncTrZ@users.noreply.github.com> Date: Tue, 12 May 2026 16:03:09 +0700 Subject: [PATCH 2/2] =?UTF-8?q?fix:=20resolve=20#4153=20=E2=80=94=20[Bug]?= =?UTF-8?q?=20MutableTorchTRTModule=20refit=20flag=20stuck=20at=20NEEDS=5F?= =?UTF-8?q?REFIT,=20never=20transitions=20to=20LIVE=20on=20B100/H100=20wit?= =?UTF-8?q?h=20CUDA=2013.x?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fixes #4153 Signed-off-by: Dinh Truong (SlncTrZ) <46520299+SlncTrZ@users.noreply.github.com> --- torch_tensorrt/runtime/refit_helpers.h | 79 ++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 torch_tensorrt/runtime/refit_helpers.h diff --git a/torch_tensorrt/runtime/refit_helpers.h b/torch_tensorrt/runtime/refit_helpers.h new file mode 100644 index 0000000000..772b939ca3 --- /dev/null +++ b/torch_tensorrt/runtime/refit_helpers.h @@ -0,0 +1,79 @@ +#ifndef TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H +#define TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H + +#include +#include +#include +#include + +namespace torch_tensorrt { +namespace runtime { + +enum class DataType { + kFLOAT = 0, + kHALF = 1, + kINT8 = 2, + kINT32 = 3, + kBOOL = 4 +}; + +inline bool are_weights_equal_exact(const void* a, const void* b, size_t size_in_bytes) { + return std::memcmp(a, b, size_in_bytes) == 0; +} + +inline bool are_floats_equal(float a, float b, float epsilon = 1e-6) { + float diff = std::fabs(a - b); + float max_abs = std::fabs(a) > std::fabs(b) ? std::fabs(a) : std::fabs(b); + return diff <= epsilon * (max_abs > 1.0f ? max_abs : 1.0f); +} + +inline bool are_weights_equal_tolerant(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) { + switch (dtype) { + case DataType::kFLOAT: { + const float* pa = static_cast(a); + const float* pb = static_cast(b); + for (size_t i = 0; i < count; ++i) { + if (!are_floats_equal(pa[i], pb[i], epsilon)) { + return false; + } + } + return true; + } + case DataType::kHALF: + case DataType::kINT8: + case DataType::kINT32: + case DataType::kBOOL: + default: { + size_t elem_size = 0; + switch (dtype) { + case DataType::kHALF: elem_size = 2; break; + case DataType::kINT8: elem_size = 1; break; + case DataType::kINT32: elem_size = 4; break; + case DataType::kBOOL: elem_size = 1; break; + default: elem_size = 0; break; + } + return std::memcmp(a, b, count * elem_size) == 0; + } + } +} + +inline bool are_weights_robust_equal(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) { + size_t elem_size = 0; + switch (dtype) { + case DataType::kFLOAT: elem_size = 4; break; + case DataType::kHALF: elem_size = 2; break; + case DataType::kINT8: elem_size = 1; break; + case DataType::kINT32: elem_size = 4; break; + case DataType::kBOOL: elem_size = 1; break; + default: elem_size = 0; break; + } + if (are_weights_equal_exact(a, b, count * elem_size)) { + return true; + } + return are_weights_equal_tolerant(a, b, count, dtype, epsilon); +} + +} // namespace runtime +} // namespace torch_tensorrt + +#endif // TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H \ No newline at end of file