Skip to content

**[None][feat] Remote-G2 failed-transfer handling: structured load-error reporting, recompute fallback, and initiator-side abort**#4

Open
cheese-head wants to merge 3 commits into
oandreeva-nv:oandreeva/remote-g2-beta-registryfrom
cheese-head:expand-transfer-protocol
Open

**[None][feat] Remote-G2 failed-transfer handling: structured load-error reporting, recompute fallback, and initiator-side abort**#4
cheese-head wants to merge 3 commits into
oandreeva-nv:oandreeva/remote-g2-beta-registryfrom
cheese-head:expand-transfer-protocol

Conversation

@cheese-head

Copy link
Copy Markdown

Description

The remote-G2 KV-cache transfer protocol previously exposed only release() and
is_completed() -> bool. Two failure modes leaked out of that single gap:

  • No FAILED state. NIXL collapses every non-success, non-in-progress status
    into a single ERR, but the connector never inspected it, so a mid-flight
    transfer error was indistinguishable from "still in progress" until the 30 s
    transfer timeout fired. When trouble was finally noticed, get_finished()
    raised a RuntimeError straight into the executor's main loop — not a
    per-request-recoverable place to throw.
  • No initiator-side abort. A cancelled or preempted request could not tear
    down its in-flight transfer or lease, leaking the NIXL transfer and leaving
    the request tracked as loading.

Both stem from the contract being too narrow. This PR expands the connector
contract once (mirroring the kvbm-v2 / vLLM failure-notification shape: error
reporting + scheduler callback + abort hook) to close both.

What changed

  • Transfer layer (remote_g2_raw_nixl_adapter.py, remote_g2_transfer.py):
    added is_failed() (reads NIXL state; ERR/poll-exception = hard failure, so
    failures are detected immediately instead of waiting out the timeout) and
    abort(). Modern NIXL has no standalone abort_xfer()release_xfer_handle()
    cancels an active transfer before freeing it — so abort() surfaces that
    primitive through a distinct, idempotent entry point.
  • Connector (remote_g2_connector.py): get_finished() no longer raises.
    On failed state / poll exception / timeout it records the load's
    target_block_ids, releases handle + lease exactly once, emits the
    local_recompute fallback event, and keeps polling other loads. Added
    get_block_ids_with_load_errors() (drain-once) and abort_request().
  • Generic contract (kv_cache_connector.py): KvCacheConnectorWorker gains
    get_block_ids_with_load_errors() (default []) and abort_request()
    (default no-op). KvCacheConnectorManager gains handle_load_errors()
    (mpi_allgather failed block ids across ranks, map to affected requests via
    get_cache_indices), recompute_failed_load() (rewind), and abort_request().
  • Executor (py_executor.py): _kv_connector_handle_load_errors() runs in
    forward() after save — affected requests fall back to local recompute via
    recompute_failed_load(), or are terminated only if the rewind cannot apply.
    _handle_canceled_requests() now calls abort_request() for in-flight
    connector loads.
  • Build (cpp/cmake/modules/FindNIXL.cmake): supporting NIXL discovery
    changes for the GDS/remote transfer path.
  • Docs (docs/source/features/remote_g2_partial_prefix_recovery.md): design
    note for the deferred Phase 2 partial-prefix recovery.

Recovery semantics. A parked request had a computed prefix of
local_prefix + external_match tokens. On load failure, recompute_failed_load()
sets the request to CONTEXT_INIT, rewinds context_current_position /
prepopulated_prompt_len down to local_prefix, zeroes the connector
matched-token count, and re-admits the request without re-add_sequence (its
blocks are still allocated). The engine then recomputes the external region
locally instead of failing the request. Remote-G2 never publishes loaded blocks
into the reuse tree, so the in-place blocks need no invalidation.

Scope / non-goals. Phase 1 recomputes the entire external prefix on
failure. Partial-prefix recovery (keep valid blocks up to the first miss) is
deferred because NIXL exposes only a single aggregate status per batched handle,
so which blocks landed is unobservable without per-block transfers. The deferred
design is documented; the rewind machinery added here is reused by Phase 2 with
a different target (local + K*block_size).

Test Coverage

tests/unittest/_torch/test_remote_g2_connector.py (27 tests pass; runs in the
dynamo dev container with the C++ bindings):

  • Failure path (non-raising contract):
    • test_remote_g2_worker_failure_releases_once_and_publishes_nothing
    • test_remote_g2_worker_timeout_releases_transfer_and_lease_once
    • test_remote_g2_worker_emits_fallback_before_validity_or_publication
    • test_remote_g2_worker_emits_failed_after_validity_is_marked
  • Structured failure reporting (FAILED state + drain-once):
    • test_remote_g2_worker_failed_state_does_not_propagate_and_reports_blocks
    • test_remote_g2_worker_failed_block_ids_drained_once_per_step
  • Initiator-side abort:
    • test_abort_request_calls_result_abort
    • test_abort_request_releases_lease_once
    • test_abort_request_releases_transfer_handle_once
    • test_abort_request_after_completion_is_noop
    • test_abort_request_unknown_id_is_noop
  • Recompute fallback (real manager method, run in a fresh subprocess):
    • test_recompute_failed_load_rewinds_to_local_prefix

Run:

python -m pytest tests/unittest/_torch/test_remote_g2_connector.py -v

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: Patrick Riel <priel@nvidia.com>
Signed-off-by: Patrick Riel <priel@nvidia.com>
… shape: error reporting + scheduler callback + abort hook

Signed-off-by: Patrick Riel <priel@nvidia.com>
Comment on lines +647 to +651
failed_block_ids = self.worker.get_block_ids_with_load_errors()
all_failed = mpi_allgather(failed_block_ids)
failed_set = set(
block_id for rank_ids in all_failed for block_id in rank_ids
)

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please verify the following for TP>1 :
handle_load_errors calls mpi_allgather which requires all TP ranks to participate simultaneously. With TP>1, kv_connector_manager only exists on rank 0 — non-rank-0 workers hit the if self.kv_connector_manager is None: return guard and skip the allgather, leaving rank 0 blocked indefinitely.

The existing get_finished path handles this correctly via _run_on_leader (rank 0 runs, mpi_broadcast to others). handle_load_errors should follow the same pattern — or kv_connector_manager needs to be present on all TP ranks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants