From 7b908d49912560bb67b2541135b98504fe83fc79 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Thu, 21 May 2026 13:16:43 -0700 Subject: [PATCH] fix: respect caller's CUDA stream in TRT runtime (Green Context support) Rebased on top of main (post #4222 Python runtime rework). The original PR #4232 modified _PythonTorchTensorRTModule.py which was removed by #4222; the same stream-selection logic has been ported into the new _TRTEngine class. Behavior: - If the caller is on the default CUDA stream, keep legacy behavior: run the engine on a dedicated pool stream and synchronise via wait_stream. - If the caller is on a non-default stream (e.g. attached to a CUDA Green Context for SM partitioning), honor it: reuse the caller's stream for the engine and skip the wait_stream pair (saving extra syncs and preserving the caller's scheduling choice end to end). - When CUDA graphs are enabled and the engine stream changes between invocations, trigger graph recapture via runtime_states.context_changed so we don't replay a graph recorded against a stale stream. Implementation: - core/runtime/execute_engine.cpp: same logic in C++ (cudagraphs replay path + standard path). - py/torch_tensorrt/dynamo/runtime/_TRTEngine.py: new _prepare_streams helper called from both _execute_standard and _execute_output_allocator. - py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py: same logic in the cudagraphs wrapper module. --- core/runtime/execute_engine.cpp | 100 +++++++++++------- .../runtime/_CudaGraphsTorchTensorRTModule.py | 41 +++++-- .../dynamo/runtime/_TRTEngine.py | 84 +++++++++++---- 3 files changed, 158 insertions(+), 67 deletions(-) diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index 3310bbcede..88901126e2 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -243,6 +243,28 @@ std::vector execute_engine(std::vector inputs, c10::intr bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS); bool shape_changed = _validate_shapes(inputs, compiled_engine); + auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device(); + auto default_stream = c10::cuda::getDefaultCUDAStream(current_device_id); + auto previous_caller_stream = compiled_engine->caller_stream; + auto previous_engine_stream = compiled_engine->engine_stream; + compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); + bool caller_on_default = (compiled_engine->caller_stream == default_stream); + if (caller_on_default) { + // Refresh on first call or after the previous call ran on the caller's + // non-default stream (which is no longer current). + if (previous_engine_stream == default_stream || previous_engine_stream == previous_caller_stream) { + compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + } + } else { + // Honor caller's non-default stream so its scheduling choice (e.g. SM + // partitioning via a CUDA Green Context) is preserved end to end. + compiled_engine->engine_stream = compiled_engine->caller_stream; + } + if (cudagraphs_enabled && compiled_engine->engine_stream != previous_engine_stream) { + // Captured CUDA graph was recorded against the old stream; force re-record. + compiled_engine->runtime_states.context_changed = true; + } + // Whether cudagraphs needs to record the graph on this pass auto result = compiled_engine->runtime_states.set_runtime_states( cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed); @@ -310,19 +332,6 @@ std::vector execute_engine(std::vector inputs, c10::intr } } - auto current_device_id = -1; - if (inputs.size() > 0) { - current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else if (outputs.size() > 0) { - current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart - } - - compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); - } - compiled_engine->record_active_input_tensor_stream_usage( cudagraphs_enabled ? compiled_engine->caller_stream : compiled_engine->engine_stream); @@ -335,10 +344,12 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); + if (caller_on_default) { + // Block engine stream until results are available on caller stream + at::cuda::CUDAEvent caller_exec_complete; + caller_exec_complete.record(compiled_engine->caller_stream); + caller_exec_complete.block(compiled_engine->engine_stream); + } if (!cudagraphs_enabled) { // Direct execution uses the caller buffers directly @@ -371,10 +382,12 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine); } - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); + if (caller_on_default) { + // Block caller stream until engine execution is complete + at::cuda::CUDAEvent trt_exec_complete; + trt_exec_complete.record(compiled_engine->engine_stream); + trt_exec_complete.block(compiled_engine->caller_stream); + } if (cudagraphs_enabled) { // If in CUDAGraph mode, copy persistent staging outputs to returned tensors on the caller stream. @@ -421,17 +434,22 @@ std::vector execute_engine(std::vector inputs, c10::intr create_output_allocator(compiled_engine); } - auto current_device_id = -1; - if (inputs.size() > 0) { - current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart - } else { - current_device_id = at::cuda::current_device(); - } - + auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device(); + auto default_stream = c10::cuda::getDefaultCUDAStream(current_device_id); + auto previous_caller_stream = compiled_engine->caller_stream; + auto previous_engine_stream = compiled_engine->engine_stream; compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id); - if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) { - // Create a new stream if the engine stream is the default stream - compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + bool caller_on_default = (compiled_engine->caller_stream == default_stream); + if (caller_on_default) { + // Refresh on first call or after the previous call ran on the caller's + // non-default stream (which is no longer current). + if (previous_engine_stream == default_stream || previous_engine_stream == previous_caller_stream) { + compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id); + } + } else { + // Honor caller's non-default stream so its scheduling choice (e.g. SM + // partitioning via a CUDA Green Context) is preserved end to end. + compiled_engine->engine_stream = compiled_engine->caller_stream; } compiled_engine->record_active_input_tensor_stream_usage(compiled_engine->engine_stream); @@ -445,10 +463,12 @@ std::vector execute_engine(std::vector inputs, c10::intr std::make_unique(compiled_engine->enqueue_profile_path); } - // Block engine stream until results are available on caller stream - at::cuda::CUDAEvent caller_exec_complete; - caller_exec_complete.record(compiled_engine->caller_stream); - caller_exec_complete.block(compiled_engine->engine_stream); + if (caller_on_default) { + // Block engine stream until results are available on caller stream + at::cuda::CUDAEvent caller_exec_complete; + caller_exec_complete.record(compiled_engine->caller_stream); + caller_exec_complete.block(compiled_engine->engine_stream); + } // Direct execution uses the caller buffers directly compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream); @@ -457,10 +477,12 @@ std::vector execute_engine(std::vector inputs, c10::intr compiled_engine->clear_active_input_tensors(); - // Block caller stream until engine execution is complete - at::cuda::CUDAEvent trt_exec_complete; - trt_exec_complete.record(compiled_engine->engine_stream); - trt_exec_complete.block(compiled_engine->caller_stream); + if (caller_on_default) { + // Block caller stream until engine execution is complete + at::cuda::CUDAEvent trt_exec_complete; + trt_exec_complete.record(compiled_engine->engine_stream); + trt_exec_complete.block(compiled_engine->caller_stream); + } std::unique_ptr output_profiler_guard; if (compiled_engine->profile_execution) { diff --git a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py index 9e54fbac3d..2fea411070 100644 --- a/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py @@ -121,7 +121,33 @@ def forward( cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode() if cudagraphs_enabled: shape_changed = self.validate_input_shapes(inputs) - need_cudagraphs_record = shape_changed or self.is_weight_streaming_set + + current_device = ( + inputs[0].device + if inputs and inputs[0].is_cuda + else torch.device("cuda", torch.cuda.current_device()) + ) + default_stream = torch.cuda.default_stream(current_device) + previous_caller_stream = self._caller_stream + previous_engine_stream = self._engine_stream + self._caller_stream = torch.cuda.current_stream(current_device) + caller_on_default = self._caller_stream == default_stream + if caller_on_default: + if ( + previous_engine_stream is None + or previous_engine_stream == default_stream + or previous_engine_stream == previous_caller_stream + ): + self._engine_stream = torch.cuda.Stream(current_device) + else: + # Honor caller's non-default stream so its scheduling choice (e.g. SM + # partitioning via a CUDA Green Context) is preserved end to end. + self._engine_stream = self._caller_stream + stream_changed = self._engine_stream != previous_engine_stream + + need_cudagraphs_record = ( + shape_changed or self.is_weight_streaming_set or stream_changed + ) if need_cudagraphs_record: self._reset_captured_graph() self._input_buffers = [None] * len(inputs) @@ -172,14 +198,8 @@ def forward( self._input_buffers, self.compiled_module._in_spec ) - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) + if caller_on_default: + self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): if need_cudagraphs_record: @@ -188,7 +208,8 @@ def forward( self._output_buffers = self.compiled_module(*args, **kwargs) self.cudagraph.replay() # type: ignore - self._caller_stream.wait_stream(self._engine_stream) + if caller_on_default: + self._caller_stream.wait_stream(self._engine_stream) if isinstance(self._output_buffers, (list, tuple)): output_buffers = self._output_buffers diff --git a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py index 74d363752f..1f0d50f397 100644 --- a/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py +++ b/py/torch_tensorrt/dynamo/runtime/_TRTEngine.py @@ -675,9 +675,65 @@ def _profile_section(self, label: str) -> ContextManager[None]: # --- execution --- + def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool: + """Pick the engine stream relative to the caller's current stream. + + If the caller is on the default stream we keep the legacy behavior of + running the engine on a dedicated pool stream (and synchronising via + wait_stream). If the caller is on a non-default stream (e.g. a stream + attached to a CUDA Green Context) we honor it by reusing the caller's + stream for the engine, so the caller's scheduling choice (e.g. SM + partitioning) is preserved end to end and no wait_stream sync is + needed. + + Returns ``caller_on_default`` so call sites can gate the wait_stream + pair on it. Also flips ``runtime_states.context_changed`` whenever + the engine stream changes while CUDA graphs are enabled so the + current invocation re-records the graph against the new stream. + Call sites MUST invoke this before ``runtime_states.set_runtime_states`` + because that call consumes and resets ``context_changed``. + """ + current_device = ( + contiguous_inputs[0].device + if contiguous_inputs and contiguous_inputs[0].is_cuda + else torch.device("cuda", torch.cuda.current_device()) + ) + default_stream = torch.cuda.default_stream(current_device) + previous_caller_stream = self._caller_stream + previous_engine_stream = self._engine_stream + self._caller_stream = torch.cuda.current_stream(current_device) + caller_on_default = self._caller_stream == default_stream + if caller_on_default: + # Refresh on first call or after the previous call ran on the + # caller's non-default stream (which is no longer current). + if ( + previous_engine_stream is None + or previous_engine_stream == default_stream + or previous_engine_stream == previous_caller_stream + ): + self._engine_stream = torch.cuda.Stream(current_device) + else: + # Honor caller's non-default stream so its scheduling choice (e.g. + # SM partitioning via a CUDA Green Context) is preserved end to + # end. + self._engine_stream = self._caller_stream + if ( + torch_tensorrt.runtime.get_cudagraphs_mode() + and self._engine_stream != previous_engine_stream + ): + # Captured CUDA graph was recorded against the old stream. + self.runtime_states.context_changed = True + return caller_on_default + def _execute_standard( self, contiguous_inputs: List[torch.Tensor] ) -> torch.Tensor | Tuple[torch.Tensor, ...]: + # Pick the engine stream BEFORE set_runtime_states so that any + # stream-identity change observed this call flips + # runtime_states.context_changed in time to trigger same-call + # cudagraph recapture (set_runtime_states consumes and resets the + # flag). See PR #4232 and the C++ mirror in execute_engine.cpp. + caller_on_default = self._prepare_streams(contiguous_inputs) shape_changed = self.validate_input_shapes(contiguous_inputs) ( need_cudagraphs_record, @@ -734,14 +790,8 @@ def _execute_standard( self.context.set_tensor_address(output_name, outputs[o].data_ptr()) with self._profile_section("TRTEngine:TensorRTRuntime"): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() - - self._engine_stream.wait_stream(self._caller_stream) + if caller_on_default: + self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): if self.resource_allocation_strategy: self._dynamic_workspace = torch.empty( @@ -770,7 +820,8 @@ def _execute_standard( else: self.context.execute_async_v3(self._engine_stream.cuda_stream) - self._caller_stream.wait_stream(self._engine_stream) + if caller_on_default: + self._caller_stream.wait_stream(self._engine_stream) if self.use_pre_allocated_outputs and ( self.output_tensors_are_unowned @@ -809,18 +860,15 @@ def _execute_output_allocator( f"Failed to set output allocator for {output_name}" ) - with self._profile_section("TRTEngine:TensorRTRuntime"): - self._caller_stream = torch.cuda.current_stream() - if ( - self._engine_stream == torch.cuda.default_stream() - or self._engine_stream is None - ): - self._engine_stream = torch.cuda.Stream() + caller_on_default = self._prepare_streams(contiguous_inputs) - self._engine_stream.wait_stream(self._caller_stream) + with self._profile_section("TRTEngine:TensorRTRuntime"): + if caller_on_default: + self._engine_stream.wait_stream(self._caller_stream) with torch.cuda.stream(self._engine_stream): self.context.execute_async_v3(self._engine_stream.cuda_stream) - self._caller_stream.wait_stream(self._engine_stream) + if caller_on_default: + self._caller_stream.wait_stream(self._engine_stream) outputs = [] assert self.output_allocator is not None