Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 61 additions & 39 deletions core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,28 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
bool cudagraphs_enabled = (CUDAGRAPHS_MODE == SUBGRAPH_CUDAGRAPHS);
bool shape_changed = _validate_shapes(inputs, compiled_engine);

auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device();
auto default_stream = c10::cuda::getDefaultCUDAStream(current_device_id);
auto previous_caller_stream = compiled_engine->caller_stream;
auto previous_engine_stream = compiled_engine->engine_stream;
compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
bool caller_on_default = (compiled_engine->caller_stream == default_stream);
if (caller_on_default) {
// Refresh on first call or after the previous call ran on the caller's
// non-default stream (which is no longer current).
if (previous_engine_stream == default_stream || previous_engine_stream == previous_caller_stream) {
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
}
} else {
// Honor caller's non-default stream so its scheduling choice (e.g. SM
// partitioning via a CUDA Green Context) is preserved end to end.
compiled_engine->engine_stream = compiled_engine->caller_stream;
}
if (cudagraphs_enabled && compiled_engine->engine_stream != previous_engine_stream) {
// Captured CUDA graph was recorded against the old stream; force re-record.
compiled_engine->runtime_states.context_changed = true;
}

// Whether cudagraphs needs to record the graph on this pass
auto result = compiled_engine->runtime_states.set_runtime_states(
cudagraphs_enabled, compiled_engine->use_pre_allocated_outputs, shape_changed);
Expand Down Expand Up @@ -310,19 +332,6 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
}
}

auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else if (outputs.size() > 0) {
current_device_id = outputs[0].device().index(); // Done this way to avoid a call to cudart
}

compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
}

compiled_engine->record_active_input_tensor_stream_usage(
cudagraphs_enabled ? compiled_engine->caller_stream : compiled_engine->engine_stream);

Expand All @@ -335,10 +344,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);
if (caller_on_default) {
// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);
}

if (!cudagraphs_enabled) {
// Direct execution uses the caller buffers directly
Expand Down Expand Up @@ -371,10 +382,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
compiled_engine->pre_allocated_outputs = create_output_tensors(compiled_engine);
}

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);
if (caller_on_default) {
// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);
}

if (cudagraphs_enabled) {
// If in CUDAGraph mode, copy persistent staging outputs to returned tensors on the caller stream.
Expand Down Expand Up @@ -421,17 +434,22 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
create_output_allocator(compiled_engine);
}

auto current_device_id = -1;
if (inputs.size() > 0) {
current_device_id = inputs[0].device().index(); // Done this way to avoid a call to cudart
} else {
current_device_id = at::cuda::current_device();
}

auto current_device_id = inputs.size() > 0 ? inputs[0].device().index() : at::cuda::current_device();
auto default_stream = c10::cuda::getDefaultCUDAStream(current_device_id);
auto previous_caller_stream = compiled_engine->caller_stream;
auto previous_engine_stream = compiled_engine->engine_stream;
compiled_engine->caller_stream = c10::cuda::getCurrentCUDAStream(current_device_id);
if (compiled_engine->engine_stream == c10::cuda::getDefaultCUDAStream(current_device_id)) {
// Create a new stream if the engine stream is the default stream
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
bool caller_on_default = (compiled_engine->caller_stream == default_stream);
if (caller_on_default) {
// Refresh on first call or after the previous call ran on the caller's
// non-default stream (which is no longer current).
if (previous_engine_stream == default_stream || previous_engine_stream == previous_caller_stream) {
compiled_engine->engine_stream = c10::cuda::getStreamFromPool(false, current_device_id);
}
} else {
// Honor caller's non-default stream so its scheduling choice (e.g. SM
// partitioning via a CUDA Green Context) is preserved end to end.
compiled_engine->engine_stream = compiled_engine->caller_stream;
}

compiled_engine->record_active_input_tensor_stream_usage(compiled_engine->engine_stream);
Expand All @@ -445,10 +463,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
std::make_unique<torch::autograd::profiler::RecordProfile>(compiled_engine->enqueue_profile_path);
}

// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);
if (caller_on_default) {
// Block engine stream until results are available on caller stream
at::cuda::CUDAEvent caller_exec_complete;
caller_exec_complete.record(compiled_engine->caller_stream);
caller_exec_complete.block(compiled_engine->engine_stream);
}

// Direct execution uses the caller buffers directly
compiled_engine->exec_ctx->enqueueV3(compiled_engine->engine_stream);
Expand All @@ -457,10 +477,12 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr

compiled_engine->clear_active_input_tensors();

// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);
if (caller_on_default) {
// Block caller stream until engine execution is complete
at::cuda::CUDAEvent trt_exec_complete;
trt_exec_complete.record(compiled_engine->engine_stream);
trt_exec_complete.block(compiled_engine->caller_stream);
}

std::unique_ptr<torch::autograd::profiler::RecordProfile> output_profiler_guard;
if (compiled_engine->profile_execution) {
Expand Down
41 changes: 31 additions & 10 deletions py/torch_tensorrt/dynamo/runtime/_CudaGraphsTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,33 @@ def forward(
cudagraphs_enabled = torch_tensorrt.runtime.get_whole_cudagraphs_mode()
if cudagraphs_enabled:
shape_changed = self.validate_input_shapes(inputs)
need_cudagraphs_record = shape_changed or self.is_weight_streaming_set

current_device = (
inputs[0].device
if inputs and inputs[0].is_cuda
else torch.device("cuda", torch.cuda.current_device())
)
default_stream = torch.cuda.default_stream(current_device)
previous_caller_stream = self._caller_stream
previous_engine_stream = self._engine_stream
self._caller_stream = torch.cuda.current_stream(current_device)
caller_on_default = self._caller_stream == default_stream
if caller_on_default:
if (
previous_engine_stream is None
or previous_engine_stream == default_stream
or previous_engine_stream == previous_caller_stream
):
self._engine_stream = torch.cuda.Stream(current_device)
else:
# Honor caller's non-default stream so its scheduling choice (e.g. SM
# partitioning via a CUDA Green Context) is preserved end to end.
self._engine_stream = self._caller_stream
stream_changed = self._engine_stream != previous_engine_stream

need_cudagraphs_record = (
shape_changed or self.is_weight_streaming_set or stream_changed
)
if need_cudagraphs_record:
self._reset_captured_graph()
self._input_buffers = [None] * len(inputs)
Expand Down Expand Up @@ -172,14 +198,8 @@ def forward(
self._input_buffers, self.compiled_module._in_spec
)

self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()

self._engine_stream.wait_stream(self._caller_stream)
if caller_on_default:
self._engine_stream.wait_stream(self._caller_stream)

with torch.cuda.stream(self._engine_stream):
if need_cudagraphs_record:
Expand All @@ -188,7 +208,8 @@ def forward(
self._output_buffers = self.compiled_module(*args, **kwargs)

self.cudagraph.replay() # type: ignore
self._caller_stream.wait_stream(self._engine_stream)
if caller_on_default:
self._caller_stream.wait_stream(self._engine_stream)

if isinstance(self._output_buffers, (list, tuple)):
output_buffers = self._output_buffers
Expand Down
84 changes: 66 additions & 18 deletions py/torch_tensorrt/dynamo/runtime/_TRTEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,65 @@ def _profile_section(self, label: str) -> ContextManager[None]:

# --- execution ---

def _prepare_streams(self, contiguous_inputs: List[torch.Tensor]) -> bool:
"""Pick the engine stream relative to the caller's current stream.

If the caller is on the default stream we keep the legacy behavior of
running the engine on a dedicated pool stream (and synchronising via
wait_stream). If the caller is on a non-default stream (e.g. a stream
attached to a CUDA Green Context) we honor it by reusing the caller's
stream for the engine, so the caller's scheduling choice (e.g. SM
partitioning) is preserved end to end and no wait_stream sync is
needed.

Returns ``caller_on_default`` so call sites can gate the wait_stream
pair on it. Also flips ``runtime_states.context_changed`` whenever
the engine stream changes while CUDA graphs are enabled so the
current invocation re-records the graph against the new stream.
Call sites MUST invoke this before ``runtime_states.set_runtime_states``
because that call consumes and resets ``context_changed``.
"""
current_device = (
contiguous_inputs[0].device
if contiguous_inputs and contiguous_inputs[0].is_cuda
else torch.device("cuda", torch.cuda.current_device())
)
default_stream = torch.cuda.default_stream(current_device)
previous_caller_stream = self._caller_stream
previous_engine_stream = self._engine_stream
self._caller_stream = torch.cuda.current_stream(current_device)
caller_on_default = self._caller_stream == default_stream
if caller_on_default:
# Refresh on first call or after the previous call ran on the
# caller's non-default stream (which is no longer current).
if (
previous_engine_stream is None
or previous_engine_stream == default_stream
or previous_engine_stream == previous_caller_stream
):
self._engine_stream = torch.cuda.Stream(current_device)
else:
# Honor caller's non-default stream so its scheduling choice (e.g.
# SM partitioning via a CUDA Green Context) is preserved end to
# end.
self._engine_stream = self._caller_stream
if (
torch_tensorrt.runtime.get_cudagraphs_mode()
and self._engine_stream != previous_engine_stream
):
# Captured CUDA graph was recorded against the old stream.
self.runtime_states.context_changed = True
return caller_on_default

def _execute_standard(
self, contiguous_inputs: List[torch.Tensor]
) -> torch.Tensor | Tuple[torch.Tensor, ...]:
# Pick the engine stream BEFORE set_runtime_states so that any
# stream-identity change observed this call flips
# runtime_states.context_changed in time to trigger same-call
# cudagraph recapture (set_runtime_states consumes and resets the
# flag). See PR #4232 and the C++ mirror in execute_engine.cpp.
caller_on_default = self._prepare_streams(contiguous_inputs)
shape_changed = self.validate_input_shapes(contiguous_inputs)
(
need_cudagraphs_record,
Expand Down Expand Up @@ -734,14 +790,8 @@ def _execute_standard(
self.context.set_tensor_address(output_name, outputs[o].data_ptr())

with self._profile_section("TRTEngine:TensorRTRuntime"):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()

self._engine_stream.wait_stream(self._caller_stream)
if caller_on_default:
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
if self.resource_allocation_strategy:
self._dynamic_workspace = torch.empty(
Expand Down Expand Up @@ -770,7 +820,8 @@ def _execute_standard(
else:
self.context.execute_async_v3(self._engine_stream.cuda_stream)

self._caller_stream.wait_stream(self._engine_stream)
if caller_on_default:
self._caller_stream.wait_stream(self._engine_stream)

if self.use_pre_allocated_outputs and (
self.output_tensors_are_unowned
Expand Down Expand Up @@ -809,18 +860,15 @@ def _execute_output_allocator(
f"Failed to set output allocator for {output_name}"
)

with self._profile_section("TRTEngine:TensorRTRuntime"):
self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
or self._engine_stream is None
):
self._engine_stream = torch.cuda.Stream()
caller_on_default = self._prepare_streams(contiguous_inputs)

self._engine_stream.wait_stream(self._caller_stream)
with self._profile_section("TRTEngine:TensorRTRuntime"):
if caller_on_default:
self._engine_stream.wait_stream(self._caller_stream)
with torch.cuda.stream(self._engine_stream):
self.context.execute_async_v3(self._engine_stream.cuda_stream)
self._caller_stream.wait_stream(self._engine_stream)
if caller_on_default:
self._caller_stream.wait_stream(self._engine_stream)

outputs = []
assert self.output_allocator is not None
Expand Down
Loading