From ae32dc4e4d732181a19f3dfdcb101a40b7d21160 Mon Sep 17 00:00:00 2001 From: Kartica Modi Date: Fri, 12 Jun 2026 07:04:55 -0700 Subject: [PATCH 1/5] [core] Fixing ray check errors due to double ray.cancel()/keyboard interrupts (#63663) This PR builds on https://github.com/ray-project/ray/pull/61102 and resolves this stack trace mentioned in the PR. For quick overview, this is the bug: After many workers were OOM-killed, regular task workers crashed with: Check failed: objects_valid 1 return objects expected, 1 returned. Object at idx 0 was not stored. The CHECK in `TaskReceiver::HandleTaskExecutionResult` fires when the task execution handler returns `Status::OK()` but a return object slot is still `nullptr`. The linked PR explains how the exception could be caused due to a double `ray.cancel()` on the task. Although I couldn't find double cancellation triggered on any worker when checking the failing job logs, double `ray.cancel()` could nonetheless generate this exception (repro in the added test case in the PR). So while the exact trigger of this incident might be something else that produces the same `status=OK` / `return_objects[i].second=nullptr` state, the fixes proposed in the PR are definitely good safeguards that should be added. --------- Signed-off-by: Kartica Modi Signed-off-by: kartica.modi --- python/ray/_raylet.pyx | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0933767614ad..b9f0963f2200 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -2363,6 +2363,32 @@ cdef CRayStatus task_execution_handler( msg = "Unexpected exception raised in task execution handler: {}".format(e) logger.error(msg) return CRayStatus.UnexpectedSystemExit(msg) + except BaseException as e: + # Safety net: any BaseException that is not Exception or SystemExit + # (e.g. KeyboardInterrupt, GeneratorExit) would otherwise escape this + # cdef function. Without this, Cython silently returns + # CRayStatus.OK() for unhandled non-Exception/non-SystemExit + # exceptions, causing a CHECK failure in HandleTaskExecutionResult + # when return objects are not populated. + # Convert to UnexpectedSystemExit so the C++ side + # treats this as a clean worker-exiting task failure. + # + # The motivating case is a rapid double `ray.cancel()`. The first + # cancel raises a KeyboardInterrupt that is caught by + # `execute_task_with_cancellation_handler`'s + # `except KeyboardInterrupt` clause, which calls + # `store_task_errors`. If a second cancel arrives while + # `store_task_errors` is running, it queues another SIGINT that + # fires inside the error-storage path. That KeyboardInterrupt + # cannot be re-caught (we are already inside `except + # KeyboardInterrupt`), so it escapes all the way out to this + # handler. + msg = ( + "BaseException escaped task execution handlers: " + f"{type(e).__name__}: {e}" + ) + logger.error(msg) + return CRayStatus.UnexpectedSystemExit(msg) return CRayStatus.OK() From 1faaaa7cb28b58278f2550db0f63325fb266755a Mon Sep 17 00:00:00 2001 From: ZTE Ray Date: Fri, 12 Jun 2026 22:15:14 +0800 Subject: [PATCH 2/5] [Autoscaler] Support environment variable configuration for log rotation and deduplicate label deprecation warnings (#63955) ## Motivation This PR addresses two operational issues in the KubeRay autoscaler: 1. **Log Rotation Flexibility**: Operators need to customize log rotation settings based on their deployment environment's disk space and retention policies. Hardcoded values don't work well across different environments (development vs production, small vs large clusters). 2. **Log Noise Reduction**: The repeated deprecation warning for `rayStartParams.labels` clutters logs, making it harder to identify genuine issues. Since KubeRay v1.5+ recommends using the top-level `Labels` field, the warning should inform users once without spamming. ## Implementation Details ### Files Modified - `python/ray/autoscaler/_private/kuberay/run_autoscaler.py` - `python/ray/autoscaler/_private/kuberay/autoscaling_config.py` ### Changes #### 1. Environment Variable Configuration for Log Rotation **File**: `run_autoscaler.py` Added support for `RAY_ROTATION_MAX_BYTES` and `RAY_ROTATION_BACKUP_COUNT` environment variables: ```python # Before (hardcoded) setup_component_logger( max_bytes=LOGGING_ROTATE_BYTES, backup_count=LOGGING_ROTATE_BACKUP_COUNT, ) # After (environment variable override) max_bytes = int(os.getenv("RAY_ROTATION_MAX_BYTES", LOGGING_ROTATE_BYTES)) backup_count = int(os.getenv("RAY_ROTATION_BACKUP_COUNT", LOGGING_ROTATE_BACKUP_COUNT)) setup_component_logger( max_bytes=max_bytes, backup_count=backup_count, ) ``` **Usage Example**: ```yaml # In RayCluster CR spec: headGroupSpec: rayStartParams: RAY_ROTATION_MAX_BYTES: "52428800" # 50MB RAY_ROTATION_BACKUP_COUNT: "10" # Keep 10 backups ``` #### 2. Deduplicate Label Deprecation Warning **File**: `autoscaling_config.py` Added `log_once()` to ensure the warning is printed only once: ```python # Before (repeated on every iteration) if labels_str: logger.warning(...) # After (printed once) if labels_str and log_once("raystartparams_labels_warning"): logger.warning(...) ``` The `log_once()` function from `ray.util.debug` uses an internal flag to ensure the message is logged only the first time the condition is met. ### Breaking Changes None. This is a backward-compatible enhancement: - Environment variables are optional (fallback to existing constants) - Warning behavior is reduced (from repeated to once), which is an improvement ## Verification ### Test Log Rotation Configuration 1. Deploy a RayCluster with custom environment variables: ```bash kubectl set env deployment/raycluster-kuberay-head \ RAY_ROTATION_MAX_BYTES=10485760 \ RAY_ROTATION_BACKUP_COUNT=5 ``` 2. Generate log activity and verify rotation occurs at 10MB instead of default 3. Verify only 5 backup files are retained ### Test Label Warning Deduplication 1. Deploy a RayCluster with deprecated `rayStartParams.labels`: ```yaml workerGroupSpecs: - rayStartParams: labels: "{\"app\": \"myapp\"}" ``` 2. Wait for multiple autoscaler iterations (e.g., 1 minute) 3. Check autoscaler logs: ```bash kubectl logs -c autoscaler | grep "Ignoring labels" ``` 4. **Expected**: Warning appears only once 5. **Previous behavior**: Warning appeared every 5 seconds ### Unit Tests No new tests added - this is a configuration enhancement that doesn't change core autoscaler logic. Existing tests should continue to pass. ## Related Issues Close #63954 --------- Signed-off-by: daiping8 --- .../_private/kuberay/autoscaling_config.py | 39 ++++++++++++++++--- .../_private/kuberay/run_autoscaler.py | 8 ++-- .../tests/kuberay/test_autoscaling_config.py | 8 +++- 3 files changed, 44 insertions(+), 11 deletions(-) diff --git a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py index e61f453d87c8..9b09c1b6f8ec 100644 --- a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py +++ b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -18,6 +18,7 @@ ) from ray.autoscaler._private.kuberay import node_provider, utils from ray.autoscaler._private.util import validate_config +from ray.util.debug import log_once logger = logging.getLogger(__name__) @@ -194,6 +195,8 @@ def _node_type_from_group_spec( group_spec: Dict[str, Any], is_head: bool ) -> Dict[str, Any]: """Converts CR group spec to autoscaler node type.""" + group_name = _HEAD_GROUP_NAME if is_head else group_spec["groupName"] + if is_head: # The head node type has no workers because the head is not a worker. min_workers = max_workers = 0 @@ -204,7 +207,7 @@ def _node_type_from_group_spec( max_workers = group_spec["maxReplicas"] * group_spec.get("numOfHosts", 1) resources = _get_ray_resources_from_group_spec(group_spec, is_head) - labels = _get_labels_from_group_spec(group_spec) + labels = _get_labels_from_group_spec(group_spec, group_name) node_type = { "min_workers": min_workers, @@ -308,21 +311,45 @@ def _get_ray_resources_from_group_spec( return resources -def _get_labels_from_group_spec(group_spec: Dict[str, Any]) -> Dict[str, str]: +def _get_labels_from_group_spec( + group_spec: Dict[str, Any], group_name: str = "" +) -> Dict[str, str]: """ Parses Ray node labels for the autoscaling config based on the following priority: 1. Top-level `labels` field in the group spec. 2. `labels` field in `rayStartParams`. + + Args: + group_spec: The group specification dictionary. + group_name: The name of the group (used in warning messages). + + Returns: + A dictionary of labels for the node type. """ labels_dict = {} ray_start_params = group_spec.get("rayStartParams", {}) labels_str = ray_start_params.get("labels") - if labels_str: - logger.warning( - f"Ignoring labels: {labels_str} set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5" - ) + # Use a unique log_once key per group to ensure each group's warning is shown. + log_once_key = ( + f"raystartparams_labels_warning_{group_name}" + if group_name + else "raystartparams_labels_warning" + ) + if labels_str and log_once(log_once_key): + if group_name: + logger.warning( + f"Ignoring labels: {labels_str} set in rayStartParams for group " + f"'{group_name}'. Group labels are supported in the top-level " + "Labels field starting in KubeRay v1.5" + ) + else: + logger.warning( + f"Ignoring labels: {labels_str} set in rayStartParams. " + "Group labels are supported in the top-level Labels field " + "starting in KubeRay v1.5" + ) # Check for top-level structured Labels field. if "labels" in group_spec and isinstance(group_spec.get("labels"), dict): diff --git a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py index 6013f9b423a7..2d626b21bf9e 100644 --- a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py +++ b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py @@ -9,7 +9,7 @@ LOGGING_ROTATE_BACKUP_COUNT, LOGGING_ROTATE_BYTES, ) -from ray._common.utils import try_to_create_directory +from ray._common.utils import env_integer, try_to_create_directory from ray._private import ray_constants from ray._private.ray_logging import setup_component_logger from ray._private.services import get_node_ip_address @@ -119,13 +119,15 @@ def _setup_logging(log_dir: str) -> None: try_to_create_directory(log_dir) # Write logs at info level to monitor.log. + max_bytes = env_integer("RAY_ROTATION_MAX_BYTES", LOGGING_ROTATE_BYTES) + backup_count = env_integer("RAY_ROTATION_BACKUP_COUNT", LOGGING_ROTATE_BACKUP_COUNT) setup_component_logger( logging_level=ray_constants.LOGGER_LEVEL, logging_format=ray_constants.LOGGER_FORMAT, log_dir=log_dir, filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log - max_bytes=LOGGING_ROTATE_BYTES, - backup_count=LOGGING_ROTATE_BACKUP_COUNT, + max_bytes=max_bytes, + backup_count=backup_count, ) # For the autoscaler, the root logger _also_ needs to write to stderr, not just diff --git a/python/ray/tests/kuberay/test_autoscaling_config.py b/python/ray/tests/kuberay/test_autoscaling_config.py index eb557c558818..e4bcc3094015 100644 --- a/python/ray/tests/kuberay/test_autoscaling_config.py +++ b/python/ray/tests/kuberay/test_autoscaling_config.py @@ -517,7 +517,7 @@ def test_resource_quantity(input: str, output: int): _get_basic_autoscaling_config(), None, None, - "Ignoring labels: ray.io/accelerator-type=TPU-V4 set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", + "Ignoring labels: ray.io/accelerator-type=TPU-V4 set in rayStartParams for group 'tpu-group'. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", id="groups-with-raystartparam-labels", ), pytest.param( @@ -525,7 +525,7 @@ def test_resource_quantity(input: str, output: int): _get_autoscaling_config_with_top_level_labels(), None, None, - "Ignoring labels: instance-type=n2 set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", + "Ignoring labels: instance-type=n2 set in rayStartParams for group 'small-group'. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", id="groups-with-top-level-labels", ), pytest.param( @@ -566,6 +566,10 @@ def test_autoscaling_config( expected_log_warning: Optional[str], ): ray_cr_in["metadata"]["namespace"] = "default" + # Reset log_once state to ensure each test case is independent. + from ray.util.debug import _logged + + _logged.clear() with mock.patch(f"{AUTOSCALING_CONFIG_MODULE_PATH}.logger") as mock_logger: if expected_error: with pytest.raises(expected_error, match=expected_error_message): From ccd2c5c0e6fd7f55a1b87123aa123cce64d28705 Mon Sep 17 00:00:00 2001 From: Rueian Date: Fri, 12 Jun 2026 08:33:37 -0700 Subject: [PATCH 3/5] [core] impl _num_objects_per_yield (#63943) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Ray Data implements data processing operators with streaming generators in a special way that yields 2 times in a single iteration: first, the block, then the block metadata. This works with today’s task-based streaming generator backpressure if the quota is 2 (`_generator_backpressure_num_objects`=2), which allows both yields to arrive at the caller without blocking. However, Ray Data is moving to actor-based backpressure, where multiple streaming generators on the same actor share the same backpressure quota. The problem is that with shared backpressure quota, one yield from a streaming generator can be blocked by another yield in another streaming generator. For example, if a caller launches 2 concurrent streaming generators on the actor at the same time, they will not be able to yield the metadata after yielding the block until the caller takes those blocks out because the quota is already drained. However, the caller can’t handle those blocks alone without their metadata. It will need another iteration to take all the metadata out. The result of this is that we will have bad caller performance since it needs more round trips to make progress on these concurrent generators. ### Solution To solve this problem, we need a way to atomically count multiple yields for backpressure. Introducing `_num_objects_per_yield`, a new private option that declares the number of ray references to unpack for each yield, similar to the current num_returns on a Ray task. ```python @ray.remote(_num_objects_per_yield =2) def generator(): for _ in range(10): yield block, meta gen = generator.remore() assert ray.get(gen._next_sync()) == block assert ray.get(gen._next_sync()) == meta ``` --------- Signed-off-by: Rueian Huang --- cpp/src/ray/runtime/task/task_executor.cc | 1 + cpp/src/ray/runtime/task/task_executor.h | 1 + python/ray/_common/ray_option_utils.py | 10 + python/ray/_raylet.pyx | 179 ++++++++++++------ python/ray/actor.py | 63 ++++++ python/ray/includes/common.pxd | 5 +- python/ray/includes/libcoreworker.pxd | 3 +- python/ray/remote_function.py | 16 ++ python/ray/tests/test_streaming_generator.py | 121 ++++++++++++ .../test_streaming_generator_backpressure.py | 13 ++ src/ray/common/task/task_spec.cc | 5 + src/ray/common/task/task_spec.h | 2 + src/ray/common/task/task_util.h | 4 +- src/ray/core_worker/common.h | 5 + src/ray/core_worker/core_worker.cc | 43 +++-- src/ray/core_worker/core_worker.h | 13 +- src/ray/core_worker/core_worker_options.h | 2 + src/ray/core_worker/generator_waiter.cc | 6 +- src/ray/core_worker/generator_waiter.h | 2 +- .../java/io_ray_runtime_RayNativeRuntime.cc | 1 + ...io_ray_runtime_task_NativeTaskSubmitter.cc | 1 + src/ray/core_worker/task_manager.cc | 21 +- src/ray/core_worker/tests/core_worker_test.cc | 1 + .../core_worker/tests/task_manager_test.cc | 2 +- src/ray/protobuf/common.proto | 3 + src/ray/protobuf/core_worker.proto | 4 +- 26 files changed, 432 insertions(+), 95 deletions(-) diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index ced9277fd825..e39fd8c9b5a5 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -140,6 +140,7 @@ Status TaskExecutor::ExecuteTask( bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index a3bd2964c3e1..bb14254bffd7 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -97,6 +97,7 @@ class TaskExecutor { bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, /* This is used by the in-actor RDT object store. However, it is only supported in * the Python frontend. */ const std::optional &tensor_transport); diff --git a/python/ray/_common/ray_option_utils.py b/python/ray/_common/ray_option_utils.py index 6185025c9b86..166223a4f5c8 100644 --- a/python/ray/_common/ray_option_utils.py +++ b/python/ray/_common/ray_option_utils.py @@ -221,6 +221,16 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "whenever `next` is called). Use -1 to disable this feature. " ), ), + "_num_objects_per_yield": Option( + (int, type(None)), + lambda x: None + if (x is None or x > 0) + else ( + "_num_objects_per_yield is a private streaming generator option " + "that must be set to a positive integer." + ), + default_value=1, + ), } _actor_only_options = { diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index b9f0963f2200..5539795b3453 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1088,6 +1088,7 @@ cdef class StreamingGeneratorExecutionContext: c_bool *is_retryable_error c_string *application_error shared_ptr[CGeneratorBackpressureWaiter] waiter + int64_t num_objects_per_yield def initialize(self, generator: Union[Generator, AsyncGenerator]): # We couldn't make this a part of `make` method because @@ -1120,6 +1121,7 @@ cdef class StreamingGeneratorExecutionContext: c_bool *is_retryable_error, c_string *application_error, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, ): cdef StreamingGeneratorExecutionContext self = ( StreamingGeneratorExecutionContext()) @@ -1141,6 +1143,7 @@ cdef class StreamingGeneratorExecutionContext: self.is_retryable_error = is_retryable_error self.application_error = application_error self.should_retry_exceptions = should_retry_exceptions + self.num_objects_per_yield = num_objects_per_yield self.waiter = make_shared[CGeneratorBackpressureWaiter]( generator_backpressure_num_objects, @@ -1167,28 +1170,53 @@ cdef report_streaming_generator_output( context: Streaming generator's execution context. output: The output yielded from a generator or raised as an exception. - generator_index: index of the output element in the - generated sequence + generator_index: The first ObjectRef stream index for this yield. """ worker = ray._private.worker.global_worker cdef: - # Ray Object created from an output. - c_pair[CObjectID, shared_ptr[CRayObject]] return_obj + # Ray Objects created from an output. + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] return_objs + size_t i + c_bool output_error_reported = False start = time.perf_counter() # Report the intermediate result if there was no error. - create_generator_return_obj( - output, - context.generator_id, - worker, - context.caller_address, - context.task_id, - context.return_size, - generator_index, - context.is_async, - &return_obj) + try: + create_generator_return_objs( + output, + context.generator_id, + worker, + context.caller_address, + context.task_id, + context.return_size, + generator_index, + context.num_objects_per_yield, + context.is_async, + &return_objs) + except Exception as e: + if ( + context.num_objects_per_yield == 1 + or return_objs.size() != context.num_objects_per_yield + ): + raise + + # Dynamic IDs for this grouped yield are already allocated. If storing + # failed after some objects were written, report the whole group as + # non-retryable so the caller does not block waiting for later stream + # indexes and the allocated IDs can be cleared. + context.is_retryable_error[0] = False + store_task_errors( + worker, e, + True, # task_exception + context.actor, # actor + context.actor_id, # actor id + context.function_name, context.task_type, context.title, + context.caller_address, + &return_objs, + context.application_error) + output_error_reported = True # Del output here so that we can GC the memory # usage asap. @@ -1199,22 +1227,25 @@ cdef report_streaming_generator_output( if interrupt_signal_event is not None and interrupt_signal_event.is_set(): return - context.streaming_generator_returns[0].push_back( - c_pair[CObjectID, c_bool]( - return_obj.first, - is_plasma_object(return_obj.second))) + for i in range(return_objs.size()): + context.streaming_generator_returns[0].push_back( + c_pair[CObjectID, c_bool]( + return_objs[i].first, + is_plasma_object(return_objs[i].second))) serialization_dur_s = time.perf_counter() - start with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - return_obj, + return_objs, context.generator_id, context.caller_address, generator_index, context.attempt_number, context.waiter)) + if output_error_reported: + return None return StreamingGeneratorStats( object_creation_dur_s=serialization_dur_s, @@ -1233,14 +1264,13 @@ cdef report_streaming_generator_exception( context: Streaming generator's execution context. output_or_exception: The output yielded from a generator or raised as an exception. - generator_index: index of the output element in the - generated sequence + generator_index: The ObjectRef stream index for this exception. """ worker = ray._private.worker.global_worker cdef: # Ray Object created from an output. - c_pair[CObjectID, shared_ptr[CRayObject]] return_obj + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] return_objs create_generator_error_object( e, @@ -1258,7 +1288,7 @@ cdef report_streaming_generator_exception( generator_index, context.is_async, context.should_retry_exceptions, - &return_obj, + &return_objs, context.is_retryable_error, context.application_error ) @@ -1274,12 +1304,12 @@ cdef report_streaming_generator_exception( context.streaming_generator_returns[0].push_back( c_pair[CObjectID, c_bool]( - return_obj.first, - is_plasma_object(return_obj.second))) + return_objs[0].first, + is_plasma_object(return_objs[0].second))) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - return_obj, + return_objs, context.generator_id, context.caller_address, generator_index, @@ -1321,9 +1351,12 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context # next output output = gen.send(stats) # Track serialization duration of the next output - stats = report_streaming_generator_output(context, output, gen_index, None) + stats = report_streaming_generator_output( + context, output, gen_index, None) + if stats is None: + break - gen_index += 1 + gen_index += context.num_objects_per_yield except StopIteration: break @@ -1405,7 +1438,9 @@ async def execute_streaming_generator_async( cur_generator_index, interrupt_signal_event, ) - cur_generator_index += 1 + if stats is None: + break + cur_generator_index += context.num_objects_per_yield except StopAsyncIteration: break @@ -1444,7 +1479,7 @@ async def execute_streaming_generator_async( check_status(return_status) -cdef create_generator_return_obj( +cdef create_generator_return_objs( output, const CObjectID &generator_id, worker: "Worker", @@ -1452,9 +1487,10 @@ cdef create_generator_return_obj( TaskID task_id, return_size, generator_index, + int64_t num_objects_per_yield, is_async, - c_pair[CObjectID, shared_ptr[CRayObject]] *return_object): - """Create a generator return object based on a given output. + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *return_objects): + """Create generator return objects based on a given output. Args: output: The output from a next(generator). @@ -1465,33 +1501,55 @@ cdef create_generator_return_obj( the owner, so we can also call it "owner address". task_id: The task ID of the generator task. return_size: The number of static returns. - generator_index: The index of a current error object. + generator_index: The first ObjectRef stream index for this yield. + num_objects_per_yield: The number of ObjectRefs to create for each yield. is_async: Whether or not the given object is created within an async actor. - return_object(out): A Ray Object that contains the given output. + return_objects(out): Ray Objects that contain the given output. """ cdef: - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker + int64_t stream_index + int64_t i + CObjectID return_id - return_id = core_worker.allocate_dynamic_return_id_for_generator( - caller_address, - task_id.native(), - return_size, - generator_index, - is_async, - ) - intermediate_result.push_back( + if num_objects_per_yield == 1: + outputs = (output,) + else: + if not isinstance(output, (tuple, list)): + raise ValueError( + "Streaming generator tasks with _num_objects_per_yield=" + f"{num_objects_per_yield} must yield a tuple or list " + f"of length {num_objects_per_yield}." + ) + if len(output) != num_objects_per_yield: + raise ValueError( + "Streaming generator task yielded " + f"{len(output)} objects, but _num_objects_per_yield=" + f"{num_objects_per_yield}." + ) + outputs = output + + return_objects.reserve(num_objects_per_yield) + for i in range(num_objects_per_yield): + stream_index = generator_index + i + return_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + stream_index, + is_async, + ) + return_objects.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) + core_worker.store_task_outputs( - worker, [output], + worker, outputs, caller_address, - &intermediate_result, + return_objects, generator_id.Binary()) - return_object[0] = intermediate_result.back() - cdef create_generator_error_object( e: Exception, @@ -1509,7 +1567,7 @@ cdef create_generator_error_object( generator_index, is_async, c_bool should_retry_exceptions, - c_pair[CObjectID, shared_ptr[CRayObject]] *error_object, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *error_objects, c_bool *is_retryable_error, c_string *application_error): """Create a generator error object. @@ -1538,17 +1596,16 @@ cdef create_generator_error_object( It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. return_size: The number of static returns. - generator_index: The index of a current error object. + generator_index: The ObjectRef stream index for this exception. is_async: Whether or not the given object is created within an async actor. - error_object(out): A Ray Object that contains the given error exception. + error_objects(out): Ray Objects that contain the given error exception. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an application error. """ cdef: - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker is_retryable_error[0] = determine_if_retryable( @@ -1578,7 +1635,7 @@ cdef create_generator_error_object( generator_index, is_async, ) - intermediate_result.push_back( + error_objects.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) store_task_errors( @@ -1588,11 +1645,9 @@ cdef create_generator_error_object( actor_id, # actor id function_name, task_type, title, caller_address, - &intermediate_result, + error_objects, application_error) - error_object[0] = intermediate_result.back() - cdef execute_dynamic_generator_and_store_task_outputs( generator, @@ -1701,6 +1756,7 @@ cdef void execute_task( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport) except *: worker = ray._private.worker.global_worker manager = worker.function_actor_manager @@ -1874,7 +1930,8 @@ cdef void execute_task( streaming_generator_returns, is_retryable_error, application_error, - generator_backpressure_num_objects) + generator_backpressure_num_objects, + num_objects_per_yield) # We cannot pass generator to cdef in Cython for some reasons. # It is a workaround. context.initialize(outputs) @@ -2068,6 +2125,7 @@ cdef execute_task_with_cancellation_handler( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport): is_retryable_error[0] = False @@ -2161,6 +2219,7 @@ cdef execute_task_with_cancellation_handler( is_streaming_generator, should_retry_exceptions, generator_backpressure_num_objects, + num_objects_per_yield, c_tensor_transport) # Check for cancellation. @@ -2279,6 +2338,7 @@ cdef CRayStatus task_execution_handler( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport) nogil: with gil, disable_client_hook(): # Initialize job_config if it hasn't already. @@ -2309,6 +2369,7 @@ cdef CRayStatus task_execution_handler( is_streaming_generator, should_retry_exceptions, generator_backpressure_num_objects, + num_objects_per_yield, c_tensor_transport) except Exception as e: sys_exit = SystemExit() @@ -3507,6 +3568,7 @@ cdef class CoreWorker: c_string debugger_breakpoint, c_string serialized_runtime_env_info, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_bool enable_task_events, labels, label_selector, @@ -3553,6 +3615,7 @@ cdef class CoreWorker: name, num_returns, c_resources, b"", generator_backpressure_num_objects, + num_objects_per_yield, serialized_runtime_env_info, enable_task_events, c_labels, @@ -3787,6 +3850,7 @@ cdef class CoreWorker: double num_method_cpus, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_bool enable_task_events, tensor_transport: Optional[str], dict labels=None): @@ -3843,6 +3907,7 @@ cdef class CoreWorker: c_resources, concurrency_group_name, generator_backpressure_num_objects, + num_objects_per_yield, serialized_runtime_env, enable_task_events, c_labels, @@ -4020,6 +4085,7 @@ cdef class CoreWorker: method_meta.max_task_retries, method_meta.retry_exceptions, method_meta.generator_backpressure_num_objects, # noqa + method_meta.num_objects_per_yield, method_meta.enable_task_events, enable_tensor_transport, method_meta.method_name_to_tensor_transport, @@ -4039,6 +4105,7 @@ cdef class CoreWorker: {}, # method max_task_retries {}, # method retry_exceptions {}, # generator_backpressure_num_objects + {}, # num_objects_per_yield {}, # enable_task_events False, # enable_tensor_transport None, # method_name_to_tensor_transport diff --git a/python/ray/actor.py b/python/ray/actor.py index 253ffa1e6107..2753b773c2c6 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -616,6 +616,7 @@ def method( max_task_retries: Optional[int] = None, retry_exceptions: Optional[Union[bool, list, tuple]] = None, _generator_backpressure_num_objects: Optional[int] = None, + _num_objects_per_yield: Optional[int] = None, enable_task_events: Optional[bool] = None, tensor_transport: Optional[str] = None, ) -> _MethodDecorator: @@ -693,6 +694,7 @@ def bar(self): "max_task_retries", "retry_exceptions", "_generator_backpressure_num_objects", + "_num_objects_per_yield", "enable_task_events", "tensor_transport", ] @@ -717,6 +719,11 @@ def annotate_method(method: Callable[_P, _Ret]): method.__ray_generator_backpressure_num_objects__ = kwargs[ "_generator_backpressure_num_objects" ] + if "_num_objects_per_yield" in kwargs: + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", kwargs["_num_objects_per_yield"] + ) + method.__ray_num_objects_per_yield__ = kwargs["_num_objects_per_yield"] if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: @@ -770,6 +777,7 @@ def __init__( retry_exceptions: Union[bool, list, tuple], is_generator: bool, generator_backpressure_num_objects: int, + num_objects_per_yield: int, enable_task_events: bool, decorator: Optional[Any] = None, signature: Optional[List[inspect.Parameter]] = None, @@ -787,6 +795,7 @@ def __init__( retry_exceptions: Boolean or list/tuple of exceptions to retry. is_generator: True if the method is a generator. generator_backpressure_num_objects: Generator-only config for backpressure. + num_objects_per_yield: Private generator-only config for grouped yields. enable_task_events: True if task events are enabled for this method. decorator: Optional decorator for the method invocation. signature: The signature of the actor method. @@ -805,6 +814,9 @@ def __init__( self._retry_exceptions = retry_exceptions self._is_generator = is_generator self._generator_backpressure_num_objects = generator_backpressure_num_objects + self._num_objects_per_yield = ( + 1 if num_objects_per_yield is None else num_objects_per_yield + ) self._enable_task_events = enable_task_events self._decorator = decorator self._signature = signature @@ -822,6 +834,7 @@ def bind(self, actor_handle: "ActorHandle") -> "ActorMethod": self._retry_exceptions, self._is_generator, self._generator_backpressure_num_objects, + self._num_objects_per_yield, self._enable_task_events, decorator=self._decorator, signature=self._signature, @@ -848,6 +861,7 @@ def __init__( retry_exceptions: Union[bool, list, tuple], is_generator: bool, generator_backpressure_num_objects: int, + num_objects_per_yield: int, enable_task_events: bool, decorator: Optional[Callable] = None, signature: Optional[List[inspect.Parameter]] = None, @@ -869,6 +883,7 @@ def __init__( generator_backpressure_num_objects: Generator-only config. If a number of unconsumed objects reach this threshold, the actor task stops pausing. + num_objects_per_yield: Private generator-only config for grouped yields. enable_task_events: True if task events is enabled, i.e., task events from the actor should be reported. Defaults to True. decorator: An optional decorator that should be applied to the actor @@ -892,6 +907,9 @@ def __init__( self._retry_exceptions = retry_exceptions self._is_generator = is_generator self._generator_backpressure_num_objects = generator_backpressure_num_objects + self._num_objects_per_yield = ( + 1 if num_objects_per_yield is None else num_objects_per_yield + ) self._enable_task_events = enable_task_events self._signature = signature # This is a decorator that is used to wrap the function invocation (as @@ -945,6 +963,11 @@ def options(self, **options: Any): A wrapper exposing ``.remote()`` / ``.bind()`` that applies the given options when the method is invoked. """ + if "_num_objects_per_yield" in options: + raise ValueError( + "_num_objects_per_yield cannot be overridden per actor method " + "call. Use @ray.method(_num_objects_per_yield=...) instead." + ) tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: @@ -967,6 +990,7 @@ def _bind( num_returns=None, concurrency_group=None, _generator_backpressure_num_objects=None, + _num_objects_per_yield=None, ) -> Union["ray.dag.ClassMethodNode", Tuple["ray.dag.ClassMethodNode", ...]]: from ray.dag.class_node import ( BIND_INDEX_KEY, @@ -983,6 +1007,11 @@ def _bind( "concurrency_group": concurrency_group, "_generator_backpressure_num_objects": _generator_backpressure_num_objects, } + if _num_objects_per_yield is not None: + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", _num_objects_per_yield + ) + options["_num_objects_per_yield"] = _num_objects_per_yield actor = self._actor if actor is None: @@ -1049,6 +1078,7 @@ def _remote( retry_exceptions=None, concurrency_group=None, _generator_backpressure_num_objects=None, + _num_objects_per_yield=None, enable_task_events=None, tensor_transport: Optional[str] = None, _labels: Optional[Dict[str, str]] = None, @@ -1067,6 +1097,11 @@ def _remote( _generator_backpressure_num_objects = ( self._generator_backpressure_num_objects ) + if _num_objects_per_yield is None: + _num_objects_per_yield = self._num_objects_per_yield + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", _num_objects_per_yield + ) if tensor_transport is None: tensor_transport = self._tensor_transport @@ -1120,6 +1155,7 @@ def invocation(args, kwargs): generator_backpressure_num_objects=( _generator_backpressure_num_objects ), + num_objects_per_yield=_num_objects_per_yield, enable_task_events=enable_task_events, tensor_transport=tensor_transport, labels=_labels, @@ -1149,6 +1185,7 @@ def __getstate__(self): "decorator": self._decorator, "is_generator": self._is_generator, "generator_backpressure_num_objects": self._generator_backpressure_num_objects, # noqa + "num_objects_per_yield": self._num_objects_per_yield, "enable_task_events": self._enable_task_events, "_tensor_transport": self._tensor_transport, } @@ -1162,6 +1199,7 @@ def __setstate__(self, state): state["retry_exceptions"], state["is_generator"], state["generator_backpressure_num_objects"], + state.get("num_objects_per_yield", 1), state["enable_task_events"], state["decorator"], state["_tensor_transport"], @@ -1250,6 +1288,7 @@ def create( self.method_is_generator = {} self.enable_task_events = {} self.generator_backpressure_num_objects = {} + self.num_objects_per_yield = {} self.concurrency_group_for_methods = {} self.method_name_to_tensor_transport: Dict[str, str] = {} @@ -1317,6 +1356,10 @@ def create( self.generator_backpressure_num_objects[ method_name ] = method.__ray_generator_backpressure_num_objects__ + if hasattr(method, "__ray_num_objects_per_yield__"): + self.num_objects_per_yield[ + method_name + ] = method.__ray_num_objects_per_yield__ if hasattr(method, "__ray_tensor_transport__"): self.method_name_to_tensor_transport[ @@ -2175,6 +2218,7 @@ def _remote( meta.method_meta.max_task_retries, meta.method_meta.retry_exceptions, meta.method_meta.generator_backpressure_num_objects, + meta.method_meta.num_objects_per_yield, meta.method_meta.enable_task_events, meta.enable_tensor_transport, meta.method_meta.method_name_to_tensor_transport, @@ -2237,6 +2281,8 @@ class ActorHandle(Generic[T]): _ray_method_generator_backpressure_num_objects: Generator-only config. The max number of objects to generate before it starts pausing a generator. + _ray_method_num_objects_per_yield: Private generator-only config. + The number of ObjectRefs produced by each streaming generator yield. _ray_method_enable_task_events: The value of whether task tracing is enabled for the actor methods. This overrides the actor's default value (`_ray_enable_task_events`). @@ -2273,6 +2319,7 @@ def __init__( method_max_task_retries: Dict[str, int], method_retry_exceptions: Dict[str, Union[bool, list, tuple]], method_generator_backpressure_num_objects: Dict[str, int], + method_num_objects_per_yield: Dict[str, int], method_enable_task_events: Dict[str, bool], enable_tensor_transport: bool, method_name_to_tensor_transport: Dict[str, str], @@ -2297,6 +2344,7 @@ def __init__( method_max_task_retries: Dictionary mapping method names to their maximum task retries. method_retry_exceptions: Dictionary mapping method names to their retry exception settings. method_generator_backpressure_num_objects: Dictionary mapping method names to their generator backpressure settings. + method_num_objects_per_yield: Dictionary mapping method names to their grouped-yield arity. method_enable_task_events: Dictionary mapping method names to whether task events are enabled. enable_tensor_transport: Whether tensor transport is enabled for this actor. If True, then methods can be called with @@ -2327,6 +2375,7 @@ def __init__( self._ray_method_generator_backpressure_num_objects = ( method_generator_backpressure_num_objects ) + self._ray_method_num_objects_per_yield = method_num_objects_per_yield self._ray_method_enable_task_events = method_enable_task_events self._ray_enable_tensor_transport = enable_tensor_transport self._ray_method_name_to_tensor_transport = method_name_to_tensor_transport @@ -2372,6 +2421,9 @@ def __init__( generator_backpressure_num_objects=self._ray_method_generator_backpressure_num_objects.get( method_name ), + num_objects_per_yield=self._ray_method_num_objects_per_yield.get( + method_name, 1 + ), enable_task_events=self._ray_method_enable_task_events.get( method_name, self._ray_enable_task_events ), @@ -2412,6 +2464,7 @@ def _actor_method_call( retry_exceptions: Union[bool, list, tuple] = None, concurrency_group_name: Optional[str] = None, generator_backpressure_num_objects: Optional[int] = None, + num_objects_per_yield: Optional[int] = None, enable_task_events: Optional[bool] = None, tensor_transport: Optional[str] = None, labels: Optional[Dict[str, str]] = None, @@ -2435,6 +2488,8 @@ def _actor_method_call( concurrency_group_name: The name of the concurrency group to use. generator_backpressure_num_objects: The number of objects to generate before applying backpressure. + num_objects_per_yield: Private streaming generator option for how many + ObjectRefs each yield should unpack into. enable_task_events: True if tracing is enabled, i.e., task events from the actor should be reported. tensor_transport: The tensor transport protocol to use for the actor method. @@ -2487,6 +2542,8 @@ def _actor_method_call( if generator_backpressure_num_objects is None: generator_backpressure_num_objects = -1 + if num_objects_per_yield is None: + num_objects_per_yield = 1 object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, @@ -2501,6 +2558,7 @@ def _actor_method_call( self._ray_actor_method_cpus, concurrency_group_name if concurrency_group_name is not None else b"", generator_backpressure_num_objects, + num_objects_per_yield, enable_task_events, tensor_transport, labels, @@ -2578,6 +2636,7 @@ def remote(self, *args, **kwargs): False, # retry_exceptions False, # is_generator self._ray_method_generator_backpressure_num_objects.get(item, -1), + self._ray_method_num_objects_per_yield.get(item, 1), self._ray_enable_task_events, # enable_task_events # Currently, cross-lang actor method not support decorator decorator=None, @@ -2658,6 +2717,9 @@ def _serialization_helper(self): "method_generator_backpressure_num_objects": ( self._ray_method_generator_backpressure_num_objects ), + "method_num_objects_per_yield": ( + self._ray_method_num_objects_per_yield + ), "method_enable_task_events": self._ray_method_enable_task_events, "enable_tensor_transport": self._ray_enable_tensor_transport, "method_name_to_tensor_transport": self._ray_method_name_to_tensor_transport, @@ -2716,6 +2778,7 @@ def _deserialization_helper( state["method_max_task_retries"], state["method_retry_exceptions"], state["method_generator_backpressure_num_objects"], + state.get("method_num_objects_per_yield", {}), state["method_enable_task_events"], state["enable_tensor_transport"], state["method_name_to_tensor_transport"], diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index b3fd284b13d0..e6687cd053b7 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -383,16 +383,19 @@ cdef extern from "ray/core_worker/common.h" nogil: CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, - int64_t generator_backpressure_num_objects) + int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield) CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_string serialized_runtime_env) CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_string serialized_runtime_env, c_bool enable_task_events, const unordered_map[c_string, c_string] &labels, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 60a99d85f85a..d3c416476a8b 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -315,7 +315,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_vector[shared_ptr[CObjectLocation]] *results) CRayStatus TriggerGlobalGC() CRayStatus ReportGeneratorItemReturns( - const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, + const c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] &dynamic_return_objects, const CObjectID &generator_id, const CAddress &caller_address, int64_t item_index, @@ -407,6 +407,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] tensor_transport ) nogil) task_execution_callback (void(const CObjectID &) nogil) free_actor_object_callback diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 57fae3e65551..1ee1bf26a604 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -281,6 +281,15 @@ def f(): # Task g will require 2 gpus instead of 1. g = f.options(num_gpus=2) """ + if "_num_objects_per_yield" in task_options: + num_objects_per_yield = ( + self._default_options.get("_num_objects_per_yield") or 1 + ) + if task_options["_num_objects_per_yield"] != num_objects_per_yield: + raise ValueError( + "_num_objects_per_yield cannot be overridden per task call. " + "Use @ray.remote(_num_objects_per_yield=...) instead." + ) func_cls = self @@ -435,6 +444,12 @@ def _remote( ] if generator_backpressure_num_objects is None: generator_backpressure_num_objects = -1 + num_objects_per_yield = task_options["_num_objects_per_yield"] + if num_objects_per_yield is None: + num_objects_per_yield = 1 + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", num_objects_per_yield + ) max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] @@ -517,6 +532,7 @@ def invocation(args, kwargs): worker.debugger_breakpoint, serialized_runtime_env_info or "{}", generator_backpressure_num_objects, + num_objects_per_yield, enable_task_events, labels, label_selector, diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 666efe6b7a40..34154f427adc 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -538,6 +538,127 @@ async def verify_async_task_async_generator(): asyncio.run(verify_async_task_async_generator()) +def test_streaming_generator_num_objects_per_yield(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + for i in range(3): + stats = yield i, f"metadata-{i}" + assert stats is None or stats.object_creation_dur_s >= 0 + + gen = generator.remote() + for i in range(3): + assert ray.get(next(gen)) == i + assert ray.get(next(gen)) == f"metadata-{i}" + + with pytest.raises(StopIteration): + next(gen) + + @ray.remote + def per_call(): + yield 1, 2 + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + per_call.options(_num_objects_per_yield=2).remote() + + +def test_actor_streaming_generator_num_objects_per_yield(shutdown_only): + ray.init() + + @ray.remote + class Actor: + @ray.method(_num_objects_per_yield=2) + def decorated(self): + yield "block", "metadata" + + def per_call(self): + yield 1, 2 + + actor = Actor.remote() + + gen = actor.decorated.remote() + assert ray.get(next(gen)) == "block" + assert ray.get(next(gen)) == "metadata" + with pytest.raises(StopIteration): + next(gen) + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + actor.per_call.options(_num_objects_per_yield=2).remote() + + +def test_streaming_generator_num_objects_per_yield_invalid_yield(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + yield (1,) + + gen = generator.remote() + with pytest.raises(ValueError, match="_num_objects_per_yield=2"): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_serialization_failure(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + yield threading.Lock(), 1 + + gen = generator.remote() + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_partial_store_failure( + shutdown_only, +): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + # If a later object fails after an earlier object has been stored, the + # caller should still receive a ref for every object in the grouped yield. + yield 1, threading.Lock() + + gen = generator.remote() + assert ray.get(next(gen)) == 1 + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_failure_not_retried( + shutdown_only, +): + ray.init() + + @ray.remote(_num_objects_per_yield=2, retry_exceptions=True, max_retries=1) + def generator(): + # Once grouped-yield IDs are allocated, the whole group must be + # reported so those temporary refs are cleared instead of retrying. + yield 1, threading.Lock() + + gen = generator.remote() + assert ray.get(next(gen)) == 1 + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + def test_streaming_generator_exception(shutdown_only): # Verify the exceptions are correctly raised. # Also verify the followup next will raise StopIteration. diff --git a/python/ray/tests/test_streaming_generator_backpressure.py b/python/ray/tests/test_streaming_generator_backpressure.py index 6d2a3b3d76e5..a88225754b95 100644 --- a/python/ray/tests/test_streaming_generator_backpressure.py +++ b/python/ray/tests/test_streaming_generator_backpressure.py @@ -248,6 +248,19 @@ async def f(self): def f(): pass + with pytest.raises(ValueError, match="_num_objects_per_yield"): + + @ray.remote(_num_objects_per_yield=0) + def g(): + pass + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + + class Actor: + @ray.method(_num_objects_per_yield=0) + def g(self): + pass + def test_threaded_actor_generator_backpressure(shutdown_only): ray.init() diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index d7b0986038ff..e717eaae5bb7 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -249,6 +249,11 @@ int64_t TaskSpecification::GeneratorBackpressureNumObjects() const { return result; } +int64_t TaskSpecification::NumObjectsPerYield() const { + auto result = message_->num_objects_per_yield(); + return result == 0 ? 1 : result; +} + std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); std::vector dynamic_return_ids; diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index e78184c6eb66..32ff710f1610 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -193,6 +193,8 @@ class TaskSpecification : public MessageWrapper { int64_t GeneratorBackpressureNumObjects() const; + int64_t NumObjectsPerYield() const; + std::vector DynamicReturnIds() const; void AddDynamicReturnId(const ObjectID &dynamic_return_id); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c0f0b2f1eb35..e1afa6e3fe5f 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -160,7 +160,8 @@ class TaskSpecBuilder { const std::unordered_map &labels = {}, const LabelSelector &label_selector = {}, const std::vector &fallback_strategy = - std::vector()) { + std::vector(), + uint64_t num_objects_per_yield = 1) { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); message_->set_language(language); @@ -179,6 +180,7 @@ class TaskSpecBuilder { message_->set_returns_dynamic(returns_dynamic); message_->set_streaming_generator(is_streaming_generator); message_->set_generator_backpressure_num_objects(generator_backpressure_num_objects); + message_->set_num_objects_per_yield(num_objects_per_yield); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 6c91be760265..f1e03824f961 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -70,6 +70,7 @@ struct TaskOptions { std::unordered_map &resources_p, std::string concurrency_group_name_p = "", int64_t generator_backpressure_num_objects_p = -1, + int64_t num_objects_per_yield_p = 1, std::string serialized_runtime_env_info_p = "{}", bool enable_task_events_p = kDefaultTaskEventEnabled, std::unordered_map labels_p = {}, @@ -82,6 +83,7 @@ struct TaskOptions { concurrency_group_name(std::move(concurrency_group_name_p)), serialized_runtime_env_info(std::move(serialized_runtime_env_info_p)), generator_backpressure_num_objects(generator_backpressure_num_objects_p), + num_objects_per_yield(num_objects_per_yield_p), enable_task_events(enable_task_events_p), labels(std::move(labels_p)), label_selector(std::move(label_selector_p)), @@ -104,6 +106,9 @@ struct TaskOptions { /// -1 means either streaming generator is not used or /// it is used but the feature is disabled. int64_t generator_backpressure_num_objects; + /// Only applicable when streaming generator is used. + /// The number of ObjectRefs produced by each generator yield. + int64_t num_objects_per_yield = 1; /// True if task events (worker::TaskEvent) from this task should be reported, default /// to true. bool enable_task_events = kDefaultTaskEventEnabled; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7257247b5344..deb0bc28301a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1835,7 +1835,8 @@ void CoreWorker::BuildCommonTaskSpec( bool enable_task_events, const std::unordered_map &labels, const LabelSelector &label_selector, - const std::vector &fallback_strategy) { + const std::vector &fallback_strategy, + int64_t num_objects_per_yield) { // Build common task spec. auto override_runtime_env_info = OverrideTaskOrActorRuntimeEnvInfo(serialized_runtime_env_info); @@ -1885,7 +1886,8 @@ void CoreWorker::BuildCommonTaskSpec( enable_task_events, labels, label_selector, - fallback_strategy); + fallback_strategy, + num_objects_per_yield); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1966,7 +1968,8 @@ std::vector CoreWorker::SubmitTask( /*enable_task_events=*/task_options.enable_task_events, task_options.labels, task_options.label_selector, - task_options.fallback_strategy); + task_options.fallback_strategy, + task_options.num_objects_per_yield); ActorID root_detached_actor_id; if (!worker_context_->GetRootDetachedActorID().IsNil()) { root_detached_actor_id = worker_context_->GetRootDetachedActorID(); @@ -2404,7 +2407,8 @@ Status CoreWorker::SubmitActorTask( /*enable_task_events=*/task_options.enable_task_events, /*labels=*/task_options.labels, /*label_selector=*/{}, - /*fallback_strategy=*/{}); + /*fallback_strategy=*/{}, + task_options.num_objects_per_yield); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -2881,6 +2885,7 @@ Status CoreWorker::ExecuteTask( /*retry_exception=*/task_spec.ShouldRetryExceptions(), /*generator_backpressure_num_objects=*/ task_spec.GeneratorBackpressureNumObjects(), + /*num_objects_per_yield=*/task_spec.NumObjectsPerYield(), /*tensor_transport=*/task_spec.TensorTransport()); // Get the reference counts for any IDs that we borrowed during this task, @@ -3113,7 +3118,8 @@ ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, } Status CoreWorker::ReportGeneratorItemReturns( - const std::pair> &dynamic_return_object, + const std::vector>> + &dynamic_return_objects, const ObjectID &generator_id, const rpc::Address &owner_address, int64_t item_index, @@ -3126,26 +3132,33 @@ Status CoreWorker::ReportGeneratorItemReturns( request.set_attempt_number(attempt_number); auto client = core_worker_client_pool_->GetOrConnect(owner_address); - // This means it is the last report when the task has finished executing. - if (!dynamic_return_object.first.IsNil()) { + std::vector return_ids; + return_ids.reserve(dynamic_return_objects.size()); + for (const auto &dynamic_return_object : dynamic_return_objects) { + if (dynamic_return_object.first.IsNil()) { + continue; + } SerializeReturnObject(dynamic_return_object.first, dynamic_return_object.second, - request.mutable_returned_object()); - std::vector deleted; + request.add_returned_objects()); + return_ids.push_back(dynamic_return_object.first); + } + if (!return_ids.empty()) { // When we allocate a dynamic return ID (AllocateDynamicReturnId), - // we borrow the object. When the object value is allocatd, the + // we borrow the object. When the object value is allocated, the // memory store is updated. We should clear borrowers and memory store // here. + std::vector deleted; ReferenceCounterInterface::ReferenceTableProto borrowed_refs; - reference_counter_->PopAndClearLocalBorrowers( - {dynamic_return_object.first}, &borrowed_refs, &deleted); + reference_counter_->PopAndClearLocalBorrowers(return_ids, &borrowed_refs, &deleted); memory_store_->Delete(deleted); } - const auto return_id = dynamic_return_object.first; + + const auto return_id = return_ids.empty() ? ObjectID::Nil() : return_ids.front(); RAY_LOG(DEBUG) << "Write the object ref stream, index: " << item_index - << ", id: " << return_id; + << ", id: " << return_id << ", count: " << return_ids.size(); - waiter->IncrementObjectGenerated(); + waiter->IncrementObjectGenerated(return_ids.size()); client->ReportGeneratorItemReturns( std::move(request), diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 83f25904ad08..43f1d34c83af 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -774,12 +774,9 @@ class CoreWorker : public std::enable_shared_from_this { /// NOTE: The API doesn't guarantee the ordering of the report. The /// owner is supposed to reorder the report based on the item_index. /// - /// \param[in] returned_object A intermediate ray object to report - /// to the owner before the task terminates. This object must have been + /// \param[in] returned_objects Intermediate ray objects to report + /// to the owner before the task terminates. These objects must have been /// created dynamically from this worker via AllocateReturnObject. - /// If the Object ID is nil, it means it is the end of the task return. - /// In this case, the owner is responsible for setting finished = true, - /// otherwise it will panic. /// \param[in] generator_id The return object ref ID from a current generator /// task. /// \param[in] owner_address The address of the owner of the current task. @@ -790,7 +787,8 @@ class CoreWorker : public std::enable_shared_from_this { /// \param[in] waiter The class to pause the thread if generator backpressure limit /// is reached. Status ReportGeneratorItemReturns( - const std::pair> &returned_object, + const std::vector>> + &returned_objects, const ObjectID &generator_id, const rpc::Address &owner_address, int64_t item_index, @@ -1428,7 +1426,8 @@ class CoreWorker : public std::enable_shared_from_this { bool enable_task_events = true, const std::unordered_map &labels = {}, const LabelSelector &label_selector = {}, - const std::vector &fallback_strategy = {}); + const std::vector &fallback_strategy = {}, + int64_t num_objects_per_yield = 1); void SetCurrentTaskId(const TaskID &task_id, uint64_t attempt_number, diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 4dc47af27ddc..847931f48810 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -78,6 +78,8 @@ struct CoreWorkerOptions { // The max number of unconsumed objects where a generator // can run without a pause. int64_t generator_backpressure_num_objects, + // The number of ObjectRefs produced by each streaming generator yield. + int64_t num_objects_per_yield, const std::optional &tensor_transport)>; CoreWorkerOptions() diff --git a/src/ray/core_worker/generator_waiter.cc b/src/ray/core_worker/generator_waiter.cc index 132e1247f9d5..875f106a6f10 100644 --- a/src/ray/core_worker/generator_waiter.cc +++ b/src/ray/core_worker/generator_waiter.cc @@ -69,9 +69,11 @@ Status GeneratorBackpressureWaiter::WaitAllObjectsReported() { return return_status; } -void GeneratorBackpressureWaiter::IncrementObjectGenerated() { +void GeneratorBackpressureWaiter::IncrementObjectGenerated( + int64_t num_objects_generated) { + RAY_CHECK_GE(num_objects_generated, 0); absl::MutexLock lock(&mutex_); - total_objects_generated_ += 1; + total_objects_generated_ += num_objects_generated; num_object_reports_in_flight_++; } diff --git a/src/ray/core_worker/generator_waiter.h b/src/ray/core_worker/generator_waiter.h index 4d04c18f34ae..492a0359b5bc 100644 --- a/src/ray/core_worker/generator_waiter.h +++ b/src/ray/core_worker/generator_waiter.h @@ -63,7 +63,7 @@ class GeneratorBackpressureWaiter { /// Increment the number of objects generated. The executor should call this /// before sending an object report to the caller. - void IncrementObjectGenerated(); + void IncrementObjectGenerated(int64_t num_objects_generated = 1); /// Handle a completed object report. The executor should call this after /// receiving an ack from the caller for an object report. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 3d047dbf4e6f..6a4a627408da 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -143,6 +143,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, bool is_streaming_generator, bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) { // These 3 parameters are used for Python only, and Java worker // will not use them. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 363d51234a12..cee2e8182990 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -162,6 +162,7 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio resources, concurrency_group_name, /*generator_backpressure_num_objects*/ -1, + /*num_objects_per_yield*/ 1, serialzied_runtime_env_info}; return task_options; } diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index d26b09570b47..e15c7f3b9607 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -818,13 +818,14 @@ bool TaskManager::HandleReportGeneratorItemReturns( } size_t num_objects_written = 0; - if (request.has_returned_object()) { - const rpc::ReturnObject &returned_object = request.returned_object(); + for (int64_t i = 0; i < request.returned_objects_size(); i++) { + const rpc::ReturnObject &returned_object = request.returned_objects(i); const auto object_id = ObjectID::FromBinary(returned_object.object_id()); + const auto object_index = item_index + i; RAY_LOG(DEBUG) << "Write an object " << object_id << " to the object ref stream of id " << generator_id; - auto index_not_used_yet = stream_it->second.InsertToStream(object_id, item_index); + auto index_not_used_yet = stream_it->second.InsertToStream(object_id, object_index); // If the ref was written to a stream, we should also // own the dynamically generated task return. @@ -849,18 +850,22 @@ bool TaskManager::HandleReportGeneratorItemReturns( // Handle backpressure if needed. auto total_generated = stream_it->second.TotalNumObjectWritten(); auto total_consumed = stream_it->second.TotalNumObjectConsumed(); + auto last_item_index = request.returned_objects_size() == 0 + ? item_index + : item_index + request.returned_objects_size() - 1; - if (stream_it->second.IsObjectConsumed(item_index)) { + if (stream_it->second.IsObjectConsumed(last_item_index)) { execution_signal_callback(Status::OK(), total_consumed); return false; } // Otherwise, follow the regular backpressure logic. - // NOTE, here we check `item_index - last_consumed_index >= backpressure_threshold`, - // instead of the number of unconsumed items, because we may receive the - // `HandleReportGeneratorItemReturns` requests out of order. + // NOTE, here we check `last_item_index - last_consumed_index >= + // backpressure_threshold`, instead of the number of unconsumed items, because we may + // receive the `HandleReportGeneratorItemReturns` requests out of order. if (backpressure_threshold != -1 && - (item_index - stream_it->second.LastConsumedIndex()) >= backpressure_threshold) { + (last_item_index - stream_it->second.LastConsumedIndex()) >= + backpressure_threshold) { RAY_LOG(DEBUG) << "Stream " << generator_id << " is backpressured. total_generated: " << total_generated << ". total_consumed: " << total_consumed diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index db8297d2a1f5..9aee2aa21bfa 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -93,6 +93,7 @@ class CoreWorkerTest : public ::testing::Test { bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) -> Status { return Status::OK(); }; diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index 3e1e91bba447..0b6bdb1855a8 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -88,7 +88,7 @@ rpc::ReportGeneratorItemReturnsRequest GetIntermediateTaskReturn( request.mutable_worker_addr()->CopyFrom(addr); request.set_item_index(idx); request.set_generator_id(generator_id.Binary()); - rpc::ReturnObject *returned_object = request.mutable_returned_object(); + rpc::ReturnObject *returned_object = request.add_returned_objects(); returned_object->set_object_id(dynamic_return_id.Binary()); returned_object->set_data(data->Data(), data->Size()); returned_object->set_in_plasma(set_in_plasma); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 6d293e80f951..ec7a8b92c88e 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -622,6 +622,9 @@ message TaskSpec { optional string tensor_transport = 44; // A list of fallback options defining the fallback strategy for scheduling. FallbackStrategy fallback_strategy = 45; + // Private streaming generator option. The number of ObjectRefs produced by + // each generator yield. Defaults to 1. + uint64 num_objects_per_yield = 46; } message TaskInfoEntry { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 443bb71976ba..5735eee215a2 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -417,8 +417,8 @@ message NumPendingTasksReply { } message ReportGeneratorItemReturnsRequest { - // Object returned from the executor (can be inlined or in plasma). - ReturnObject returned_object = 1; + // Objects returned from the executor (can be inlined or in plasma). + repeated ReturnObject returned_objects = 1; // The address of the executor. Address worker_addr = 2; // The index of the task return. It is used to From 2f0866ccb85f6eb1e1c271c0684595b9fac7f639 Mon Sep 17 00:00:00 2001 From: Lehui Liu Date: Fri, 12 Jun 2026 09:40:49 -0700 Subject: [PATCH 4/5] [train][Preemption handling 1/n] Add preemption watcher for node-drain observability (#63807) --- python/ray/train/v2/BUILD.bazel | 16 ++ .../callbacks/preemption_callback.py | 85 +++++++ python/ray/train/v2/_internal/constants.py | 11 + .../execution/controller/controller.py | 24 ++ .../v2/_internal/execution/preemption.py | 213 ++++++++++++++++++ .../train/v2/tests/test_preemption_watcher.py | 192 ++++++++++++++++ 6 files changed, 541 insertions(+) create mode 100644 python/ray/train/v2/_internal/callbacks/preemption_callback.py create mode 100644 python/ray/train/v2/_internal/execution/preemption.py create mode 100644 python/ray/train/v2/tests/test_preemption_watcher.py diff --git a/python/ray/train/v2/BUILD.bazel b/python/ray/train/v2/BUILD.bazel index bd10115880f4..0b0a17556097 100644 --- a/python/ray/train/v2/BUILD.bazel +++ b/python/ray/train/v2/BUILD.bazel @@ -528,6 +528,22 @@ py_test( ], ) +py_test( + name = "test_preemption_watcher", + size = "small", + srcs = ["tests/test_preemption_watcher.py"], + env = {"RAY_TRAIN_V2_ENABLED": "1"}, + tags = [ + "exclusive", + "team:ml", + "train_v2", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_result", size = "medium", diff --git a/python/ray/train/v2/_internal/callbacks/preemption_callback.py b/python/ray/train/v2/_internal/callbacks/preemption_callback.py new file mode 100644 index 000000000000..d2bc512d32f7 --- /dev/null +++ b/python/ray/train/v2/_internal/callbacks/preemption_callback.py @@ -0,0 +1,85 @@ +import logging +import os +from typing import TYPE_CHECKING, Dict, List, Optional + +import ray +from ray.actor import ActorHandle +from ray.train.v2._internal.constants import ( + DEFAULT_PREEMPTION_POLL_INTERVAL_S, + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, +) +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.preemption import PreemptionWatcher + +if TYPE_CHECKING: + from ray.train.v2._internal.execution.worker_group import ( + WorkerGroup, + WorkerGroupContext, + ) + +logger = logging.getLogger(__name__) + + +class PreemptionCallback(WorkerGroupCallback): + """Manages a :class:`PreemptionWatcher` across worker-group lifecycles. + + Spawns a fresh watcher in :meth:`after_worker_group_start` and stops it on + every teardown path (shutdown and abort). Each worker group gets its own + watcher and failure-domain map, so elastic resizes and restarts never + leak stale state. + """ + + def __init__(self) -> None: + self._poll_interval_s: float = float( + os.getenv( + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, + str(DEFAULT_PREEMPTION_POLL_INTERVAL_S), + ) + ) + self._watcher: Optional[ActorHandle] = None + + def after_worker_group_start(self, worker_group: "WorkerGroup") -> None: + # Tear down any watcher from a previous worker group first. Worker-group + # startup can fail after this hook without running the shutdown hook, so + # this also prevents leaking an orphaned watcher across a reschedule. + self._stop_watcher() + + node_to_ranks: Dict[str, List[int]] = {} + for w in worker_group.get_workers(): + node_to_ranks.setdefault(w.metadata.node_id, []).append( + w.distributed_context.world_rank + ) + + watcher_cls = ray.remote(num_cpus=0, max_restarts=-1)(PreemptionWatcher) + self._watcher = watcher_cls.remote( + node_to_ranks=node_to_ranks, + poll_interval_s=self._poll_interval_s, + ) + + logger.debug( + "PreemptionCallback: started watcher for %d node(s).", + len(node_to_ranks), + ) + + def before_worker_group_shutdown(self, worker_group: "WorkerGroup") -> None: + self._stop_watcher() + + def after_worker_group_abort( + self, worker_group_context: "WorkerGroupContext" + ) -> None: + # abort() doesn't run the shutdown hook, so tear the watcher down here + # too — otherwise it keeps polling GCS until the cluster reaps it. + self._stop_watcher() + + def _stop_watcher(self) -> None: + if self._watcher is None: + return + watcher = self._watcher + self._watcher = None + # Force-kill (non-blocking) rather than a synchronous graceful stop, so + # we never block the controller's event loop. The watcher's daemon poll + # thread dies with the actor process and holds no external resources. + try: + ray.kill(watcher) + except Exception: + logger.warning("Failed to kill PreemptionWatcher actor.", exc_info=True) diff --git a/python/ray/train/v2/_internal/constants.py b/python/ray/train/v2/_internal/constants.py index a116e12ff73a..2bcbbce73e5f 100644 --- a/python/ray/train/v2/_internal/constants.py +++ b/python/ray/train/v2/_internal/constants.py @@ -58,6 +58,15 @@ ) DEFAULT_CHECKPOINT_UPLOAD_WARN_INTERVAL_S: float = 60 +# Feature flag for the preemption watcher. Default-on; provides a quick +# rollback path if the watcher actor misbehaves in a cluster. +ENABLE_PREEMPTION_WATCHER_ENV_VAR = "RAY_TRAIN_ENABLE_PREEMPTION_WATCHER" +DEFAULT_ENABLE_PREEMPTION_WATCHER: bool = True + +# How often the preemption watcher polls Ray Core's drain state. +PREEMPTION_POLL_INTERVAL_S_ENV_VAR = "RAY_TRAIN_PREEMPTION_POLL_INTERVAL_S" +DEFAULT_PREEMPTION_POLL_INTERVAL_S: float = 5.0 + # Environment variable to enable the print function patching. ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH" DEFAULT_ENABLE_PRINT_PATCH = True @@ -118,6 +127,8 @@ STATE_ACTOR_RECONCILIATION_INTERVAL_S_ENV_VAR, RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR, TORCHFT_LIGHTHOUSE_ADDR_ENV_VAR, + ENABLE_PREEMPTION_WATCHER_ENV_VAR, + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, } diff --git a/python/ray/train/v2/_internal/execution/controller/controller.py b/python/ray/train/v2/_internal/execution/controller/controller.py index e7a509bc6cc9..c86345d62a31 100644 --- a/python/ray/train/v2/_internal/execution/controller/controller.py +++ b/python/ray/train/v2/_internal/execution/controller/controller.py @@ -12,8 +12,10 @@ from ray.exceptions import AsyncioActorExit from ray.train.v2._internal.constants import ( DEFAULT_ENABLE_CONTROLLER_LOGGING, + DEFAULT_ENABLE_PREEMPTION_WATCHER, DEFAULT_HEALTH_CHECK_INTERVAL_S, ENABLE_CONTROLLER_STRUCTURED_LOGGING_ENV_VAR, + ENABLE_PREEMPTION_WATCHER_ENV_VAR, HEALTH_CHECK_INTERVAL_S_ENV_VAR, ) from ray.train.v2._internal.execution.callback import ( @@ -188,6 +190,28 @@ def __init__( else False ) + # Register the preemption-observability callback when not in TorchFT + # mode (replica groups handle peer loss via their own quorum). + enable_preemption_watcher = ray_constants.env_bool( + ENABLE_PREEMPTION_WATCHER_ENV_VAR, + DEFAULT_ENABLE_PREEMPTION_WATCHER, + ) + if self._manages_replica_groups: + if enable_preemption_watcher and ray_constants.env_set_by_user( + ENABLE_PREEMPTION_WATCHER_ENV_VAR + ): + logger.info( + "The preemption watcher is not compatible with replica " + "groups (e.g. TorchFT), which handle peer loss via their " + "own quorum; skipping it." + ) + elif enable_preemption_watcher: + from ray.train.v2._internal.callbacks.preemption_callback import ( + PreemptionCallback, + ) + + self._worker_group_callbacks_to_propagate.append(PreemptionCallback()) + self._worker_group: Optional[WorkerGroup] = None self._state = InitializingState() self._return_value: Optional[Any] = None diff --git a/python/ray/train/v2/_internal/execution/preemption.py b/python/ray/train/v2/_internal/execution/preemption.py new file mode 100644 index 000000000000..d8ea6ab078bc --- /dev/null +++ b/python/ray/train/v2/_internal/execution/preemption.py @@ -0,0 +1,213 @@ +import logging +import threading +from dataclasses import dataclass +from typing import Dict, List, Optional, Set + +import ray +from ray.train.v2._internal.constants import DEFAULT_PREEMPTION_POLL_INTERVAL_S +from ray.util.tpu import get_tpu_slice_name_from_node + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PreemptionInfo: + """Information about an imminent preemption event. + + Attributes: + deadline_ms: Earliest preemption deadline (UNIX time in milliseconds) + across all preempted nodes. ``None`` if no deadline was reported. + preempted_node_to_ranks: Map of preempted ``node_id`` to the worker ``world_rank``s affected when that node + is preempted. + """ + + deadline_ms: Optional[int] + preempted_node_to_ranks: Dict[str, List[int]] + + @property + def preempted_node_ids(self) -> List[str]: + """Preempted node IDs, sorted lexicographically.""" + return sorted(self.preempted_node_to_ranks) + + @property + def preempted_ranks(self) -> List[int]: + """All affected ranks across the preempted nodes, sorted ascending.""" + return sorted( + {r for ranks in self.preempted_node_to_ranks.values() for r in ranks} + ) + + +def _get_draining_nodes() -> Dict[str, int]: + """Ray Core's draining nodes as ``{node_id_hex: deadline_ms}`` (0 = no deadline).""" + return ray._private.state.state.get_draining_nodes() + + +class PreemptionWatcher: + """Polls Ray Core for node drains and logs detected preemption events. + + One watcher per worker group, spawned as a ``num_cpus=0`` actor by + ``PreemptionCallback``. The poll loop runs in a background thread. The + failure-domain map is built once on construction and is immutable for the + watcher's lifetime. + + The failure-domain map records which of our ranks are affected if a node is + preempted: for a GPU node, the ranks on that node; for a TPU node, every + rank in the node's slice, since a TPU slice is preempted atomically. + + Args: + node_to_ranks: Map ``node_id_hex -> [ranks on that node]``. Used both + as the set of nodes we care about (drains elsewhere are ignored) + and as the seed for failure-domain expansion. + poll_interval_s: Seconds between drain-state polls. + """ + + def __init__( + self, + node_to_ranks: Dict[str, List[int]], + poll_interval_s: float = DEFAULT_PREEMPTION_POLL_INTERVAL_S, + ): + self._node_to_ranks: Dict[str, List[int]] = { + nid: sorted(ranks) for nid, ranks in node_to_ranks.items() + } + self._poll_interval_s = poll_interval_s + self._failure_domain_map: Dict[str, List[int]] = self._build_failure_domain_map( + self._node_to_ranks + ) + + self._stop_event = threading.Event() + self._last_drained: Dict[str, int] = {} + self._latest_info: Optional[PreemptionInfo] = None + + self._monitor_thread = threading.Thread( + target=self._watch_loop, + name="PreemptionWatcher", + daemon=True, + ) + self._monitor_thread.start() + + @staticmethod + def _build_failure_domain_map( + node_to_ranks: Dict[str, List[int]], + ) -> Dict[str, List[int]]: + """Map each node we host to all ranks in its failure domain. + + - Non-TPU (e.g. GPU) clusters: the failure domain is the node itself, + so a drain on a node flags only the ranks this job runs there. + - TPU multislice: every host in a slice is reclaimed atomically, so a + drain on any host is fate-shared with the rest. + """ + per_node = {nid: sorted(set(ranks)) for nid, ranks in node_to_ranks.items()} + + try: + all_nodes = ray.nodes() + + # Slice label for each node we host (None for non-TPU nodes). + node_to_slice: Dict[str, Optional[str]] = { + node["NodeID"]: get_tpu_slice_name_from_node(node) + for node in all_nodes + if node["NodeID"] in node_to_ranks + } + + # Union our ranks per slice. + slice_to_ranks: Dict[str, Set[int]] = {} + for node_id, ranks in node_to_ranks.items(): + slice_label = node_to_slice.get(node_id) + if slice_label: + slice_to_ranks.setdefault(slice_label, set()).update(ranks) + + # Non-TPU cluster (or none of our nodes are on a slice): per-node. + if not slice_to_ranks: + return per_node + + result: Dict[str, List[int]] = {} + for node_id, ranks in node_to_ranks.items(): + slice_label = node_to_slice.get(node_id) + if slice_label: + result[node_id] = sorted(slice_to_ranks[slice_label]) + else: + result[node_id] = sorted(set(ranks)) + return result + except Exception: + logger.debug( + "Could not build failure-domain map; falling back to per-node " + "domains (no TPU-slice expansion).", + exc_info=True, + ) + return per_node + + def get_latest_preemption_info(self) -> Optional[PreemptionInfo]: + """Most recent :class:`PreemptionInfo` observed, or ``None``.""" + return self._latest_info + + def _watch_loop(self) -> None: + logger.debug( + "PreemptionWatcher polling %d node(s) every %.1fs.", + len(self._node_to_ranks), + self._poll_interval_s, + ) + while not self._stop_event.is_set(): + self._poll_once() + self._stop_event.wait(timeout=self._poll_interval_s) + logger.debug("PreemptionWatcher stopped.") + + def _poll_once(self) -> None: + """Poll the drain source once and dispatch on change. + + Per-poll exceptions are caught and logged so a transient GCS hiccup + doesn't kill the watcher loop. + """ + try: + drained = _get_draining_nodes() or {} + # Keep only drains on this job's own nodes (others are ignored). + # That's complete for TPU — an SPMD job fully occupies its slice, so + # every fate-shared host is one of our nodes and a drain on any slice + # host appears here. For GPU, a drain on a host we don't run on is + # correctly irrelevant. + relevant = { + n: d for n, d in drained.items() if n in self._failure_domain_map + } + if relevant != self._last_drained: + self._on_drain_change(relevant) + self._last_drained = relevant + except Exception: + # TODO(lehui): consider exponential backoff when the drain API keeps + # failing, instead of retrying at the fixed poll interval. + logger.warning("PreemptionWatcher poll failed", exc_info=True) + + def _on_drain_change(self, drained: Dict[str, int]) -> None: + """Handle a change in the drained-node set. + + ``drained`` has already been narrowed to this job's nodes by the + caller (``_poll_once``). + """ + if not drained: + return + + affected_node_ids = sorted(drained.keys()) + preempted_node_to_ranks = { + node_id: self._failure_domain_map[node_id] for node_id in affected_node_ids + } + + # Earliest deadline across the preempted nodes; None if none reported one + # (Ray Core uses 0 for "no deadline", which is falsy and filtered out). + reported_deadlines = [drained[n] for n in affected_node_ids if drained[n]] + deadline_ms = min(reported_deadlines) if reported_deadlines else None + + info = PreemptionInfo( + deadline_ms=deadline_ms, + preempted_node_to_ranks=preempted_node_to_ranks, + ) + self._latest_info = info + + logger.warning( + "PreemptionWatcher: preemption detected — " + "preempted_node_ids=%s, preempted_ranks=%s, deadline_ms=%s", + info.preempted_node_ids, + info.preempted_ranks, + deadline_ms, + ) + # TODO(lehui): forward the detected preemption to the workers so the + # training loop can react to it. + # TODO(lehui): coalesce preemptions seen within one window into a single + # worker-group restart, so a staggered drain (node A at t, node B at + # t+60s) doesn't cause back-to-back restarts. diff --git a/python/ray/train/v2/tests/test_preemption_watcher.py b/python/ray/train/v2/tests/test_preemption_watcher.py new file mode 100644 index 000000000000..32cce03ddfcd --- /dev/null +++ b/python/ray/train/v2/tests/test_preemption_watcher.py @@ -0,0 +1,192 @@ +"""Unit tests for the preemption watcher.""" +from typing import Dict +from unittest.mock import Mock, patch + +import pytest + +import ray +from ray.train.v2._internal.callbacks.preemption_callback import PreemptionCallback +from ray.train.v2._internal.execution.preemption import PreemptionWatcher + +_PREEMPTION_MOD = "ray.train.v2._internal.execution.preemption" + + +def _make_watcher( + node_to_ranks: Dict[str, list], + fd_map: Dict[str, list], +) -> PreemptionWatcher: + """Construct a watcher with a fixed failure-domain map, then halt its loop. + + The background poll thread is stopped (while the GCS drain call is mocked + out) so tests can drive ``_poll_once()`` synchronously and deterministically. + """ + fd = {nid: sorted(ranks) for nid, ranks in fd_map.items()} + with patch.object( + PreemptionWatcher, "_build_failure_domain_map", return_value=fd + ), patch(f"{_PREEMPTION_MOD}._get_draining_nodes", return_value={}): + watcher = PreemptionWatcher(node_to_ranks=node_to_ranks) + watcher._stop_event.set() + watcher._monitor_thread.join(timeout=5) + return watcher + + +def _poll_once_with(watcher: PreemptionWatcher, **patch_kwargs) -> None: + """Drive one poll with the GCS drain call mocked (``return_value``/``side_effect``).""" + with patch(f"{_PREEMPTION_MOD}._get_draining_nodes", **patch_kwargs): + watcher._poll_once() + + +class TestPreemptionWatcher: + @pytest.mark.parametrize( + "fd_map, drained, exp_nodes, exp_ranks, exp_deadline_ms", + [ + # Single drained node. + ({"node-a": [0, 1]}, {"node-a": 30_000}, ["node-a"], [0, 1], 30_000), + # Drains on nodes outside our failure domains are filtered out. + ( + {"node-a": [0, 1]}, + {"node-a": 30_000, "other": 30_000}, + ["node-a"], + [0, 1], + 30_000, + ), + # Multiple of our nodes: ranks aggregated, earliest deadline wins. + ( + {"node-a": [0, 1], "node-b": [2, 3]}, + {"node-a": 45_000, "node-b": 20_000}, + ["node-a", "node-b"], + [0, 1, 2, 3], + 20_000, + ), + # Failure-domain-expanded map: a drain on one host flags the whole + # fate-shared set (e.g. a TPU slice). + ( + {"host-0": [0, 1, 2, 3], "host-1": [0, 1, 2, 3]}, + {"host-0": 30_000}, + ["host-0"], + [0, 1, 2, 3], + 30_000, + ), + # deadline_ms=0 from Ray Core means "no deadline" -> surfaced as None. + ({"node-a": [0]}, {"node-a": 0}, ["node-a"], [0], None), + # A None deadline is also surfaced as None rather than raising. + ({"node-a": [0]}, {"node-a": None}, ["node-a"], [0], None), + ], + ) + def test_drain_produces_info( + self, fd_map, drained, exp_nodes, exp_ranks, exp_deadline_ms + ): + watcher = _make_watcher(node_to_ranks=fd_map, fd_map=fd_map) + _poll_once_with(watcher, return_value=dict(drained)) + + info = watcher.get_latest_preemption_info() + assert info.preempted_node_ids == exp_nodes + assert info.preempted_ranks == exp_ranks + assert info.deadline_ms == exp_deadline_ms + # The node->ranks map keys are the preempted nodes; each maps to its + # failure domain. The flat getters above derive from it. + assert info.preempted_node_to_ranks == {n: sorted(fd_map[n]) for n in exp_nodes} + + def test_no_drain_means_no_info(self): + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with(watcher, return_value={}) + assert watcher.get_latest_preemption_info() is None + + def test_poll_swallows_drain_errors(self): + """A raising drain call must not propagate out of a poll.""" + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with( + watcher, side_effect=RuntimeError("transient") + ) # must not raise + assert watcher.get_latest_preemption_info() is None + + def test_none_drain_result_is_safe(self): + """A drain call returning None is treated as no drains.""" + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with(watcher, return_value=None) # must not raise + assert watcher.get_latest_preemption_info() is None + + +class TestBuildFailureDomainMap: + def test_falls_back_on_ray_nodes_error(self, monkeypatch): + """If ray.nodes() raises, fall back to per-node domains.""" + + def raise_runtime(): + raise RuntimeError("no ray runtime") + + monkeypatch.setattr(ray, "nodes", raise_runtime) + result = PreemptionWatcher._build_failure_domain_map( + {"node-a": [0, 1], "node-b": [2, 3]} + ) + assert result == {"node-a": [0, 1], "node-b": [2, 3]} + + def test_falls_back_on_slice_lookup_error(self): + """If the TPU slice lookup raises, fall back to per-node domains.""" + with patch.object(ray, "nodes", return_value=[{"NodeID": "node-a"}]), patch( + f"{_PREEMPTION_MOD}.get_tpu_slice_name_from_node", + side_effect=RuntimeError("boom"), + ): + result = PreemptionWatcher._build_failure_domain_map({"node-a": [0, 1]}) + assert result == {"node-a": [0, 1]} + + @pytest.mark.parametrize( + "node_to_ranks, slice_labels, cluster_node_ids, expected", + [ + # GPU cluster (no slice labels): per-node domains. node-1 may be + # shared with another workload, but only our ranks are in + # node_to_ranks, so a drain on node-1 flags only our rank there. + ( + {"node-0": [0, 1, 2, 3], "node-1": [4]}, + {}, + ["node-0", "node-1"], + {"node-0": [0, 1, 2, 3], "node-1": [4]}, + ), + # TPU slice fully occupied by this job: a drain on any host is + # fate-shared, so every host maps to the union of the slice's ranks. + ( + {"sa-0": [0, 1], "sa-1": [2, 3]}, + {"sa-0": "A", "sa-1": "A"}, + ["sa-0", "sa-1"], + {"sa-0": [0, 1, 2, 3], "sa-1": [0, 1, 2, 3]}, + ), + ], + ) + def test_build_failure_domain_map( + self, node_to_ranks, slice_labels, cluster_node_ids, expected + ): + fake_nodes = [{"NodeID": nid} for nid in cluster_node_ids] + + def fake_slice_label(node): + return slice_labels.get(node["NodeID"]) + + with patch.object(ray, "nodes", return_value=fake_nodes), patch( + f"{_PREEMPTION_MOD}.get_tpu_slice_name_from_node", + side_effect=fake_slice_label, + ): + result = PreemptionWatcher._build_failure_domain_map(node_to_ranks) + + assert result == expected + + +class TestPreemptionCallbackTeardown: + @pytest.mark.parametrize( + "teardown_hook", + ["before_worker_group_shutdown", "after_worker_group_abort"], + ) + def test_teardown_kills_watcher(self, teardown_hook): + """Both shutdown and abort must kill the watcher (abort skips shutdown).""" + callback = PreemptionCallback() + watcher = Mock() + callback._watcher = watcher + + with patch.object(ray, "kill") as mock_kill: + getattr(callback, teardown_hook)(None) + + mock_kill.assert_called_once_with(watcher) + assert callback._watcher is None + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) From 90c5a3a716899708ab743c6c22316f0923e6b78c Mon Sep 17 00:00:00 2001 From: kourosh hakhamaneshi <31483498+kouroshHakha@users.noreply.github.com> Date: Fri, 12 Jun 2026 11:37:53 -0700 Subject: [PATCH 5/5] [serve.llm] Add MoRIIO KV-connector backend for prefill/decode (#63951) Signed-off-by: Kourosh Hakhamaneshi Co-authored-by: Claude Opus 4.8 (1M context) --- .../serve/engines/vllm/kv_transfer/base.py | 49 ++- .../serve/engines/vllm/kv_transfer/factory.py | 1 + .../serve/engines/vllm/kv_transfer/moriio.py | 342 ++++++++++++++++ .../prefill_decode/builder.py | 30 ++ .../prefill_decode/pd_server.py | 39 +- .../test_moriio_connector.py | 365 ++++++++++++++++++ .../test_prefill_decode_disagg.py | 29 ++ 7 files changed, 844 insertions(+), 11 deletions(-) create mode 100644 python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py create mode 100644 python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py index 243b3ac6e5e6..bfd0294d53af 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py @@ -18,6 +18,30 @@ RequestType = Union[ChatCompletionRequest, CompletionRequest] +def base_prefill_kv_transfer_params() -> Dict[str, Any]: + """The ``kv_transfer_params`` common to a prefill (producer) request. + + Tells the prefill engine to produce KV for a remote decode. Connectors layer + their own keys (e.g. a transfer id, DP/TP routing) on top of these. + """ + return { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + } + + +def clamp_request_to_single_token(request: "RequestType") -> None: + """Clamp a prefill request to a single, non-streaming token (in place).""" + request.max_tokens = 1 + if hasattr(request, "max_completion_tokens"): + request.max_completion_tokens = 1 + request.stream = False + if hasattr(request, "stream_options"): + request.stream_options = None + + class BaseConnectorBackend(abc.ABC): # ---- P/D coordination protocol ---- # @@ -159,6 +183,19 @@ def setup(self) -> None: """ pass + def replica_metadata(self) -> Dict[str, Any]: + """Static per-replica coordination data published to the orchestrator. + + Surfaced via the replica-metadata hook on ``ReplicaSelection`` so that a + connector opting into ``requires_peer_binding`` can address the selected + prefill peer. The default backend publishes nothing; connectors that need + to advertise an address (e.g. MoRIIO's zmq endpoint) override this. + + Returns: + A JSON-serializable dict of per-replica metadata (empty by default). + """ + return {} + class DefaultPDProtocolMixin: """The default P/D protocol policy: no peer binding, sequential handoff. @@ -187,19 +224,11 @@ def prepare_prefill_request( ), "kv_transfer_params should be empty before orchestrator" prefill_request = request.model_copy(deep=True) prefill_request.kv_transfer_params = { - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_engine_id": None, - "remote_block_ids": None, + **base_prefill_kv_transfer_params(), "remote_host": None, "remote_port": None, } - prefill_request.max_tokens = 1 - if hasattr(prefill_request, "max_completion_tokens"): - prefill_request.max_completion_tokens = 1 - prefill_request.stream = False - if hasattr(prefill_request, "stream_options"): - prefill_request.stream_options = None + clamp_request_to_single_token(prefill_request) return prefill_request def prepare_decode_request( diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py index aeec2348c554..54b7c12be713 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py @@ -121,6 +121,7 @@ def unregister_backend(cls, name: str) -> None: "LMCacheConnectorV1": "ray.llm._internal.serve.engines.vllm.kv_transfer.lmcache:LMCacheConnectorV1Backend", "NixlConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.nixl:NixlConnectorBackend", "MultiConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.multi_connector:MultiConnectorBackend", + "MoRIIOConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.moriio:MoRIIOConnectorBackend", } diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py new file mode 100644 index 000000000000..d05e4dca5627 --- /dev/null +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py @@ -0,0 +1,342 @@ +"""MoRIIO connector backend for Ray Serve LLM (analogue of nixl.py). + +Configures a vLLM engine's ``kv_transfer_config.kv_connector_extra_config`` for +the MoRIIO connector and computes per-replica handshake/notify ports so colocated +replicas don't collide. Also builds the engine's advertised zmq address so the +P/D orchestrator can discover it via the replica-metadata hook +(``ReplicaSelection.replica_metadata``), and implements the PD connector protocol +(``requires_peer_binding`` / ``concurrent_handoff`` / ``prepare_prefill_request`` / +``prepare_decode_request``) so the decode orchestrator can address the selected +prefill peer by request id. + +Unlike NIXL/LMCache, MoRIIO does NOT use ``DefaultPDProtocolMixin``: it has custom +request shaping (a dual-address request_id + transfer_id) and therefore IMPLEMENTS +the abstract ``prepare_*`` methods directly on ``BaseConnectorBackend``. + +Two transfer disciplines, selected by ``read_mode``: + * WRITE (default): prefill PUSHES KV to decode -> concurrent handoff. + * READ: decode PULLS KV from prefill -> sequential handoff; the decode request + forwards the ``remote_block_ids`` / ``remote_engine_id`` the prefill engine + returned. + +The dual-address request_id and the transfer_id are derived DETERMINISTICALLY +from the incoming request id (a hash), so ``prepare_prefill_request`` and +``prepare_decode_request`` produce identical ids across their two separate calls +without per-request backend state (the backend instance is shared across +requests). + +Registered with Ray's public connector registry via the factory. +""" + +import hashlib +import logging +import re +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import ray +from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, + base_prefill_kv_transfer_params, + clamp_request_to_single_token, +) + +if TYPE_CHECKING: + from ray.llm._internal.serve.engines.vllm.kv_transfer.base import RequestType + +logger = logging.getLogger(__name__) + +# Defaults mirror vLLM's MoRIIOConstants (DEFAULT_HANDSHAKE_PORT / NOTIFY_PORT). +# Prefill uses these bases; decode is shifted (see builder.py) so a colocated +# P+D pair on one node doesn't collide. +DEFAULT_HANDSHAKE_PORT_BASE = 6301 +DEFAULT_NOTIFY_PORT_BASE = 61005 + +# experimental_configs keys understood by this backend. +HANDSHAKE_PORT_BASE_KEY = "MORI_HANDSHAKE_PORT_BASE" +NOTIFY_PORT_BASE_KEY = "MORI_NOTIFY_PORT_BASE" + +# --------------------------------------------------------------------------- +# Dual-address request_id / zmq address encoding. +# +# These MUST stay byte-compatible with the regexes vLLM's MoRIIO connector uses +# to recover peer addresses from the request_id: +# +# vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py +# _PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_") +# _DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$") +# # zmq address: "host:IP,handshake:PORT,notify:PORT" +# --------------------------------------------------------------------------- + +_PREFILL_PREFIX = "___prefill_addr_" +_DECODE_PREFIX = "___decode_addr_" +_TRANSFER_PREFIX = "tx" + +# Copies of vLLM's regexes for local validation / round-trip tests. +_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_") +_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$") + + +def build_zmq_address(host: str, handshake_port: int, notify_port: int) -> str: + """Build the MORI zmq address string ``host:IP,handshake:PORT,notify:PORT``.""" + return f"host:{host},handshake:{handshake_port},notify:{notify_port}" + + +def parse_zmq_address(zmq_address: str) -> Tuple[str, int, int]: + """Inverse of :func:`build_zmq_address` -> ``(host, handshake_port, notify_port)``.""" + parts = {} + for segment in zmq_address.split(","): + key, _, val = segment.partition(":") + parts[key.strip()] = val.strip() + return parts["host"], int(parts["handshake"]), int(parts["notify"]) + + +def parse_peer_zmq(request_id: str, is_producer: bool) -> str: + """Recover the peer's zmq address from a request id (for tests/debugging). + + Producer (prefill) wants the *decode* address; consumer wants the *prefill*. + """ + rex = _DECODE_ZMQ_RE if is_producer else _PREFILL_ZMQ_RE + m = rex.search(request_id) + if not m: + raise ValueError(f"No peer zmq address in request_id: {request_id!r}") + return m.group(1) + + +def _read_mode_enabled(extra_config: Dict[str, Any]) -> bool: + """Mirror vLLM's ``get_moriio_mode`` parse of ``read_mode``. + + true / 1 -> READ; anything else -> WRITE (default). + """ + return str(extra_config.get("read_mode", "false")).lower().strip() in ( + "true", + "1", + ) + + +class MoRIIOConnectorBackend(BaseConnectorBackend): + """Set up MoRIIO ports/extra_config and implement the PD connector protocol.""" + + # The advertised zmq address ("host:IP,handshake:PORT,notify:PORT"), + # computed by setup(); consumers reach it via this backend instance. + _zmq_address: Optional[str] = None + + # MORI addresses peers by the dual-address request id, so the orchestrator + # must bind to the selected prefill replica BEFORE dispatch. + requires_peer_binding: bool = True + + def _extra_config(self) -> dict: + cfg = self.kv_transfer_config.setdefault("kv_connector_extra_config", {}) + return cfg + + @property + def _read_mode(self) -> bool: + """True iff this engine's MoRIIO connector is configured for READ mode.""" + extra = self._extra_config() + return _read_mode_enabled(extra) + + @property + def concurrent_handoff(self) -> bool: + """WRITE -> concurrent (prefill pushes); READ -> sequential (decode pulls).""" + return not self._read_mode + + def setup(self) -> None: + offset = self._compute_port_offset() + + handshake_base = int( + self.llm_config.experimental_configs.get( + HANDSHAKE_PORT_BASE_KEY, DEFAULT_HANDSHAKE_PORT_BASE + ) + ) + notify_base = int( + self.llm_config.experimental_configs.get( + NOTIFY_PORT_BASE_KEY, DEFAULT_NOTIFY_PORT_BASE + ) + ) + + # NOTE: vLLM internally adds get_port_offset(dp_rank, tp_rank) on top of + # these bases. For TP/DP>1, reserve a stride >= tp_size*pp_size when + # shifting decode's base in the builder so the two offset schemes never + # overlap. + handshake_port = handshake_base + offset + notify_port = notify_base + offset + + extra = self._extra_config() + # Required keys for vLLM's config parser (KeyError otherwise) -- proxyless. + extra.setdefault("proxy_ip", "") # empty => ping/registration thread disabled + extra.setdefault("proxy_ping_port", "0") + # TODO: real Serve replica HTTP port. Harmless placeholder while + # proxy_ip="" (only used to build request_address for the disabled ping). + extra.setdefault("http_port", str(8000 + offset)) + # WRITE mode (prefill pushes). READ would be "true". + extra.setdefault("read_mode", "false") + extra["handshake_port"] = str(handshake_port) + extra["notify_port"] = str(notify_port) + + # Advertise the Ray internal cluster IP as the zmq host. + host = ray.util.get_node_ip_address() + zmq_address = build_zmq_address(host, handshake_port, notify_port) + # Stash so replica_metadata() can publish it; the decode + # orchestrator reads the selected prefill replica's copy off the peer. + self._zmq_address = zmq_address + # NOTE: cross-node correctness additionally needs each worker to + # advertise the node INTERNAL IP (set VLLM_HOST_IP inside every worker + # process). VLLM_HOST_IP is excluded from vLLM's driver->worker env copy, + # so it can only be set in-process -- handled by a vLLM general-plugin + # shipped separately. Single-node deployments work without it. + + # ---- parallelism (data/tensor) ---- + + def _dp_rank(self) -> int: + rank = self.llm_config.engine_kwargs.get("data_parallel_rank") + return rank if isinstance(rank, int) and rank >= 0 else 0 + + def _dp_size(self) -> int: + return int(self.llm_config.engine_kwargs.get("data_parallel_size") or 1) + + def _tp_size(self) -> int: + return int(self.llm_config.engine_kwargs.get("tensor_parallel_size") or 1) + + # ---- replica metadata (published via the replica-metadata hook) ---- + + def replica_metadata(self) -> dict: + """Static per-replica coordination data published to the orchestrator. + + The prefill replica publishes its MORI zmq address and its parallelism + (DP rank/size, TP size); the decode orchestrator reads them off the + selected prefill replica's ``ReplicaSelection.replica_metadata`` and uses + them to address the right remote (dp_rank, tp) workers. + """ + return { + "mori_zmq_address": self._zmq_address, + "dp_rank": self._dp_rank(), + "dp_size": self._dp_size(), + "tp_size": self._tp_size(), + } + + def _remote_routing(self, remote: Dict[str, Any]) -> Dict[str, Any]: + """``kv_transfer_params`` keys telling vLLM which remote workers to reach. + + ``remote`` is the metadata of the *other* side of the transfer: the + decode (this orchestrator) for a prefill request, the selected prefill + peer for a decode request. vLLM addresses a remote worker at + ``advertised_base + get_port_offset(remote_dp_rank, tp_index)`` and + handshakes all ``remote_dp_size`` ranks, so both must match the target + replica. ``tp_size`` is the remote's TP (symmetric across P/D). + """ + return { + "remote_dp_rank": int(remote.get("dp_rank", 0)), + "remote_dp_size": int(remote.get("dp_size", 1)), + "tp_size": int(remote.get("tp_size", self._tp_size())), + } + + def _own_routing(self) -> Dict[str, Any]: + """``_remote_routing`` input describing this (decode orchestrator) replica + -- the remote side for a prefill request.""" + return { + "dp_rank": self._dp_rank(), + "dp_size": self._dp_size(), + "tp_size": self._tp_size(), + } + + # ---- request shaping (PD connector protocol) ---- + + def _dual_ids( + self, request: Any, peer: Optional[Dict[str, Any]] + ) -> Tuple[str, str]: + """Compute the (dual-address request_id, transfer_id) for this request. + + ``prepare_prefill_request`` and ``prepare_decode_request`` are two + independent, stateless calls for the same request, so both ids are + derived deterministically (hash of a stable per-request seed) — no + per-request backend state. + """ + prefill_zmq = (peer or {}).get("mori_zmq_address") + decode_zmq = self._zmq_address + if not prefill_zmq: + raise ValueError( + "MoRIIO peer is missing 'mori_zmq_address': the selected prefill " + "replica did not publish its address (is MoRIIOConnector " + "configured on the prefill deployment?)." + ) + if not decode_zmq: + raise ValueError( + "MoRIIO decode zmq address is not set: setup() must run on this " + "engine before requests are shaped." + ) + # The incoming request_id (always populated -- OpenAI models default it + # to a uuid) is the seed. Both prepare_* calls run on the same request + # object, so they agree; uniqueness per request is inherited from it. + seed = str(request.request_id) + # 32 hex chars (the trailing uid _PREFILL_ZMQ_RE / _DECODE_ZMQ_RE + # anchor on); a hash of the seed, so both prepare_* calls agree. + uid = hashlib.sha256(seed.encode()).hexdigest()[:32] + # Wire format consumed by vLLM's MoRIIO connector. + request_id = f"{_PREFILL_PREFIX}{prefill_zmq}{_DECODE_PREFIX}{decode_zmq}_{uid}" + transfer_id = f"{_TRANSFER_PREFIX}-{uid}" + return request_id, transfer_id + + def prepare_prefill_request( + self, *, request: "RequestType", peer: Optional[Dict[str, Any]] + ) -> "RequestType": + request_id, transfer_id = self._dual_ids(request, peer) + prefill_request = request.model_copy(deep=True) + # The dual-address id (peer zmq encoded in it) must reach the engine: + # setting request_id explicitly makes the LLMServer pipeline preserve it + # (not clobber it with the Serve id) and the engine copies it into the + # X-Request-Id header that vLLM's MoRIIO connector parses. + prefill_request.request_id = request_id + prefill_request.kv_transfer_params = { + **base_prefill_kv_transfer_params(), + "transfer_id": transfer_id, + # The prefill engine's remote is the decode (this orchestrator). + **self._remote_routing(self._own_routing()), + } + clamp_request_to_single_token(prefill_request) + return prefill_request + + def prepare_decode_request( + self, + *, + request: "RequestType", + peer: Optional[Dict[str, Any]], + prefill_response: Optional[Any], + ) -> "RequestType": + request_id, transfer_id = self._dual_ids(request, peer) + decode_request = request.model_copy(deep=True) + decode_request.request_id = request_id + # The decode engine's remote is the selected prefill peer. + remote_routing = self._remote_routing(peer or {}) + + if not self._read_mode: + # WRITE: prefill pushes KV; decode just needs do_remote_prefill + the + # shared transfer_id (no block ids -- they are pushed, not pulled). + decode_request.kv_transfer_params = { + "do_remote_prefill": True, + "do_remote_decode": False, + "remote_engine_id": None, + "remote_block_ids": None, + "transfer_id": transfer_id, + **remote_routing, + } + return decode_request + + # READ: decode PULLS KV; forward the remote_block_ids / remote_engine_id + # the prefill engine returned on its response. If absent (e.g. prompt < + # block_size / full prefix hit), fall back to a local recompute. + prefill_kv_params = getattr(prefill_response, "kv_transfer_params", None) + params = dict(prefill_kv_params) if prefill_kv_params else {} + if params.get("remote_block_ids") and params.get("remote_engine_id"): + params.setdefault("transfer_id", transfer_id) + params["do_remote_prefill"] = True + params["do_remote_decode"] = False + # Address the prefill peer's (dp_rank, dp_size, tp) workers. + params.update(remote_routing) + decode_request.kv_transfer_params = params + else: + logger.warning( + "[MORI][READ] prefill returned no remote_block_ids/remote_engine_id " + "(kv_transfer_params=%s); decode will recompute locally.", + prefill_kv_params, + ) + decode_request.kv_transfer_params = None + return decode_request diff --git a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py index f095d39fd599..61af89007db3 100644 --- a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py +++ b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py @@ -183,6 +183,36 @@ def _default_decode_nixl_port_base(self): ) return self + @model_validator(mode="after") + def _default_decode_moriio_port_base(self): + """Shift decode's MoRIIO handshake/notify bases off prefill's defaults. + + Mirrors ``_default_decode_nixl_port_base``: a colocated P+D pair on one + node would otherwise share MoRIIO's default handshake/notify ports. Only + applies when the decode config uses the MoRIIO connector. The +1000 + stride is well above any realistic tp_size*pp_size offset added on top. + """ + kv_transfer_config = ( + self.decode_config.engine_kwargs.get("kv_transfer_config") or {} + ) + if kv_transfer_config.get("kv_connector") != "MoRIIOConnector": + return self + + from ray.llm._internal.serve.engines.vllm.kv_transfer.moriio import ( + DEFAULT_HANDSHAKE_PORT_BASE, + DEFAULT_NOTIFY_PORT_BASE, + HANDSHAKE_PORT_BASE_KEY, + NOTIFY_PORT_BASE_KEY, + ) + + self.decode_config.experimental_configs.setdefault( + HANDSHAKE_PORT_BASE_KEY, DEFAULT_HANDSHAKE_PORT_BASE + 1000 + ) + self.decode_config.experimental_configs.setdefault( + NOTIFY_PORT_BASE_KEY, DEFAULT_NOTIFY_PORT_BASE + 1000 + ) + return self + # --------------------------------------------------------------------------- # Builder diff --git a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py index ddca4dafce1b..30bcd43edeb2 100644 --- a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py +++ b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py @@ -175,6 +175,12 @@ async def _pd_handle_request( the resolved KV-connector backend. With the default backend flags (``requires_peer_binding=False``, ``concurrent_handoff=False``) the control flow and calls are identical to the historical NIXL/default flow. + + A connector that encodes coordination data in the request id (MoRIIO's + dual-address id) just stamps ``request.request_id`` in ``prepare_*``; it + then reaches both engines unchanged -- the LLMServer pipeline preserves + an explicitly-set request_id (it no longer clobbers it with the Serve + id) and the engine copies it into the ``X-Request-Id`` header it reads. """ # Determine method name for the handle call @@ -439,9 +445,21 @@ async def _maybe_prewarm(self) -> None: logger.info("[PDDecodeServer] Starting pre-warm across all P replicas.") + backend = self._get_connector_backend() + if backend.requires_peer_binding: + # Peer-binding connectors (e.g. MoRIIO) shape a prefill request + # against a specific selected replica's metadata; a peerless + # broadcast prewarm cannot bind one. The connector handshake + # completes on the first real request instead. + logger.info( + "[PDDecodeServer] Skipping pre-warm: connector %s requires peer " + "binding (handshake completes on the first real request).", + type(backend).__name__, + ) + return + model_id = self._llm_config.model_id dummy = self._make_dummy_request(model_id) - backend = self._get_connector_backend() prefill_req = backend.prepare_prefill_request(request=dummy, peer=None) # Broadcast to every live P replica; retry until they are up. @@ -551,6 +569,25 @@ class PDPrefillServer(LLMServer): method used during the pre-warm handshake. """ + async def record_replica_metadata(self) -> Dict[str, Any]: + """Publish this prefill replica's connector coordination metadata. + + Read by the decode orchestrator via the replica-metadata hook + (``ReplicaSelection.replica_metadata``) so peer-binding connectors (e.g. + MoRIIO) can address the selected prefill replica. Returns ``{}`` for + connectors that publish nothing (the ``BaseConnectorBackend`` default). + + Returns the metadata of the backend that engine init + (``setup_engine_backend``) created, ``setup()``-ed, and stored on this + server's ``_llm_config``. The replica-metadata hook is captured after + engine init, so for connector deployments the backend is present by + then; with no backend stored there is nothing to publish. + """ + backend = getattr(self._llm_config, "kv_connector_backend", None) + if backend is None: + return {} + return backend.replica_metadata() + async def prewarm_prefill( self, prefill_request: CompletionRequest ) -> Optional[dict]: diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py new file mode 100644 index 000000000000..8878ab36ca4c --- /dev/null +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py @@ -0,0 +1,365 @@ +import re +import sys +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, +) +from ray.llm._internal.serve.engines.vllm.kv_transfer.factory import ( + KVConnectorBackendFactory, +) +from ray.llm._internal.serve.engines.vllm.kv_transfer.moriio import ( + _DECODE_ZMQ_RE, + _PREFILL_ZMQ_RE, + DEFAULT_HANDSHAKE_PORT_BASE, + DEFAULT_NOTIFY_PORT_BASE, + MoRIIOConnectorBackend, + parse_peer_zmq, + parse_zmq_address, +) +from ray.serve.llm import LLMConfig +from ray.serve.schema import ReplicaRank + +_TEST_HOST = "10.0.0.5" + + +def _replica_context(global_rank: int) -> SimpleNamespace: + return SimpleNamespace( + rank=ReplicaRank(rank=global_rank, node_rank=0, local_rank=global_rank) + ) + + +def _make_backend( + read_mode: bool = False, + extra_exp: dict = None, + dp_rank: int = None, + dp_size: int = None, + tp_size: int = None, +): + extra_config = {} + if read_mode: + extra_config["read_mode"] = "true" + engine_kwargs = dict( + kv_transfer_config=dict( + kv_connector="MoRIIOConnector", + kv_role="kv_both", + kv_connector_extra_config=extra_config, + ) + ) + if dp_rank is not None: + engine_kwargs["data_parallel_rank"] = dp_rank + if dp_size is not None: + engine_kwargs["data_parallel_size"] = dp_size + if tp_size is not None: + engine_kwargs["tensor_parallel_size"] = tp_size + return MoRIIOConnectorBackend( + llm_config=LLMConfig( + model_loading_config=dict(model_id="Qwen/Qwen3-0.6B"), + engine_kwargs=engine_kwargs, + experimental_configs=extra_exp or {}, + ), + ) + + +def _setup(backend, rank: int = 0): + with ( + patch.dict("os.environ", {}, clear=True), + patch("ray.util.get_node_ip_address", return_value=_TEST_HOST), + patch("ray.serve.get_replica_context", return_value=_replica_context(rank)), + ): + backend.setup() + + +class TestMoRIIOConnectorBackendSetup: + def test_setup_sets_ports_zmq_and_extra_config(self): + backend = _make_backend() + _setup(backend, rank=0) + + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == str(DEFAULT_HANDSHAKE_PORT_BASE) + assert extra["notify_port"] == str(DEFAULT_NOTIFY_PORT_BASE) + assert extra["proxy_ip"] == "" + assert extra["proxy_ping_port"] == "0" + assert "http_port" in extra + assert extra["read_mode"] == "false" + + zmq = backend._zmq_address + host, hs, notify = parse_zmq_address(zmq) + assert host == _TEST_HOST + assert hs == DEFAULT_HANDSHAKE_PORT_BASE + assert notify == DEFAULT_NOTIFY_PORT_BASE + + def test_setup_port_offset_uses_replica_rank(self): + backend = _make_backend() + num_devices = backend.llm_config.get_engine_config().num_devices + _setup(backend, rank=2) + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == str( + DEFAULT_HANDSHAKE_PORT_BASE + 2 * num_devices + ) + assert extra["notify_port"] == str(DEFAULT_NOTIFY_PORT_BASE + 2 * num_devices) + + def test_setup_respects_overridden_bases(self): + backend = _make_backend( + extra_exp={ + "MORI_HANDSHAKE_PORT_BASE": 7000, + "MORI_NOTIFY_PORT_BASE": 62000, + } + ) + _setup(backend, rank=0) + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == "7000" + assert extra["notify_port"] == "62000" + + def test_requires_peer_binding(self): + assert MoRIIOConnectorBackend.requires_peer_binding is True + + def test_concurrent_handoff_write_vs_read(self): + write_backend = _make_backend(read_mode=False) + read_backend = _make_backend(read_mode=True) + assert write_backend.concurrent_handoff is True + assert read_backend.concurrent_handoff is False + assert write_backend._read_mode is False + assert read_backend._read_mode is True + + @pytest.mark.parametrize( + "value,expected_read", + [ + ("true", True), + ("True", True), + ("1", True), + ("false", False), + ("0", False), + ("", False), + ], + ) + def test_read_mode_parsing(self, value, expected_read): + backend = MoRIIOConnectorBackend( + llm_config=LLMConfig( + model_loading_config=dict(model_id="Qwen/Qwen3-0.6B"), + engine_kwargs=dict( + kv_transfer_config=dict( + kv_connector="MoRIIOConnector", + kv_connector_extra_config={"read_mode": value}, + ) + ), + ), + ) + assert backend._read_mode is expected_read + + def test_replica_metadata_returns_zmq(self): + backend = _make_backend() + _setup(backend, rank=0) + meta = backend.replica_metadata() + assert meta["mori_zmq_address"] == backend._zmq_address + # Parallelism is published so the orchestrator can address remote workers. + assert meta["dp_rank"] == 0 and meta["dp_size"] == 1 and meta["tp_size"] == 1 + + def test_replica_metadata_publishes_dp_tp(self): + backend = _make_backend(dp_rank=2, dp_size=4, tp_size=8) + _setup(backend, rank=0) + meta = backend.replica_metadata() + assert meta["dp_rank"] == 2 + assert meta["dp_size"] == 4 + assert meta["tp_size"] == 8 + + def test_replica_metadata_default_empty(self): + # The default backend publishes nothing (concrete default on the base). + assert BaseConnectorBackend.replica_metadata(None) == {} + + +class TestMoRIIORequestId: + def _prepared_pair(self, backend, request, peer): + prefill = backend.prepare_prefill_request(request=request, peer=peer) + # prefill_response is unused in WRITE mode; pass a dummy with no params. + decode = backend.prepare_decode_request( + request=request, + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + return prefill, decode + + def _request_with_copy(self, request_id="user-req-123"): + class _Req: + def __init__(self, rid): + self.request_id = rid + self.kv_transfer_params = None + self.max_tokens = 128 + self.max_completion_tokens = 128 + self.stream = True + self.stream_options = {"include_usage": True} + + def model_copy(self, deep=False): + new = _Req(self.request_id) + return new + + return _Req(request_id) + + def test_prefill_and_decode_share_request_id_and_transfer_id(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + decode_zmq = backend._zmq_address + prefill_zmq = "host:10.0.0.9,handshake:6301,notify:61005" + peer = {"mori_zmq_address": prefill_zmq} + + req = self._request_with_copy("user-req-123") + prefill, decode = self._prepared_pair(backend, req, peer) + + assert prefill.request_id == decode.request_id + assert ( + prefill.kv_transfer_params["transfer_id"] + == decode.kv_transfer_params["transfer_id"] + ) + + # Round-trips to the right peer zmq via the vLLM regexes. + assert _PREFILL_ZMQ_RE.search(prefill.request_id).group(1) == prefill_zmq + assert _DECODE_ZMQ_RE.search(prefill.request_id) is not None + assert parse_peer_zmq(prefill.request_id, is_producer=False) == prefill_zmq + assert parse_peer_zmq(prefill.request_id, is_producer=True) == decode_zmq + + # transfer_id format tx-<32hex>. + assert re.fullmatch( + r"tx-[0-9a-f]{32}", prefill.kv_transfer_params["transfer_id"] + ) + + def test_id_is_deterministic_across_calls(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + + p1 = backend.prepare_prefill_request( + request=self._request_with_copy("R"), peer=peer + ) + p2 = backend.prepare_prefill_request( + request=self._request_with_copy("R"), peer=peer + ) + assert p1.request_id == p2.request_id + assert ( + p1.kv_transfer_params["transfer_id"] == p2.kv_transfer_params["transfer_id"] + ) + + def test_prefill_kv_params_write(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + prefill = backend.prepare_prefill_request( + request=self._request_with_copy(), peer=peer + ) + assert prefill.kv_transfer_params["do_remote_decode"] is True + assert prefill.kv_transfer_params["do_remote_prefill"] is False + assert prefill.max_tokens == 1 + assert prefill.stream is False + # vLLM reads "tp_size" (not "remote_tp_size"); DP defaults at dp_size=1. + assert "remote_tp_size" not in prefill.kv_transfer_params + assert prefill.kv_transfer_params["tp_size"] == 1 + assert prefill.kv_transfer_params["remote_dp_rank"] == 0 + assert prefill.kv_transfer_params["remote_dp_size"] == 1 + + def test_decode_kv_params_write(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + assert decode.kv_transfer_params["do_remote_prefill"] is True + assert decode.kv_transfer_params["do_remote_decode"] is False + assert decode.kv_transfer_params["remote_block_ids"] is None + assert "remote_tp_size" not in decode.kv_transfer_params + assert decode.kv_transfer_params["tp_size"] == 1 + assert decode.kv_transfer_params["remote_dp_rank"] == 0 + assert decode.kv_transfer_params["remote_dp_size"] == 1 + + def test_dp_routing_targets_correct_ranks(self): + """With DP>1, the prefill request addresses the decode (this + orchestrator) rank; the decode request addresses the selected prefill + peer's rank (read from peer metadata).""" + # This orchestrator is decode dp_rank=1 of a 2-way DP decode group. + backend = _make_backend(read_mode=False, dp_rank=1, dp_size=2, tp_size=4) + _setup(backend, rank=0) + # The selected prefill peer is dp_rank=3 of a 4-way DP prefill group. + peer = { + "mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005", + "dp_rank": 3, + "dp_size": 4, + "tp_size": 4, + } + prefill = backend.prepare_prefill_request( + request=self._request_with_copy(), peer=peer + ) + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + # Prefill engine's remote == this decode orchestrator (rank 1 of 2). + assert prefill.kv_transfer_params["remote_dp_rank"] == 1 + assert prefill.kv_transfer_params["remote_dp_size"] == 2 + # Decode engine's remote == the selected prefill peer (rank 3 of 4). + assert decode.kv_transfer_params["remote_dp_rank"] == 3 + assert decode.kv_transfer_params["remote_dp_size"] == 4 + assert decode.kv_transfer_params["tp_size"] == 4 + + def test_decode_kv_params_read_forwards_prefill_params(self): + backend = _make_backend(read_mode=True) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + prefill_resp = SimpleNamespace( + kv_transfer_params={ + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "eng-7", + } + ) + decode = backend.prepare_decode_request( + request=self._request_with_copy(), peer=peer, prefill_response=prefill_resp + ) + assert decode.kv_transfer_params["do_remote_prefill"] is True + assert decode.kv_transfer_params["remote_block_ids"] == [1, 2, 3] + assert decode.kv_transfer_params["remote_engine_id"] == "eng-7" + assert "transfer_id" in decode.kv_transfer_params + # READ also stamps the remote (prefill peer) routing. + assert decode.kv_transfer_params["remote_dp_rank"] == 0 + assert decode.kv_transfer_params["remote_dp_size"] == 1 + assert decode.kv_transfer_params["tp_size"] == 1 + + def test_decode_read_fallback_when_no_remote_params(self): + backend = _make_backend(read_mode=True) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + assert decode.kv_transfer_params is None + + +class TestMoRIIOZmqValidation: + def test_missing_peer_zmq_raises(self): + """A missing/empty peer mori_zmq_address must raise a clear error, not + silently build a request id containing "None".""" + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + request = TestMoRIIORequestId()._request_with_copy("user-req-123") + for peer in (None, {}, {"mori_zmq_address": ""}): + with pytest.raises(ValueError, match="mori_zmq_address"): + backend.prepare_prefill_request(request=request, peer=peer) + + +class TestMoRIIOFactory: + def test_registered(self): + assert KVConnectorBackendFactory.is_registered("MoRIIOConnector") + + def test_create_backend_returns_class(self): + backend_class = KVConnectorBackendFactory.get_backend_class("MoRIIOConnector") + assert backend_class is MoRIIOConnectorBackend + assert issubclass(backend_class, BaseConnectorBackend) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py index 06d245c0c409..adccf0f6aa21 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py +++ b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py @@ -780,6 +780,35 @@ async def _gen(): assert chunks == [prefill_error] assert decode_aborted["value"] is True + @pytest.mark.asyncio + async def test_prewarm_skipped_for_peer_binding_backend(self): + """Pre-warm broadcasts a peerless prefill request, which a peer-binding + connector (e.g. MoRIIO) cannot shape -- so it must be skipped rather + than crash decode-replica init.""" + from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, + ) + + class _PeerBindingBackend(BaseConnectorBackend): + requires_peer_binding = True + + def prepare_prefill_request(self, *, request, peer): + raise AssertionError("prepare_prefill_request must not be called") + + def prepare_decode_request(self, *, request, peer, prefill_response): + raise AssertionError("prepare_decode_request must not be called") + + server = PDDecodeServer.__new__(PDDecodeServer) + server._llm_config = LLMConfig( + model_loading_config=ModelLoadingConfig(model_id="test-model"), + experimental_configs={"_prewarm_prefill_decode": True}, + ) + server._llm_config._kv_connector_backend = _PeerBindingBackend( + server._llm_config + ) + # Must return without raising (and without touching prepare_*). + await server._maybe_prewarm() + class TestBuildPDOpenaiApp: """Test suite for build_pd_openai_app function."""