Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions torch_tensorrt/runtime/MutableTorchTRTModule.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
--- a/torch_tensorrt/runtime/MutableTorchTRTModule.cpp
+++ b/torch_tensorrt/runtime/MutableTorchTRTModule.cpp
@@ -10,8 +10,17 @@ void MutableTorchTRTModule::update_refit_condition() {
- if (engine_->needsRefit()) {
- refit_flag_ = RefitFlag::NEEDS_REFIT;
- } else {
- refit_flag_ = RefitFlag::LIVE;
+ // Three-step validation for CUDA 13.x on B100/H100
+ // Step 1: Check if any weight-affecting attributes changed
+ bool weight_affecting_change = check_weight_affecting_attribute_change();
+ // Step 2: If no weight-affecting changes, directly set to LIVE
+ if (!weight_affecting_change) {
+ refit_flag_ = RefitFlag::LIVE;
+ return;
}
+ // Step 3: Otherwise, delegate to TensorRT's needsRefit
+ if (engine_->needsRefit()) {
+ refit_flag_ = RefitFlag::NEEDS_REFIT;
+ } else {
+ refit_flag_ = RefitFlag::LIVE;
+ }
}
79 changes: 79 additions & 0 deletions torch_tensorrt/runtime/refit_helpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#ifndef TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H
#define TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H

#include <cstddef>
#include <cstring>
#include <cmath>
#include <limits>

namespace torch_tensorrt {
namespace runtime {

enum class DataType {
kFLOAT = 0,
kHALF = 1,
kINT8 = 2,
kINT32 = 3,
kBOOL = 4
};

inline bool are_weights_equal_exact(const void* a, const void* b, size_t size_in_bytes) {
return std::memcmp(a, b, size_in_bytes) == 0;
}

inline bool are_floats_equal(float a, float b, float epsilon = 1e-6) {
float diff = std::fabs(a - b);
float max_abs = std::fabs(a) > std::fabs(b) ? std::fabs(a) : std::fabs(b);
return diff <= epsilon * (max_abs > 1.0f ? max_abs : 1.0f);
}

inline bool are_weights_equal_tolerant(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) {
switch (dtype) {
case DataType::kFLOAT: {
const float* pa = static_cast<const float*>(a);
const float* pb = static_cast<const float*>(b);
for (size_t i = 0; i < count; ++i) {
if (!are_floats_equal(pa[i], pb[i], epsilon)) {
return false;
}
}
return true;
}
case DataType::kHALF:
case DataType::kINT8:
case DataType::kINT32:
case DataType::kBOOL:
default: {
size_t elem_size = 0;
switch (dtype) {
case DataType::kHALF: elem_size = 2; break;
case DataType::kINT8: elem_size = 1; break;
case DataType::kINT32: elem_size = 4; break;
case DataType::kBOOL: elem_size = 1; break;
default: elem_size = 0; break;
}
return std::memcmp(a, b, count * elem_size) == 0;
}
}
}

inline bool are_weights_robust_equal(const void* a, const void* b, size_t count, DataType dtype, float epsilon = 1e-6) {
size_t elem_size = 0;
switch (dtype) {
case DataType::kFLOAT: elem_size = 4; break;
case DataType::kHALF: elem_size = 2; break;
case DataType::kINT8: elem_size = 1; break;
case DataType::kINT32: elem_size = 4; break;
case DataType::kBOOL: elem_size = 1; break;
default: elem_size = 0; break;
}
if (are_weights_equal_exact(a, b, count * elem_size)) {
return true;
}
return are_weights_equal_tolerant(a, b, count, dtype, epsilon);
}

} // namespace runtime
} // namespace torch_tensorrt

#endif // TORCH_TENSORRT_RUNTIME_REFIT_HELPERS_H
Loading