Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 26 additions & 31 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def convert_method_to_trt_engine(
module, torchtrt_arg_inputs, kwarg_inputs=torchtrt_kwarg_inputs, **kwargs
)

return dynamo_convert_exported_program_to_serialized_trt_engine( # type: ignore[no-any-return]
return dynamo_convert_exported_program_to_serialized_trt_engine(
exp_program,
arg_inputs=tuple(normalized_arg_inputs),
kwarg_inputs=torchtrt_kwarg_inputs,
Expand Down Expand Up @@ -1272,41 +1272,36 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None


def _normalize_engine_constants_to_python(exp_program: "ExportedProgram") -> None:
pass
"""Convert C++ ``torch.classes.tensorrt.Engine`` constants to Python ``TRTEngine``.

The C++ runtime stores engine constants as ``torch._C.ScriptObject``
(``torch.classes.tensorrt.Engine``). Python ``TRTEngine`` is registered as
an opaque type so ``torch.export`` can serialise it with ``pickle``. By
converting before save the artifact is portable across both runtimes.
"""
import base64

# TODO: Uncomment this when cross serialization is enabled
# """Convert C++ ``torch.classes.tensorrt.Engine`` constants to Python ``TRTEngine``.

# The C++ runtime stores engine constants as ``torch._C.ScriptObject``
# (``torch.classes.tensorrt.Engine``). Python ``TRTEngine`` is registered as
# an opaque type so ``torch.export`` can serialise it with ``pickle``. By
# converting before save the artifact is portable across both runtimes.
# """
# import base64

# from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ENGINE_IDX
# from torch_tensorrt.dynamo.runtime._TRTEngine import (
# EngineSerializer,
# TRTEngine,
# )

# for fqn, constant in list(exp_program.constants.items()):
# if isinstance(constant, (torch._C.ScriptObject, TRTEngine)):
from torch_tensorrt.dynamo.runtime._serialized_engine_layout import ENGINE_IDX
from torch_tensorrt.dynamo.runtime._TRTEngine import (
EngineSerializer,
TRTEngine,
)

# state = constant.__getstate__()
# if len(state) == 2 and (
# state[1] == "TRTEngine"
# or state[1] == "__torch__.torch.classes.tensorrt.Engine"
# ):
# serialized_info = list(state[0])
# serialized_info[ENGINE_IDX] = base64.b64decode(
# serialized_info[ENGINE_IDX]
# )
# exp_program.constants[fqn] = EngineSerializer(serialized_info)
for fqn, constant in list(exp_program.constants.items()):
if isinstance(constant, (torch._C.ScriptObject, TRTEngine)):

state = constant.__getstate__()
if len(state) == 2 and (
state[1] == "TRTEngine"
or state[1] == "__torch__.torch.classes.tensorrt.Engine"
):
serialized_info = list(state[0])
serialized_info[ENGINE_IDX] = base64.b64decode(
serialized_info[ENGINE_IDX]
)
exp_program.constants[fqn] = EngineSerializer(serialized_info)


#
def function_overload_with_kwargs(
fn: Callable[..., Any], *args: Any, **kwargs: Any
) -> Any:
Expand Down
45 changes: 36 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def cross_compile_for_windows(
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: bool = False,
use_python_runtime: bool = False, # Deprecated; setting True emits DeprecationWarning. Kept for backward compatibility.
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = _defaults.DRYRUN,
Expand Down Expand Up @@ -163,7 +163,7 @@ def cross_compile_for_windows(
max_aux_stream (Optional[int]): Maximum streams in the engine
version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)
optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.
use_python_runtime: (bool): Force the pure-Python TensorRT runtime (``TRTEngine`` + ``tensorrt::execute_engine_python``). The default is ``False``, which uses the C++ runtime when available and falls back to the Python runtime automatically when the C++ runtime is unavailable.
use_python_runtime: (bool): **Deprecated**. Kept for backward compatibility; emits a ``DeprecationWarning`` when set to ``True``. The Python and C++ runtimes are now merged and the runtime is selected automatically based on whether the C++ Torch-TensorRT runtime is available.
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
Expand Down Expand Up @@ -220,6 +220,16 @@ def cross_compile_for_windows(
stacklevel=2,
)

if use_python_runtime:
warnings.warn(
"`use_python_runtime` is deprecated and has no effect. The Python and C++ "
"runtimes have been merged; the runtime is now selected automatically based "
"on whether the C++ Torch-TensorRT runtime is available. This argument will "
"be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted.",
Expand Down Expand Up @@ -334,7 +344,6 @@ def cross_compile_for_windows(
"dynamically_allocate_resources": dynamically_allocate_resources,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
"use_python_runtime": use_python_runtime,
}

# disable the following settings is not supported for cross compilation for windows feature
Expand Down Expand Up @@ -424,7 +433,7 @@ def compile(
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: bool = False,
use_python_runtime: bool = False, # Deprecated; setting True emits DeprecationWarning. Kept for backward compatibility.
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = _defaults.DRYRUN,
Expand Down Expand Up @@ -519,7 +528,7 @@ def compile(
max_aux_streams (Optional[int]): Maximum streams in the engine
version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)
optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.
use_python_runtime: (bool): Force the pure-Python TensorRT runtime (``TRTEngine`` + ``tensorrt::execute_engine_python``). The default is ``False``, which uses the C++ runtime when available and falls back to the Python runtime automatically when the C++ runtime is unavailable.
use_python_runtime: (bool): **Deprecated**. Kept for backward compatibility; emits a ``DeprecationWarning`` when set to ``True``. The Python and C++ runtimes are now merged and the runtime is selected automatically based on whether the C++ Torch-TensorRT runtime is available.
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
Expand Down Expand Up @@ -579,6 +588,16 @@ def compile(
stacklevel=2,
)

if use_python_runtime:
warnings.warn(
"`use_python_runtime` is deprecated and has no effect. The Python and C++ "
"runtimes have been merged; the runtime is now selected automatically based "
"on whether the C++ Torch-TensorRT runtime is available. This argument will "
"be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted",
Expand Down Expand Up @@ -731,7 +750,6 @@ def compile(
"dynamically_allocate_resources": dynamically_allocate_resources,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
"use_python_runtime": use_python_runtime,
}
logger.debug(f"CPU memory usage before lowering: {get_cpu_memory_usage()} MB")
settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -1218,7 +1236,7 @@ def convert_exported_program_to_serialized_trt_engine(
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
use_python_runtime: bool = False,
use_python_runtime: bool = False, # Deprecated; setting True emits DeprecationWarning. Kept for backward compatibility.
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
dryrun: bool = _defaults.DRYRUN,
Expand Down Expand Up @@ -1294,7 +1312,7 @@ def convert_exported_program_to_serialized_trt_engine(
max_aux_streams (Optional[int]): Maximum streams in the engine
version_compatible (bool): Build the TensorRT engines compatible with future versions of TensorRT (Restrict to lean runtime operators to provide version forward compatibility for the engines)
optimization_level: (Optional[int]): Setting a higher optimization level allows TensorRT to spend longer engine building time searching for more optimization options. The resulting engine may have better performance compared to an engine built with a lower optimization level. The default optimization level is 3. Valid values include integers from 0 to the maximum optimization level, which is currently 5. Setting it to be greater than the maximum level results in identical behavior to the maximum level.
use_python_runtime: (bool): Force the pure-Python TensorRT runtime (``TRTEngine`` + ``tensorrt::execute_engine_python``). The default is ``False``, which uses the C++ runtime when available and falls back to the Python runtime automatically when the C++ runtime is unavailable.
use_python_runtime: (bool): **Deprecated**. Kept for backward compatibility; emits a ``DeprecationWarning`` when set to ``True``. The Python and C++ runtimes are now merged and the runtime is selected automatically based on whether the C++ Torch-TensorRT runtime is available.
use_fast_partitioner: (bool): Use the adjacency based partitioning scheme instead of the global partitioner. Adjacency partitioning is faster but may not be optimal. Use the global paritioner (``False``) if looking for best performance
enable_experimental_decompositions (bool): Use the full set of operator decompositions. These decompositions may not be tested but serve to make the graph easier to convert to TensorRT, potentially increasing the amount of graphs run in TensorRT.
dryrun (bool): Toggle for "Dryrun" mode, running everything except conversion to TRT and logging outputs
Expand Down Expand Up @@ -1344,6 +1362,16 @@ def convert_exported_program_to_serialized_trt_engine(
stacklevel=2,
)

if use_python_runtime:
warnings.warn(
"`use_python_runtime` is deprecated and has no effect. The Python and C++ "
"runtimes have been merged; the runtime is now selected automatically based "
"on whether the C++ Torch-TensorRT runtime is available. This argument will "
"be removed in a future release.",
DeprecationWarning,
stacklevel=2,
)

if "refit" in kwargs.keys():
warnings.warn(
"`refit` is deprecated. Please set `immutable_weights=False` to build a refittable engine whose weights can be refitted",
Expand Down Expand Up @@ -1473,7 +1501,6 @@ def convert_exported_program_to_serialized_trt_engine(
"use_distributed_mode_trace": use_distributed_mode_trace,
"decompose_attention": decompose_attention,
"attn_bias_is_causal": attn_bias_is_causal,
"use_python_runtime": use_python_runtime,
}
if "runtime_cache_path" in compilation_options:
compilation_options.pop("runtime_cache_path")
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
DECOMPOSE_ATTENTION = False
ATTN_BIAS_IS_CAUSAL = True
DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy"
USE_PYTHON_RUNTIME = False

if platform.system() == "Linux":
import pwd
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,8 +472,6 @@ def inline_trt_modules(
continue
# Get the TRT submodule
trt_module = getattr(gm, name)
if trt_module._use_python_runtime:
raise ValueError("Python runtime is not supported for serialization")

# Ensure the trt module node in the main graph (gm) has inputs
trt_module_node = [node for node in gm.graph.nodes if node.name == name]
Expand Down
4 changes: 1 addition & 3 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
USE_DISTRIBUTED_MODE_TRACE,
USE_FAST_PARTITIONER,
USE_FP32_ACC,
USE_PYTHON_RUNTIME,
VERSION_COMPATIBLE,
WORKSPACE_SIZE,
default_device,
Expand Down Expand Up @@ -118,7 +117,6 @@ class CompilationSettings:
dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines
decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True.
attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False.
use_python_runtime (bool): Force the pure-Python TensorRT runtime (``TRTEngine`` + ``tensorrt::execute_engine_python``). When ``False`` (default) the C++ runtime is used if available and the Python runtime is used as a fallback otherwise.
"""

workspace_size: int = WORKSPACE_SIZE
Expand Down Expand Up @@ -181,7 +179,6 @@ class CompilationSettings:
dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES
decompose_attention: bool = DECOMPOSE_ATTENTION
attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL
use_python_runtime: bool = USE_PYTHON_RUNTIME

def __getstate__(self) -> dict[str, Any]:
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
Expand All @@ -196,6 +193,7 @@ def __getstate__(self) -> dict[str, Any]:
return state

def __setstate__(self, state: dict[str, Any]) -> None:
state.pop("use_python_runtime", None)
self.__dict__.update(state)


Expand Down
Loading
Loading