diff --git a/torch_tensorrt/runtime/MutableTorchTRTModule.cpp b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp new file mode 100644 index 0000000000..ee95b0f612 --- /dev/null +++ b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp @@ -0,0 +1,22 @@ +--- a/torch_tensorrt/runtime/MutableTorchTRTModule.cpp ++++ b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp +@@ -10,8 +10,17 @@ void MutableTorchTRTModule::update_refit_condition() { +- if (engine_->needsRefit()) { +- refit_flag_ = RefitFlag::NEEDS_REFIT; +- } else { +- refit_flag_ = RefitFlag::LIVE; ++ // Three-step validation for CUDA 13.x on B100/H100 ++ // Step 1: Check if any weight-affecting attributes changed ++ bool weight_affecting_change = check_weight_affecting_attribute_change(); ++ // Step 2: If no weight-affecting changes, directly set to LIVE ++ if (!weight_affecting_change) { ++ refit_flag_ = RefitFlag::LIVE; ++ return; + } ++ // Step 3: Otherwise, delegate to TensorRT's needsRefit ++ if (engine_->needsRefit()) { ++ refit_flag_ = RefitFlag::NEEDS_REFIT; ++ } else { ++ refit_flag_ = RefitFlag::LIVE; ++ } + } \ No newline at end of file diff --git a/torch_tensorrt/runtime/refit_helpers.h b/torch_tensorrt/runtime/refit_helpers.h new file mode 100644 index 0000000000..772b939ca3 --- /dev/null +++ b/torch_tensorrt/runtime/refit_helpers.h @@ -0,0 +1,79 @@ +#ifndef TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H +#define TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H + +#include +#include +#include +#include + +namespace torch_tensorrt { +namespace runtime { + +enum class DataType { + kFLOAT = 0, + kHALF = 1, + kINT8 = 2, + kINT32 = 3, + kBOOL = 4 +}; + +inline bool are_weights_equal_exact(const void* a, const void* b, size_t size_in_bytes) { + return std::memcmp(a, b, size_in_bytes) == 0; +} + +inline bool are_floats_equal(float a, float b, float epsilon = 1e-6) { + float diff = std::fabs(a - b); + float max_abs = std::fabs(a) > std::fabs(b) ? std::fabs(a) : std::fabs(b); + return diff <= epsilon * (max_abs > 1.0f ? max_abs : 1.0f); +} + +inline bool are_weights_equal_tolerant(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) { + switch (dtype) { + case DataType::kFLOAT: { + const float* pa = static_cast(a); + const float* pb = static_cast(b); + for (size_t i = 0; i < count; ++i) { + if (!are_floats_equal(pa[i], pb[i], epsilon)) { + return false; + } + } + return true; + } + case DataType::kHALF: + case DataType::kINT8: + case DataType::kINT32: + case DataType::kBOOL: + default: { + size_t elem_size = 0; + switch (dtype) { + case DataType::kHALF: elem_size = 2; break; + case DataType::kINT8: elem_size = 1; break; + case DataType::kINT32: elem_size = 4; break; + case DataType::kBOOL: elem_size = 1; break; + default: elem_size = 0; break; + } + return std::memcmp(a, b, count * elem_size) == 0; + } + } +} + +inline bool are_weights_robust_equal(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) { + size_t elem_size = 0; + switch (dtype) { + case DataType::kFLOAT: elem_size = 4; break; + case DataType::kHALF: elem_size = 2; break; + case DataType::kINT8: elem_size = 1; break; + case DataType::kINT32: elem_size = 4; break; + case DataType::kBOOL: elem_size = 1; break; + default: elem_size = 0; break; + } + if (are_weights_equal_exact(a, b, count * elem_size)) { + return true; + } + return are_weights_equal_tolerant(a, b, count, dtype, epsilon); +} + +} // namespace runtime +} // namespace torch_tensorrt + +#endif // TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H \ No newline at end of file