diff --git a/cpp/src/ray/runtime/task/task_executor.cc b/cpp/src/ray/runtime/task/task_executor.cc index ced9277fd825..e39fd8c9b5a5 100644 --- a/cpp/src/ray/runtime/task/task_executor.cc +++ b/cpp/src/ray/runtime/task/task_executor.cc @@ -140,6 +140,7 @@ Status TaskExecutor::ExecuteTask( bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) { RAY_LOG(DEBUG) << "Execute task type: " << TaskType_Name(task_type) << " name:" << task_name; diff --git a/cpp/src/ray/runtime/task/task_executor.h b/cpp/src/ray/runtime/task/task_executor.h index a3bd2964c3e1..bb14254bffd7 100644 --- a/cpp/src/ray/runtime/task/task_executor.h +++ b/cpp/src/ray/runtime/task/task_executor.h @@ -97,6 +97,7 @@ class TaskExecutor { bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, /* This is used by the in-actor RDT object store. However, it is only supported in * the Python frontend. */ const std::optional &tensor_transport); diff --git a/python/ray/_common/ray_option_utils.py b/python/ray/_common/ray_option_utils.py index 6185025c9b86..166223a4f5c8 100644 --- a/python/ray/_common/ray_option_utils.py +++ b/python/ray/_common/ray_option_utils.py @@ -221,6 +221,16 @@ def issubclass_safe(obj: Any, cls_: type) -> bool: "whenever `next` is called). Use -1 to disable this feature. " ), ), + "_num_objects_per_yield": Option( + (int, type(None)), + lambda x: None + if (x is None or x > 0) + else ( + "_num_objects_per_yield is a private streaming generator option " + "that must be set to a positive integer." + ), + default_value=1, + ), } _actor_only_options = { diff --git a/python/ray/_raylet.pyx b/python/ray/_raylet.pyx index 0933767614ad..5539795b3453 100644 --- a/python/ray/_raylet.pyx +++ b/python/ray/_raylet.pyx @@ -1088,6 +1088,7 @@ cdef class StreamingGeneratorExecutionContext: c_bool *is_retryable_error c_string *application_error shared_ptr[CGeneratorBackpressureWaiter] waiter + int64_t num_objects_per_yield def initialize(self, generator: Union[Generator, AsyncGenerator]): # We couldn't make this a part of `make` method because @@ -1120,6 +1121,7 @@ cdef class StreamingGeneratorExecutionContext: c_bool *is_retryable_error, c_string *application_error, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, ): cdef StreamingGeneratorExecutionContext self = ( StreamingGeneratorExecutionContext()) @@ -1141,6 +1143,7 @@ cdef class StreamingGeneratorExecutionContext: self.is_retryable_error = is_retryable_error self.application_error = application_error self.should_retry_exceptions = should_retry_exceptions + self.num_objects_per_yield = num_objects_per_yield self.waiter = make_shared[CGeneratorBackpressureWaiter]( generator_backpressure_num_objects, @@ -1167,28 +1170,53 @@ cdef report_streaming_generator_output( context: Streaming generator's execution context. output: The output yielded from a generator or raised as an exception. - generator_index: index of the output element in the - generated sequence + generator_index: The first ObjectRef stream index for this yield. """ worker = ray._private.worker.global_worker cdef: - # Ray Object created from an output. - c_pair[CObjectID, shared_ptr[CRayObject]] return_obj + # Ray Objects created from an output. + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] return_objs + size_t i + c_bool output_error_reported = False start = time.perf_counter() # Report the intermediate result if there was no error. - create_generator_return_obj( - output, - context.generator_id, - worker, - context.caller_address, - context.task_id, - context.return_size, - generator_index, - context.is_async, - &return_obj) + try: + create_generator_return_objs( + output, + context.generator_id, + worker, + context.caller_address, + context.task_id, + context.return_size, + generator_index, + context.num_objects_per_yield, + context.is_async, + &return_objs) + except Exception as e: + if ( + context.num_objects_per_yield == 1 + or return_objs.size() != context.num_objects_per_yield + ): + raise + + # Dynamic IDs for this grouped yield are already allocated. If storing + # failed after some objects were written, report the whole group as + # non-retryable so the caller does not block waiting for later stream + # indexes and the allocated IDs can be cleared. + context.is_retryable_error[0] = False + store_task_errors( + worker, e, + True, # task_exception + context.actor, # actor + context.actor_id, # actor id + context.function_name, context.task_type, context.title, + context.caller_address, + &return_objs, + context.application_error) + output_error_reported = True # Del output here so that we can GC the memory # usage asap. @@ -1199,22 +1227,25 @@ cdef report_streaming_generator_output( if interrupt_signal_event is not None and interrupt_signal_event.is_set(): return - context.streaming_generator_returns[0].push_back( - c_pair[CObjectID, c_bool]( - return_obj.first, - is_plasma_object(return_obj.second))) + for i in range(return_objs.size()): + context.streaming_generator_returns[0].push_back( + c_pair[CObjectID, c_bool]( + return_objs[i].first, + is_plasma_object(return_objs[i].second))) serialization_dur_s = time.perf_counter() - start with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - return_obj, + return_objs, context.generator_id, context.caller_address, generator_index, context.attempt_number, context.waiter)) + if output_error_reported: + return None return StreamingGeneratorStats( object_creation_dur_s=serialization_dur_s, @@ -1233,14 +1264,13 @@ cdef report_streaming_generator_exception( context: Streaming generator's execution context. output_or_exception: The output yielded from a generator or raised as an exception. - generator_index: index of the output element in the - generated sequence + generator_index: The ObjectRef stream index for this exception. """ worker = ray._private.worker.global_worker cdef: # Ray Object created from an output. - c_pair[CObjectID, shared_ptr[CRayObject]] return_obj + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] return_objs create_generator_error_object( e, @@ -1258,7 +1288,7 @@ cdef report_streaming_generator_exception( generator_index, context.is_async, context.should_retry_exceptions, - &return_obj, + &return_objs, context.is_retryable_error, context.application_error ) @@ -1274,12 +1304,12 @@ cdef report_streaming_generator_exception( context.streaming_generator_returns[0].push_back( c_pair[CObjectID, c_bool]( - return_obj.first, - is_plasma_object(return_obj.second))) + return_objs[0].first, + is_plasma_object(return_objs[0].second))) with nogil: check_status(CCoreWorkerProcess.GetCoreWorker().ReportGeneratorItemReturns( - return_obj, + return_objs, context.generator_id, context.caller_address, generator_index, @@ -1321,9 +1351,12 @@ cdef execute_streaming_generator_sync(StreamingGeneratorExecutionContext context # next output output = gen.send(stats) # Track serialization duration of the next output - stats = report_streaming_generator_output(context, output, gen_index, None) + stats = report_streaming_generator_output( + context, output, gen_index, None) + if stats is None: + break - gen_index += 1 + gen_index += context.num_objects_per_yield except StopIteration: break @@ -1405,7 +1438,9 @@ async def execute_streaming_generator_async( cur_generator_index, interrupt_signal_event, ) - cur_generator_index += 1 + if stats is None: + break + cur_generator_index += context.num_objects_per_yield except StopAsyncIteration: break @@ -1444,7 +1479,7 @@ async def execute_streaming_generator_async( check_status(return_status) -cdef create_generator_return_obj( +cdef create_generator_return_objs( output, const CObjectID &generator_id, worker: "Worker", @@ -1452,9 +1487,10 @@ cdef create_generator_return_obj( TaskID task_id, return_size, generator_index, + int64_t num_objects_per_yield, is_async, - c_pair[CObjectID, shared_ptr[CRayObject]] *return_object): - """Create a generator return object based on a given output. + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *return_objects): + """Create generator return objects based on a given output. Args: output: The output from a next(generator). @@ -1465,33 +1501,55 @@ cdef create_generator_return_obj( the owner, so we can also call it "owner address". task_id: The task ID of the generator task. return_size: The number of static returns. - generator_index: The index of a current error object. + generator_index: The first ObjectRef stream index for this yield. + num_objects_per_yield: The number of ObjectRefs to create for each yield. is_async: Whether or not the given object is created within an async actor. - return_object(out): A Ray Object that contains the given output. + return_objects(out): Ray Objects that contain the given output. """ cdef: - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker + int64_t stream_index + int64_t i + CObjectID return_id - return_id = core_worker.allocate_dynamic_return_id_for_generator( - caller_address, - task_id.native(), - return_size, - generator_index, - is_async, - ) - intermediate_result.push_back( + if num_objects_per_yield == 1: + outputs = (output,) + else: + if not isinstance(output, (tuple, list)): + raise ValueError( + "Streaming generator tasks with _num_objects_per_yield=" + f"{num_objects_per_yield} must yield a tuple or list " + f"of length {num_objects_per_yield}." + ) + if len(output) != num_objects_per_yield: + raise ValueError( + "Streaming generator task yielded " + f"{len(output)} objects, but _num_objects_per_yield=" + f"{num_objects_per_yield}." + ) + outputs = output + + return_objects.reserve(num_objects_per_yield) + for i in range(num_objects_per_yield): + stream_index = generator_index + i + return_id = core_worker.allocate_dynamic_return_id_for_generator( + caller_address, + task_id.native(), + return_size, + stream_index, + is_async, + ) + return_objects.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( return_id, shared_ptr[CRayObject]())) + core_worker.store_task_outputs( - worker, [output], + worker, outputs, caller_address, - &intermediate_result, + return_objects, generator_id.Binary()) - return_object[0] = intermediate_result.back() - cdef create_generator_error_object( e: Exception, @@ -1509,7 +1567,7 @@ cdef create_generator_error_object( generator_index, is_async, c_bool should_retry_exceptions, - c_pair[CObjectID, shared_ptr[CRayObject]] *error_object, + c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] *error_objects, c_bool *is_retryable_error, c_string *application_error): """Create a generator error object. @@ -1538,17 +1596,16 @@ cdef create_generator_error_object( It is used to write an error message. actor_id: The ID of the actor. It is used to write an error message. return_size: The number of static returns. - generator_index: The index of a current error object. + generator_index: The ObjectRef stream index for this exception. is_async: Whether or not the given object is created within an async actor. - error_object(out): A Ray Object that contains the given error exception. + error_objects(out): Ray Objects that contain the given error exception. is_retryable_error(out): It is set to True if the generator raises an exception, and the error is retryable. application_error(out): It is set if the generator raises an application error. """ cdef: - c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] intermediate_result CoreWorker core_worker = worker.core_worker is_retryable_error[0] = determine_if_retryable( @@ -1578,7 +1635,7 @@ cdef create_generator_error_object( generator_index, is_async, ) - intermediate_result.push_back( + error_objects.push_back( c_pair[CObjectID, shared_ptr[CRayObject]]( error_id, shared_ptr[CRayObject]())) store_task_errors( @@ -1588,11 +1645,9 @@ cdef create_generator_error_object( actor_id, # actor id function_name, task_type, title, caller_address, - &intermediate_result, + error_objects, application_error) - error_object[0] = intermediate_result.back() - cdef execute_dynamic_generator_and_store_task_outputs( generator, @@ -1701,6 +1756,7 @@ cdef void execute_task( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport) except *: worker = ray._private.worker.global_worker manager = worker.function_actor_manager @@ -1874,7 +1930,8 @@ cdef void execute_task( streaming_generator_returns, is_retryable_error, application_error, - generator_backpressure_num_objects) + generator_backpressure_num_objects, + num_objects_per_yield) # We cannot pass generator to cdef in Cython for some reasons. # It is a workaround. context.initialize(outputs) @@ -2068,6 +2125,7 @@ cdef execute_task_with_cancellation_handler( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport): is_retryable_error[0] = False @@ -2161,6 +2219,7 @@ cdef execute_task_with_cancellation_handler( is_streaming_generator, should_retry_exceptions, generator_backpressure_num_objects, + num_objects_per_yield, c_tensor_transport) # Check for cancellation. @@ -2279,6 +2338,7 @@ cdef CRayStatus task_execution_handler( c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] c_tensor_transport) nogil: with gil, disable_client_hook(): # Initialize job_config if it hasn't already. @@ -2309,6 +2369,7 @@ cdef CRayStatus task_execution_handler( is_streaming_generator, should_retry_exceptions, generator_backpressure_num_objects, + num_objects_per_yield, c_tensor_transport) except Exception as e: sys_exit = SystemExit() @@ -2363,6 +2424,32 @@ cdef CRayStatus task_execution_handler( msg = "Unexpected exception raised in task execution handler: {}".format(e) logger.error(msg) return CRayStatus.UnexpectedSystemExit(msg) + except BaseException as e: + # Safety net: any BaseException that is not Exception or SystemExit + # (e.g. KeyboardInterrupt, GeneratorExit) would otherwise escape this + # cdef function. Without this, Cython silently returns + # CRayStatus.OK() for unhandled non-Exception/non-SystemExit + # exceptions, causing a CHECK failure in HandleTaskExecutionResult + # when return objects are not populated. + # Convert to UnexpectedSystemExit so the C++ side + # treats this as a clean worker-exiting task failure. + # + # The motivating case is a rapid double `ray.cancel()`. The first + # cancel raises a KeyboardInterrupt that is caught by + # `execute_task_with_cancellation_handler`'s + # `except KeyboardInterrupt` clause, which calls + # `store_task_errors`. If a second cancel arrives while + # `store_task_errors` is running, it queues another SIGINT that + # fires inside the error-storage path. That KeyboardInterrupt + # cannot be re-caught (we are already inside `except + # KeyboardInterrupt`), so it escapes all the way out to this + # handler. + msg = ( + "BaseException escaped task execution handlers: " + f"{type(e).__name__}: {e}" + ) + logger.error(msg) + return CRayStatus.UnexpectedSystemExit(msg) return CRayStatus.OK() @@ -3481,6 +3568,7 @@ cdef class CoreWorker: c_string debugger_breakpoint, c_string serialized_runtime_env_info, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_bool enable_task_events, labels, label_selector, @@ -3527,6 +3615,7 @@ cdef class CoreWorker: name, num_returns, c_resources, b"", generator_backpressure_num_objects, + num_objects_per_yield, serialized_runtime_env_info, enable_task_events, c_labels, @@ -3761,6 +3850,7 @@ cdef class CoreWorker: double num_method_cpus, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_bool enable_task_events, tensor_transport: Optional[str], dict labels=None): @@ -3817,6 +3907,7 @@ cdef class CoreWorker: c_resources, concurrency_group_name, generator_backpressure_num_objects, + num_objects_per_yield, serialized_runtime_env, enable_task_events, c_labels, @@ -3994,6 +4085,7 @@ cdef class CoreWorker: method_meta.max_task_retries, method_meta.retry_exceptions, method_meta.generator_backpressure_num_objects, # noqa + method_meta.num_objects_per_yield, method_meta.enable_task_events, enable_tensor_transport, method_meta.method_name_to_tensor_transport, @@ -4013,6 +4105,7 @@ cdef class CoreWorker: {}, # method max_task_retries {}, # method retry_exceptions {}, # generator_backpressure_num_objects + {}, # num_objects_per_yield {}, # enable_task_events False, # enable_tensor_transport None, # method_name_to_tensor_transport diff --git a/python/ray/actor.py b/python/ray/actor.py index 253ffa1e6107..2753b773c2c6 100644 --- a/python/ray/actor.py +++ b/python/ray/actor.py @@ -616,6 +616,7 @@ def method( max_task_retries: Optional[int] = None, retry_exceptions: Optional[Union[bool, list, tuple]] = None, _generator_backpressure_num_objects: Optional[int] = None, + _num_objects_per_yield: Optional[int] = None, enable_task_events: Optional[bool] = None, tensor_transport: Optional[str] = None, ) -> _MethodDecorator: @@ -693,6 +694,7 @@ def bar(self): "max_task_retries", "retry_exceptions", "_generator_backpressure_num_objects", + "_num_objects_per_yield", "enable_task_events", "tensor_transport", ] @@ -717,6 +719,11 @@ def annotate_method(method: Callable[_P, _Ret]): method.__ray_generator_backpressure_num_objects__ = kwargs[ "_generator_backpressure_num_objects" ] + if "_num_objects_per_yield" in kwargs: + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", kwargs["_num_objects_per_yield"] + ) + method.__ray_num_objects_per_yield__ = kwargs["_num_objects_per_yield"] if "enable_task_events" in kwargs and kwargs["enable_task_events"] is not None: method.__ray_enable_task_events__ = kwargs["enable_task_events"] if "tensor_transport" in kwargs: @@ -770,6 +777,7 @@ def __init__( retry_exceptions: Union[bool, list, tuple], is_generator: bool, generator_backpressure_num_objects: int, + num_objects_per_yield: int, enable_task_events: bool, decorator: Optional[Any] = None, signature: Optional[List[inspect.Parameter]] = None, @@ -787,6 +795,7 @@ def __init__( retry_exceptions: Boolean or list/tuple of exceptions to retry. is_generator: True if the method is a generator. generator_backpressure_num_objects: Generator-only config for backpressure. + num_objects_per_yield: Private generator-only config for grouped yields. enable_task_events: True if task events are enabled for this method. decorator: Optional decorator for the method invocation. signature: The signature of the actor method. @@ -805,6 +814,9 @@ def __init__( self._retry_exceptions = retry_exceptions self._is_generator = is_generator self._generator_backpressure_num_objects = generator_backpressure_num_objects + self._num_objects_per_yield = ( + 1 if num_objects_per_yield is None else num_objects_per_yield + ) self._enable_task_events = enable_task_events self._decorator = decorator self._signature = signature @@ -822,6 +834,7 @@ def bind(self, actor_handle: "ActorHandle") -> "ActorMethod": self._retry_exceptions, self._is_generator, self._generator_backpressure_num_objects, + self._num_objects_per_yield, self._enable_task_events, decorator=self._decorator, signature=self._signature, @@ -848,6 +861,7 @@ def __init__( retry_exceptions: Union[bool, list, tuple], is_generator: bool, generator_backpressure_num_objects: int, + num_objects_per_yield: int, enable_task_events: bool, decorator: Optional[Callable] = None, signature: Optional[List[inspect.Parameter]] = None, @@ -869,6 +883,7 @@ def __init__( generator_backpressure_num_objects: Generator-only config. If a number of unconsumed objects reach this threshold, the actor task stops pausing. + num_objects_per_yield: Private generator-only config for grouped yields. enable_task_events: True if task events is enabled, i.e., task events from the actor should be reported. Defaults to True. decorator: An optional decorator that should be applied to the actor @@ -892,6 +907,9 @@ def __init__( self._retry_exceptions = retry_exceptions self._is_generator = is_generator self._generator_backpressure_num_objects = generator_backpressure_num_objects + self._num_objects_per_yield = ( + 1 if num_objects_per_yield is None else num_objects_per_yield + ) self._enable_task_events = enable_task_events self._signature = signature # This is a decorator that is used to wrap the function invocation (as @@ -945,6 +963,11 @@ def options(self, **options: Any): A wrapper exposing ``.remote()`` / ``.bind()`` that applies the given options when the method is invoked. """ + if "_num_objects_per_yield" in options: + raise ValueError( + "_num_objects_per_yield cannot be overridden per actor method " + "call. Use @ray.method(_num_objects_per_yield=...) instead." + ) tensor_transport = options.get("tensor_transport", None) if tensor_transport is not None: @@ -967,6 +990,7 @@ def _bind( num_returns=None, concurrency_group=None, _generator_backpressure_num_objects=None, + _num_objects_per_yield=None, ) -> Union["ray.dag.ClassMethodNode", Tuple["ray.dag.ClassMethodNode", ...]]: from ray.dag.class_node import ( BIND_INDEX_KEY, @@ -983,6 +1007,11 @@ def _bind( "concurrency_group": concurrency_group, "_generator_backpressure_num_objects": _generator_backpressure_num_objects, } + if _num_objects_per_yield is not None: + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", _num_objects_per_yield + ) + options["_num_objects_per_yield"] = _num_objects_per_yield actor = self._actor if actor is None: @@ -1049,6 +1078,7 @@ def _remote( retry_exceptions=None, concurrency_group=None, _generator_backpressure_num_objects=None, + _num_objects_per_yield=None, enable_task_events=None, tensor_transport: Optional[str] = None, _labels: Optional[Dict[str, str]] = None, @@ -1067,6 +1097,11 @@ def _remote( _generator_backpressure_num_objects = ( self._generator_backpressure_num_objects ) + if _num_objects_per_yield is None: + _num_objects_per_yield = self._num_objects_per_yield + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", _num_objects_per_yield + ) if tensor_transport is None: tensor_transport = self._tensor_transport @@ -1120,6 +1155,7 @@ def invocation(args, kwargs): generator_backpressure_num_objects=( _generator_backpressure_num_objects ), + num_objects_per_yield=_num_objects_per_yield, enable_task_events=enable_task_events, tensor_transport=tensor_transport, labels=_labels, @@ -1149,6 +1185,7 @@ def __getstate__(self): "decorator": self._decorator, "is_generator": self._is_generator, "generator_backpressure_num_objects": self._generator_backpressure_num_objects, # noqa + "num_objects_per_yield": self._num_objects_per_yield, "enable_task_events": self._enable_task_events, "_tensor_transport": self._tensor_transport, } @@ -1162,6 +1199,7 @@ def __setstate__(self, state): state["retry_exceptions"], state["is_generator"], state["generator_backpressure_num_objects"], + state.get("num_objects_per_yield", 1), state["enable_task_events"], state["decorator"], state["_tensor_transport"], @@ -1250,6 +1288,7 @@ def create( self.method_is_generator = {} self.enable_task_events = {} self.generator_backpressure_num_objects = {} + self.num_objects_per_yield = {} self.concurrency_group_for_methods = {} self.method_name_to_tensor_transport: Dict[str, str] = {} @@ -1317,6 +1356,10 @@ def create( self.generator_backpressure_num_objects[ method_name ] = method.__ray_generator_backpressure_num_objects__ + if hasattr(method, "__ray_num_objects_per_yield__"): + self.num_objects_per_yield[ + method_name + ] = method.__ray_num_objects_per_yield__ if hasattr(method, "__ray_tensor_transport__"): self.method_name_to_tensor_transport[ @@ -2175,6 +2218,7 @@ def _remote( meta.method_meta.max_task_retries, meta.method_meta.retry_exceptions, meta.method_meta.generator_backpressure_num_objects, + meta.method_meta.num_objects_per_yield, meta.method_meta.enable_task_events, meta.enable_tensor_transport, meta.method_meta.method_name_to_tensor_transport, @@ -2237,6 +2281,8 @@ class ActorHandle(Generic[T]): _ray_method_generator_backpressure_num_objects: Generator-only config. The max number of objects to generate before it starts pausing a generator. + _ray_method_num_objects_per_yield: Private generator-only config. + The number of ObjectRefs produced by each streaming generator yield. _ray_method_enable_task_events: The value of whether task tracing is enabled for the actor methods. This overrides the actor's default value (`_ray_enable_task_events`). @@ -2273,6 +2319,7 @@ def __init__( method_max_task_retries: Dict[str, int], method_retry_exceptions: Dict[str, Union[bool, list, tuple]], method_generator_backpressure_num_objects: Dict[str, int], + method_num_objects_per_yield: Dict[str, int], method_enable_task_events: Dict[str, bool], enable_tensor_transport: bool, method_name_to_tensor_transport: Dict[str, str], @@ -2297,6 +2344,7 @@ def __init__( method_max_task_retries: Dictionary mapping method names to their maximum task retries. method_retry_exceptions: Dictionary mapping method names to their retry exception settings. method_generator_backpressure_num_objects: Dictionary mapping method names to their generator backpressure settings. + method_num_objects_per_yield: Dictionary mapping method names to their grouped-yield arity. method_enable_task_events: Dictionary mapping method names to whether task events are enabled. enable_tensor_transport: Whether tensor transport is enabled for this actor. If True, then methods can be called with @@ -2327,6 +2375,7 @@ def __init__( self._ray_method_generator_backpressure_num_objects = ( method_generator_backpressure_num_objects ) + self._ray_method_num_objects_per_yield = method_num_objects_per_yield self._ray_method_enable_task_events = method_enable_task_events self._ray_enable_tensor_transport = enable_tensor_transport self._ray_method_name_to_tensor_transport = method_name_to_tensor_transport @@ -2372,6 +2421,9 @@ def __init__( generator_backpressure_num_objects=self._ray_method_generator_backpressure_num_objects.get( method_name ), + num_objects_per_yield=self._ray_method_num_objects_per_yield.get( + method_name, 1 + ), enable_task_events=self._ray_method_enable_task_events.get( method_name, self._ray_enable_task_events ), @@ -2412,6 +2464,7 @@ def _actor_method_call( retry_exceptions: Union[bool, list, tuple] = None, concurrency_group_name: Optional[str] = None, generator_backpressure_num_objects: Optional[int] = None, + num_objects_per_yield: Optional[int] = None, enable_task_events: Optional[bool] = None, tensor_transport: Optional[str] = None, labels: Optional[Dict[str, str]] = None, @@ -2435,6 +2488,8 @@ def _actor_method_call( concurrency_group_name: The name of the concurrency group to use. generator_backpressure_num_objects: The number of objects to generate before applying backpressure. + num_objects_per_yield: Private streaming generator option for how many + ObjectRefs each yield should unpack into. enable_task_events: True if tracing is enabled, i.e., task events from the actor should be reported. tensor_transport: The tensor transport protocol to use for the actor method. @@ -2487,6 +2542,8 @@ def _actor_method_call( if generator_backpressure_num_objects is None: generator_backpressure_num_objects = -1 + if num_objects_per_yield is None: + num_objects_per_yield = 1 object_refs = worker.core_worker.submit_actor_task( self._ray_actor_language, @@ -2501,6 +2558,7 @@ def _actor_method_call( self._ray_actor_method_cpus, concurrency_group_name if concurrency_group_name is not None else b"", generator_backpressure_num_objects, + num_objects_per_yield, enable_task_events, tensor_transport, labels, @@ -2578,6 +2636,7 @@ def remote(self, *args, **kwargs): False, # retry_exceptions False, # is_generator self._ray_method_generator_backpressure_num_objects.get(item, -1), + self._ray_method_num_objects_per_yield.get(item, 1), self._ray_enable_task_events, # enable_task_events # Currently, cross-lang actor method not support decorator decorator=None, @@ -2658,6 +2717,9 @@ def _serialization_helper(self): "method_generator_backpressure_num_objects": ( self._ray_method_generator_backpressure_num_objects ), + "method_num_objects_per_yield": ( + self._ray_method_num_objects_per_yield + ), "method_enable_task_events": self._ray_method_enable_task_events, "enable_tensor_transport": self._ray_enable_tensor_transport, "method_name_to_tensor_transport": self._ray_method_name_to_tensor_transport, @@ -2716,6 +2778,7 @@ def _deserialization_helper( state["method_max_task_retries"], state["method_retry_exceptions"], state["method_generator_backpressure_num_objects"], + state.get("method_num_objects_per_yield", {}), state["method_enable_task_events"], state["enable_tensor_transport"], state["method_name_to_tensor_transport"], diff --git a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py index e61f453d87c8..9b09c1b6f8ec 100644 --- a/python/ray/autoscaler/_private/kuberay/autoscaling_config.py +++ b/python/ray/autoscaler/_private/kuberay/autoscaling_config.py @@ -18,6 +18,7 @@ ) from ray.autoscaler._private.kuberay import node_provider, utils from ray.autoscaler._private.util import validate_config +from ray.util.debug import log_once logger = logging.getLogger(__name__) @@ -194,6 +195,8 @@ def _node_type_from_group_spec( group_spec: Dict[str, Any], is_head: bool ) -> Dict[str, Any]: """Converts CR group spec to autoscaler node type.""" + group_name = _HEAD_GROUP_NAME if is_head else group_spec["groupName"] + if is_head: # The head node type has no workers because the head is not a worker. min_workers = max_workers = 0 @@ -204,7 +207,7 @@ def _node_type_from_group_spec( max_workers = group_spec["maxReplicas"] * group_spec.get("numOfHosts", 1) resources = _get_ray_resources_from_group_spec(group_spec, is_head) - labels = _get_labels_from_group_spec(group_spec) + labels = _get_labels_from_group_spec(group_spec, group_name) node_type = { "min_workers": min_workers, @@ -308,21 +311,45 @@ def _get_ray_resources_from_group_spec( return resources -def _get_labels_from_group_spec(group_spec: Dict[str, Any]) -> Dict[str, str]: +def _get_labels_from_group_spec( + group_spec: Dict[str, Any], group_name: str = "" +) -> Dict[str, str]: """ Parses Ray node labels for the autoscaling config based on the following priority: 1. Top-level `labels` field in the group spec. 2. `labels` field in `rayStartParams`. + + Args: + group_spec: The group specification dictionary. + group_name: The name of the group (used in warning messages). + + Returns: + A dictionary of labels for the node type. """ labels_dict = {} ray_start_params = group_spec.get("rayStartParams", {}) labels_str = ray_start_params.get("labels") - if labels_str: - logger.warning( - f"Ignoring labels: {labels_str} set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5" - ) + # Use a unique log_once key per group to ensure each group's warning is shown. + log_once_key = ( + f"raystartparams_labels_warning_{group_name}" + if group_name + else "raystartparams_labels_warning" + ) + if labels_str and log_once(log_once_key): + if group_name: + logger.warning( + f"Ignoring labels: {labels_str} set in rayStartParams for group " + f"'{group_name}'. Group labels are supported in the top-level " + "Labels field starting in KubeRay v1.5" + ) + else: + logger.warning( + f"Ignoring labels: {labels_str} set in rayStartParams. " + "Group labels are supported in the top-level Labels field " + "starting in KubeRay v1.5" + ) # Check for top-level structured Labels field. if "labels" in group_spec and isinstance(group_spec.get("labels"), dict): diff --git a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py index 6013f9b423a7..2d626b21bf9e 100644 --- a/python/ray/autoscaler/_private/kuberay/run_autoscaler.py +++ b/python/ray/autoscaler/_private/kuberay/run_autoscaler.py @@ -9,7 +9,7 @@ LOGGING_ROTATE_BACKUP_COUNT, LOGGING_ROTATE_BYTES, ) -from ray._common.utils import try_to_create_directory +from ray._common.utils import env_integer, try_to_create_directory from ray._private import ray_constants from ray._private.ray_logging import setup_component_logger from ray._private.services import get_node_ip_address @@ -119,13 +119,15 @@ def _setup_logging(log_dir: str) -> None: try_to_create_directory(log_dir) # Write logs at info level to monitor.log. + max_bytes = env_integer("RAY_ROTATION_MAX_BYTES", LOGGING_ROTATE_BYTES) + backup_count = env_integer("RAY_ROTATION_BACKUP_COUNT", LOGGING_ROTATE_BACKUP_COUNT) setup_component_logger( logging_level=ray_constants.LOGGER_LEVEL, logging_format=ray_constants.LOGGER_FORMAT, log_dir=log_dir, filename=ray_constants.MONITOR_LOG_FILE_NAME, # monitor.log - max_bytes=LOGGING_ROTATE_BYTES, - backup_count=LOGGING_ROTATE_BACKUP_COUNT, + max_bytes=max_bytes, + backup_count=backup_count, ) # For the autoscaler, the root logger _also_ needs to write to stderr, not just diff --git a/python/ray/includes/common.pxd b/python/ray/includes/common.pxd index b3fd284b13d0..e6687cd053b7 100644 --- a/python/ray/includes/common.pxd +++ b/python/ray/includes/common.pxd @@ -383,16 +383,19 @@ cdef extern from "ray/core_worker/common.h" nogil: CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, - int64_t generator_backpressure_num_objects) + int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield) CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_string serialized_runtime_env) CTaskOptions(c_string name, int num_returns, unordered_map[c_string, double] &resources, c_string concurrency_group_name, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, c_string serialized_runtime_env, c_bool enable_task_events, const unordered_map[c_string, c_string] &labels, diff --git a/python/ray/includes/libcoreworker.pxd b/python/ray/includes/libcoreworker.pxd index 60a99d85f85a..d3c416476a8b 100644 --- a/python/ray/includes/libcoreworker.pxd +++ b/python/ray/includes/libcoreworker.pxd @@ -315,7 +315,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_vector[shared_ptr[CObjectLocation]] *results) CRayStatus TriggerGlobalGC() CRayStatus ReportGeneratorItemReturns( - const pair[CObjectID, shared_ptr[CRayObject]] &dynamic_return_object, + const c_vector[c_pair[CObjectID, shared_ptr[CRayObject]]] &dynamic_return_objects, const CObjectID &generator_id, const CAddress &caller_address, int64_t item_index, @@ -407,6 +407,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil: c_bool is_streaming_generator, c_bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, optional[c_string] tensor_transport ) nogil) task_execution_callback (void(const CObjectID &) nogil) free_actor_object_callback diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py index 243b3ac6e5e6..bfd0294d53af 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/base.py @@ -18,6 +18,30 @@ RequestType = Union[ChatCompletionRequest, CompletionRequest] +def base_prefill_kv_transfer_params() -> Dict[str, Any]: + """The ``kv_transfer_params`` common to a prefill (producer) request. + + Tells the prefill engine to produce KV for a remote decode. Connectors layer + their own keys (e.g. a transfer id, DP/TP routing) on top of these. + """ + return { + "do_remote_decode": True, + "do_remote_prefill": False, + "remote_engine_id": None, + "remote_block_ids": None, + } + + +def clamp_request_to_single_token(request: "RequestType") -> None: + """Clamp a prefill request to a single, non-streaming token (in place).""" + request.max_tokens = 1 + if hasattr(request, "max_completion_tokens"): + request.max_completion_tokens = 1 + request.stream = False + if hasattr(request, "stream_options"): + request.stream_options = None + + class BaseConnectorBackend(abc.ABC): # ---- P/D coordination protocol ---- # @@ -159,6 +183,19 @@ def setup(self) -> None: """ pass + def replica_metadata(self) -> Dict[str, Any]: + """Static per-replica coordination data published to the orchestrator. + + Surfaced via the replica-metadata hook on ``ReplicaSelection`` so that a + connector opting into ``requires_peer_binding`` can address the selected + prefill peer. The default backend publishes nothing; connectors that need + to advertise an address (e.g. MoRIIO's zmq endpoint) override this. + + Returns: + A JSON-serializable dict of per-replica metadata (empty by default). + """ + return {} + class DefaultPDProtocolMixin: """The default P/D protocol policy: no peer binding, sequential handoff. @@ -187,19 +224,11 @@ def prepare_prefill_request( ), "kv_transfer_params should be empty before orchestrator" prefill_request = request.model_copy(deep=True) prefill_request.kv_transfer_params = { - "do_remote_decode": True, - "do_remote_prefill": False, - "remote_engine_id": None, - "remote_block_ids": None, + **base_prefill_kv_transfer_params(), "remote_host": None, "remote_port": None, } - prefill_request.max_tokens = 1 - if hasattr(prefill_request, "max_completion_tokens"): - prefill_request.max_completion_tokens = 1 - prefill_request.stream = False - if hasattr(prefill_request, "stream_options"): - prefill_request.stream_options = None + clamp_request_to_single_token(prefill_request) return prefill_request def prepare_decode_request( diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py index aeec2348c554..54b7c12be713 100644 --- a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/factory.py @@ -121,6 +121,7 @@ def unregister_backend(cls, name: str) -> None: "LMCacheConnectorV1": "ray.llm._internal.serve.engines.vllm.kv_transfer.lmcache:LMCacheConnectorV1Backend", "NixlConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.nixl:NixlConnectorBackend", "MultiConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.multi_connector:MultiConnectorBackend", + "MoRIIOConnector": "ray.llm._internal.serve.engines.vllm.kv_transfer.moriio:MoRIIOConnectorBackend", } diff --git a/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py new file mode 100644 index 000000000000..d05e4dca5627 --- /dev/null +++ b/python/ray/llm/_internal/serve/engines/vllm/kv_transfer/moriio.py @@ -0,0 +1,342 @@ +"""MoRIIO connector backend for Ray Serve LLM (analogue of nixl.py). + +Configures a vLLM engine's ``kv_transfer_config.kv_connector_extra_config`` for +the MoRIIO connector and computes per-replica handshake/notify ports so colocated +replicas don't collide. Also builds the engine's advertised zmq address so the +P/D orchestrator can discover it via the replica-metadata hook +(``ReplicaSelection.replica_metadata``), and implements the PD connector protocol +(``requires_peer_binding`` / ``concurrent_handoff`` / ``prepare_prefill_request`` / +``prepare_decode_request``) so the decode orchestrator can address the selected +prefill peer by request id. + +Unlike NIXL/LMCache, MoRIIO does NOT use ``DefaultPDProtocolMixin``: it has custom +request shaping (a dual-address request_id + transfer_id) and therefore IMPLEMENTS +the abstract ``prepare_*`` methods directly on ``BaseConnectorBackend``. + +Two transfer disciplines, selected by ``read_mode``: + * WRITE (default): prefill PUSHES KV to decode -> concurrent handoff. + * READ: decode PULLS KV from prefill -> sequential handoff; the decode request + forwards the ``remote_block_ids`` / ``remote_engine_id`` the prefill engine + returned. + +The dual-address request_id and the transfer_id are derived DETERMINISTICALLY +from the incoming request id (a hash), so ``prepare_prefill_request`` and +``prepare_decode_request`` produce identical ids across their two separate calls +without per-request backend state (the backend instance is shared across +requests). + +Registered with Ray's public connector registry via the factory. +""" + +import hashlib +import logging +import re +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import ray +from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, + base_prefill_kv_transfer_params, + clamp_request_to_single_token, +) + +if TYPE_CHECKING: + from ray.llm._internal.serve.engines.vllm.kv_transfer.base import RequestType + +logger = logging.getLogger(__name__) + +# Defaults mirror vLLM's MoRIIOConstants (DEFAULT_HANDSHAKE_PORT / NOTIFY_PORT). +# Prefill uses these bases; decode is shifted (see builder.py) so a colocated +# P+D pair on one node doesn't collide. +DEFAULT_HANDSHAKE_PORT_BASE = 6301 +DEFAULT_NOTIFY_PORT_BASE = 61005 + +# experimental_configs keys understood by this backend. +HANDSHAKE_PORT_BASE_KEY = "MORI_HANDSHAKE_PORT_BASE" +NOTIFY_PORT_BASE_KEY = "MORI_NOTIFY_PORT_BASE" + +# --------------------------------------------------------------------------- +# Dual-address request_id / zmq address encoding. +# +# These MUST stay byte-compatible with the regexes vLLM's MoRIIO connector uses +# to recover peer addresses from the request_id: +# +# vllm/distributed/kv_transfer/kv_connector/v1/moriio/moriio_common.py +# _PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_") +# _DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$") +# # zmq address: "host:IP,handshake:PORT,notify:PORT" +# --------------------------------------------------------------------------- + +_PREFILL_PREFIX = "___prefill_addr_" +_DECODE_PREFIX = "___decode_addr_" +_TRANSFER_PREFIX = "tx" + +# Copies of vLLM's regexes for local validation / round-trip tests. +_PREFILL_ZMQ_RE = re.compile(r"___prefill_addr_(.+?)___decode_addr_") +_DECODE_ZMQ_RE = re.compile(r"___decode_addr_(.+)_[0-9a-f]{32}(?:-.*)?$") + + +def build_zmq_address(host: str, handshake_port: int, notify_port: int) -> str: + """Build the MORI zmq address string ``host:IP,handshake:PORT,notify:PORT``.""" + return f"host:{host},handshake:{handshake_port},notify:{notify_port}" + + +def parse_zmq_address(zmq_address: str) -> Tuple[str, int, int]: + """Inverse of :func:`build_zmq_address` -> ``(host, handshake_port, notify_port)``.""" + parts = {} + for segment in zmq_address.split(","): + key, _, val = segment.partition(":") + parts[key.strip()] = val.strip() + return parts["host"], int(parts["handshake"]), int(parts["notify"]) + + +def parse_peer_zmq(request_id: str, is_producer: bool) -> str: + """Recover the peer's zmq address from a request id (for tests/debugging). + + Producer (prefill) wants the *decode* address; consumer wants the *prefill*. + """ + rex = _DECODE_ZMQ_RE if is_producer else _PREFILL_ZMQ_RE + m = rex.search(request_id) + if not m: + raise ValueError(f"No peer zmq address in request_id: {request_id!r}") + return m.group(1) + + +def _read_mode_enabled(extra_config: Dict[str, Any]) -> bool: + """Mirror vLLM's ``get_moriio_mode`` parse of ``read_mode``. + + true / 1 -> READ; anything else -> WRITE (default). + """ + return str(extra_config.get("read_mode", "false")).lower().strip() in ( + "true", + "1", + ) + + +class MoRIIOConnectorBackend(BaseConnectorBackend): + """Set up MoRIIO ports/extra_config and implement the PD connector protocol.""" + + # The advertised zmq address ("host:IP,handshake:PORT,notify:PORT"), + # computed by setup(); consumers reach it via this backend instance. + _zmq_address: Optional[str] = None + + # MORI addresses peers by the dual-address request id, so the orchestrator + # must bind to the selected prefill replica BEFORE dispatch. + requires_peer_binding: bool = True + + def _extra_config(self) -> dict: + cfg = self.kv_transfer_config.setdefault("kv_connector_extra_config", {}) + return cfg + + @property + def _read_mode(self) -> bool: + """True iff this engine's MoRIIO connector is configured for READ mode.""" + extra = self._extra_config() + return _read_mode_enabled(extra) + + @property + def concurrent_handoff(self) -> bool: + """WRITE -> concurrent (prefill pushes); READ -> sequential (decode pulls).""" + return not self._read_mode + + def setup(self) -> None: + offset = self._compute_port_offset() + + handshake_base = int( + self.llm_config.experimental_configs.get( + HANDSHAKE_PORT_BASE_KEY, DEFAULT_HANDSHAKE_PORT_BASE + ) + ) + notify_base = int( + self.llm_config.experimental_configs.get( + NOTIFY_PORT_BASE_KEY, DEFAULT_NOTIFY_PORT_BASE + ) + ) + + # NOTE: vLLM internally adds get_port_offset(dp_rank, tp_rank) on top of + # these bases. For TP/DP>1, reserve a stride >= tp_size*pp_size when + # shifting decode's base in the builder so the two offset schemes never + # overlap. + handshake_port = handshake_base + offset + notify_port = notify_base + offset + + extra = self._extra_config() + # Required keys for vLLM's config parser (KeyError otherwise) -- proxyless. + extra.setdefault("proxy_ip", "") # empty => ping/registration thread disabled + extra.setdefault("proxy_ping_port", "0") + # TODO: real Serve replica HTTP port. Harmless placeholder while + # proxy_ip="" (only used to build request_address for the disabled ping). + extra.setdefault("http_port", str(8000 + offset)) + # WRITE mode (prefill pushes). READ would be "true". + extra.setdefault("read_mode", "false") + extra["handshake_port"] = str(handshake_port) + extra["notify_port"] = str(notify_port) + + # Advertise the Ray internal cluster IP as the zmq host. + host = ray.util.get_node_ip_address() + zmq_address = build_zmq_address(host, handshake_port, notify_port) + # Stash so replica_metadata() can publish it; the decode + # orchestrator reads the selected prefill replica's copy off the peer. + self._zmq_address = zmq_address + # NOTE: cross-node correctness additionally needs each worker to + # advertise the node INTERNAL IP (set VLLM_HOST_IP inside every worker + # process). VLLM_HOST_IP is excluded from vLLM's driver->worker env copy, + # so it can only be set in-process -- handled by a vLLM general-plugin + # shipped separately. Single-node deployments work without it. + + # ---- parallelism (data/tensor) ---- + + def _dp_rank(self) -> int: + rank = self.llm_config.engine_kwargs.get("data_parallel_rank") + return rank if isinstance(rank, int) and rank >= 0 else 0 + + def _dp_size(self) -> int: + return int(self.llm_config.engine_kwargs.get("data_parallel_size") or 1) + + def _tp_size(self) -> int: + return int(self.llm_config.engine_kwargs.get("tensor_parallel_size") or 1) + + # ---- replica metadata (published via the replica-metadata hook) ---- + + def replica_metadata(self) -> dict: + """Static per-replica coordination data published to the orchestrator. + + The prefill replica publishes its MORI zmq address and its parallelism + (DP rank/size, TP size); the decode orchestrator reads them off the + selected prefill replica's ``ReplicaSelection.replica_metadata`` and uses + them to address the right remote (dp_rank, tp) workers. + """ + return { + "mori_zmq_address": self._zmq_address, + "dp_rank": self._dp_rank(), + "dp_size": self._dp_size(), + "tp_size": self._tp_size(), + } + + def _remote_routing(self, remote: Dict[str, Any]) -> Dict[str, Any]: + """``kv_transfer_params`` keys telling vLLM which remote workers to reach. + + ``remote`` is the metadata of the *other* side of the transfer: the + decode (this orchestrator) for a prefill request, the selected prefill + peer for a decode request. vLLM addresses a remote worker at + ``advertised_base + get_port_offset(remote_dp_rank, tp_index)`` and + handshakes all ``remote_dp_size`` ranks, so both must match the target + replica. ``tp_size`` is the remote's TP (symmetric across P/D). + """ + return { + "remote_dp_rank": int(remote.get("dp_rank", 0)), + "remote_dp_size": int(remote.get("dp_size", 1)), + "tp_size": int(remote.get("tp_size", self._tp_size())), + } + + def _own_routing(self) -> Dict[str, Any]: + """``_remote_routing`` input describing this (decode orchestrator) replica + -- the remote side for a prefill request.""" + return { + "dp_rank": self._dp_rank(), + "dp_size": self._dp_size(), + "tp_size": self._tp_size(), + } + + # ---- request shaping (PD connector protocol) ---- + + def _dual_ids( + self, request: Any, peer: Optional[Dict[str, Any]] + ) -> Tuple[str, str]: + """Compute the (dual-address request_id, transfer_id) for this request. + + ``prepare_prefill_request`` and ``prepare_decode_request`` are two + independent, stateless calls for the same request, so both ids are + derived deterministically (hash of a stable per-request seed) — no + per-request backend state. + """ + prefill_zmq = (peer or {}).get("mori_zmq_address") + decode_zmq = self._zmq_address + if not prefill_zmq: + raise ValueError( + "MoRIIO peer is missing 'mori_zmq_address': the selected prefill " + "replica did not publish its address (is MoRIIOConnector " + "configured on the prefill deployment?)." + ) + if not decode_zmq: + raise ValueError( + "MoRIIO decode zmq address is not set: setup() must run on this " + "engine before requests are shaped." + ) + # The incoming request_id (always populated -- OpenAI models default it + # to a uuid) is the seed. Both prepare_* calls run on the same request + # object, so they agree; uniqueness per request is inherited from it. + seed = str(request.request_id) + # 32 hex chars (the trailing uid _PREFILL_ZMQ_RE / _DECODE_ZMQ_RE + # anchor on); a hash of the seed, so both prepare_* calls agree. + uid = hashlib.sha256(seed.encode()).hexdigest()[:32] + # Wire format consumed by vLLM's MoRIIO connector. + request_id = f"{_PREFILL_PREFIX}{prefill_zmq}{_DECODE_PREFIX}{decode_zmq}_{uid}" + transfer_id = f"{_TRANSFER_PREFIX}-{uid}" + return request_id, transfer_id + + def prepare_prefill_request( + self, *, request: "RequestType", peer: Optional[Dict[str, Any]] + ) -> "RequestType": + request_id, transfer_id = self._dual_ids(request, peer) + prefill_request = request.model_copy(deep=True) + # The dual-address id (peer zmq encoded in it) must reach the engine: + # setting request_id explicitly makes the LLMServer pipeline preserve it + # (not clobber it with the Serve id) and the engine copies it into the + # X-Request-Id header that vLLM's MoRIIO connector parses. + prefill_request.request_id = request_id + prefill_request.kv_transfer_params = { + **base_prefill_kv_transfer_params(), + "transfer_id": transfer_id, + # The prefill engine's remote is the decode (this orchestrator). + **self._remote_routing(self._own_routing()), + } + clamp_request_to_single_token(prefill_request) + return prefill_request + + def prepare_decode_request( + self, + *, + request: "RequestType", + peer: Optional[Dict[str, Any]], + prefill_response: Optional[Any], + ) -> "RequestType": + request_id, transfer_id = self._dual_ids(request, peer) + decode_request = request.model_copy(deep=True) + decode_request.request_id = request_id + # The decode engine's remote is the selected prefill peer. + remote_routing = self._remote_routing(peer or {}) + + if not self._read_mode: + # WRITE: prefill pushes KV; decode just needs do_remote_prefill + the + # shared transfer_id (no block ids -- they are pushed, not pulled). + decode_request.kv_transfer_params = { + "do_remote_prefill": True, + "do_remote_decode": False, + "remote_engine_id": None, + "remote_block_ids": None, + "transfer_id": transfer_id, + **remote_routing, + } + return decode_request + + # READ: decode PULLS KV; forward the remote_block_ids / remote_engine_id + # the prefill engine returned on its response. If absent (e.g. prompt < + # block_size / full prefix hit), fall back to a local recompute. + prefill_kv_params = getattr(prefill_response, "kv_transfer_params", None) + params = dict(prefill_kv_params) if prefill_kv_params else {} + if params.get("remote_block_ids") and params.get("remote_engine_id"): + params.setdefault("transfer_id", transfer_id) + params["do_remote_prefill"] = True + params["do_remote_decode"] = False + # Address the prefill peer's (dp_rank, dp_size, tp) workers. + params.update(remote_routing) + decode_request.kv_transfer_params = params + else: + logger.warning( + "[MORI][READ] prefill returned no remote_block_ids/remote_engine_id " + "(kv_transfer_params=%s); decode will recompute locally.", + prefill_kv_params, + ) + decode_request.kv_transfer_params = None + return decode_request diff --git a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py index f095d39fd599..61af89007db3 100644 --- a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py +++ b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/builder.py @@ -183,6 +183,36 @@ def _default_decode_nixl_port_base(self): ) return self + @model_validator(mode="after") + def _default_decode_moriio_port_base(self): + """Shift decode's MoRIIO handshake/notify bases off prefill's defaults. + + Mirrors ``_default_decode_nixl_port_base``: a colocated P+D pair on one + node would otherwise share MoRIIO's default handshake/notify ports. Only + applies when the decode config uses the MoRIIO connector. The +1000 + stride is well above any realistic tp_size*pp_size offset added on top. + """ + kv_transfer_config = ( + self.decode_config.engine_kwargs.get("kv_transfer_config") or {} + ) + if kv_transfer_config.get("kv_connector") != "MoRIIOConnector": + return self + + from ray.llm._internal.serve.engines.vllm.kv_transfer.moriio import ( + DEFAULT_HANDSHAKE_PORT_BASE, + DEFAULT_NOTIFY_PORT_BASE, + HANDSHAKE_PORT_BASE_KEY, + NOTIFY_PORT_BASE_KEY, + ) + + self.decode_config.experimental_configs.setdefault( + HANDSHAKE_PORT_BASE_KEY, DEFAULT_HANDSHAKE_PORT_BASE + 1000 + ) + self.decode_config.experimental_configs.setdefault( + NOTIFY_PORT_BASE_KEY, DEFAULT_NOTIFY_PORT_BASE + 1000 + ) + return self + # --------------------------------------------------------------------------- # Builder diff --git a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py index ddca4dafce1b..30bcd43edeb2 100644 --- a/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py +++ b/python/ray/llm/_internal/serve/serving_patterns/prefill_decode/pd_server.py @@ -175,6 +175,12 @@ async def _pd_handle_request( the resolved KV-connector backend. With the default backend flags (``requires_peer_binding=False``, ``concurrent_handoff=False``) the control flow and calls are identical to the historical NIXL/default flow. + + A connector that encodes coordination data in the request id (MoRIIO's + dual-address id) just stamps ``request.request_id`` in ``prepare_*``; it + then reaches both engines unchanged -- the LLMServer pipeline preserves + an explicitly-set request_id (it no longer clobbers it with the Serve + id) and the engine copies it into the ``X-Request-Id`` header it reads. """ # Determine method name for the handle call @@ -439,9 +445,21 @@ async def _maybe_prewarm(self) -> None: logger.info("[PDDecodeServer] Starting pre-warm across all P replicas.") + backend = self._get_connector_backend() + if backend.requires_peer_binding: + # Peer-binding connectors (e.g. MoRIIO) shape a prefill request + # against a specific selected replica's metadata; a peerless + # broadcast prewarm cannot bind one. The connector handshake + # completes on the first real request instead. + logger.info( + "[PDDecodeServer] Skipping pre-warm: connector %s requires peer " + "binding (handshake completes on the first real request).", + type(backend).__name__, + ) + return + model_id = self._llm_config.model_id dummy = self._make_dummy_request(model_id) - backend = self._get_connector_backend() prefill_req = backend.prepare_prefill_request(request=dummy, peer=None) # Broadcast to every live P replica; retry until they are up. @@ -551,6 +569,25 @@ class PDPrefillServer(LLMServer): method used during the pre-warm handshake. """ + async def record_replica_metadata(self) -> Dict[str, Any]: + """Publish this prefill replica's connector coordination metadata. + + Read by the decode orchestrator via the replica-metadata hook + (``ReplicaSelection.replica_metadata``) so peer-binding connectors (e.g. + MoRIIO) can address the selected prefill replica. Returns ``{}`` for + connectors that publish nothing (the ``BaseConnectorBackend`` default). + + Returns the metadata of the backend that engine init + (``setup_engine_backend``) created, ``setup()``-ed, and stored on this + server's ``_llm_config``. The replica-metadata hook is captured after + engine init, so for connector deployments the backend is present by + then; with no backend stored there is nothing to publish. + """ + backend = getattr(self._llm_config, "kv_connector_backend", None) + if backend is None: + return {} + return backend.replica_metadata() + async def prewarm_prefill( self, prefill_request: CompletionRequest ) -> Optional[dict]: diff --git a/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py new file mode 100644 index 000000000000..8878ab36ca4c --- /dev/null +++ b/python/ray/llm/tests/serve/cpu/deployments/llm/vllm/kv_transfer_backends/test_moriio_connector.py @@ -0,0 +1,365 @@ +import re +import sys +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + +from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, +) +from ray.llm._internal.serve.engines.vllm.kv_transfer.factory import ( + KVConnectorBackendFactory, +) +from ray.llm._internal.serve.engines.vllm.kv_transfer.moriio import ( + _DECODE_ZMQ_RE, + _PREFILL_ZMQ_RE, + DEFAULT_HANDSHAKE_PORT_BASE, + DEFAULT_NOTIFY_PORT_BASE, + MoRIIOConnectorBackend, + parse_peer_zmq, + parse_zmq_address, +) +from ray.serve.llm import LLMConfig +from ray.serve.schema import ReplicaRank + +_TEST_HOST = "10.0.0.5" + + +def _replica_context(global_rank: int) -> SimpleNamespace: + return SimpleNamespace( + rank=ReplicaRank(rank=global_rank, node_rank=0, local_rank=global_rank) + ) + + +def _make_backend( + read_mode: bool = False, + extra_exp: dict = None, + dp_rank: int = None, + dp_size: int = None, + tp_size: int = None, +): + extra_config = {} + if read_mode: + extra_config["read_mode"] = "true" + engine_kwargs = dict( + kv_transfer_config=dict( + kv_connector="MoRIIOConnector", + kv_role="kv_both", + kv_connector_extra_config=extra_config, + ) + ) + if dp_rank is not None: + engine_kwargs["data_parallel_rank"] = dp_rank + if dp_size is not None: + engine_kwargs["data_parallel_size"] = dp_size + if tp_size is not None: + engine_kwargs["tensor_parallel_size"] = tp_size + return MoRIIOConnectorBackend( + llm_config=LLMConfig( + model_loading_config=dict(model_id="Qwen/Qwen3-0.6B"), + engine_kwargs=engine_kwargs, + experimental_configs=extra_exp or {}, + ), + ) + + +def _setup(backend, rank: int = 0): + with ( + patch.dict("os.environ", {}, clear=True), + patch("ray.util.get_node_ip_address", return_value=_TEST_HOST), + patch("ray.serve.get_replica_context", return_value=_replica_context(rank)), + ): + backend.setup() + + +class TestMoRIIOConnectorBackendSetup: + def test_setup_sets_ports_zmq_and_extra_config(self): + backend = _make_backend() + _setup(backend, rank=0) + + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == str(DEFAULT_HANDSHAKE_PORT_BASE) + assert extra["notify_port"] == str(DEFAULT_NOTIFY_PORT_BASE) + assert extra["proxy_ip"] == "" + assert extra["proxy_ping_port"] == "0" + assert "http_port" in extra + assert extra["read_mode"] == "false" + + zmq = backend._zmq_address + host, hs, notify = parse_zmq_address(zmq) + assert host == _TEST_HOST + assert hs == DEFAULT_HANDSHAKE_PORT_BASE + assert notify == DEFAULT_NOTIFY_PORT_BASE + + def test_setup_port_offset_uses_replica_rank(self): + backend = _make_backend() + num_devices = backend.llm_config.get_engine_config().num_devices + _setup(backend, rank=2) + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == str( + DEFAULT_HANDSHAKE_PORT_BASE + 2 * num_devices + ) + assert extra["notify_port"] == str(DEFAULT_NOTIFY_PORT_BASE + 2 * num_devices) + + def test_setup_respects_overridden_bases(self): + backend = _make_backend( + extra_exp={ + "MORI_HANDSHAKE_PORT_BASE": 7000, + "MORI_NOTIFY_PORT_BASE": 62000, + } + ) + _setup(backend, rank=0) + extra = backend.kv_transfer_config["kv_connector_extra_config"] + assert extra["handshake_port"] == "7000" + assert extra["notify_port"] == "62000" + + def test_requires_peer_binding(self): + assert MoRIIOConnectorBackend.requires_peer_binding is True + + def test_concurrent_handoff_write_vs_read(self): + write_backend = _make_backend(read_mode=False) + read_backend = _make_backend(read_mode=True) + assert write_backend.concurrent_handoff is True + assert read_backend.concurrent_handoff is False + assert write_backend._read_mode is False + assert read_backend._read_mode is True + + @pytest.mark.parametrize( + "value,expected_read", + [ + ("true", True), + ("True", True), + ("1", True), + ("false", False), + ("0", False), + ("", False), + ], + ) + def test_read_mode_parsing(self, value, expected_read): + backend = MoRIIOConnectorBackend( + llm_config=LLMConfig( + model_loading_config=dict(model_id="Qwen/Qwen3-0.6B"), + engine_kwargs=dict( + kv_transfer_config=dict( + kv_connector="MoRIIOConnector", + kv_connector_extra_config={"read_mode": value}, + ) + ), + ), + ) + assert backend._read_mode is expected_read + + def test_replica_metadata_returns_zmq(self): + backend = _make_backend() + _setup(backend, rank=0) + meta = backend.replica_metadata() + assert meta["mori_zmq_address"] == backend._zmq_address + # Parallelism is published so the orchestrator can address remote workers. + assert meta["dp_rank"] == 0 and meta["dp_size"] == 1 and meta["tp_size"] == 1 + + def test_replica_metadata_publishes_dp_tp(self): + backend = _make_backend(dp_rank=2, dp_size=4, tp_size=8) + _setup(backend, rank=0) + meta = backend.replica_metadata() + assert meta["dp_rank"] == 2 + assert meta["dp_size"] == 4 + assert meta["tp_size"] == 8 + + def test_replica_metadata_default_empty(self): + # The default backend publishes nothing (concrete default on the base). + assert BaseConnectorBackend.replica_metadata(None) == {} + + +class TestMoRIIORequestId: + def _prepared_pair(self, backend, request, peer): + prefill = backend.prepare_prefill_request(request=request, peer=peer) + # prefill_response is unused in WRITE mode; pass a dummy with no params. + decode = backend.prepare_decode_request( + request=request, + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + return prefill, decode + + def _request_with_copy(self, request_id="user-req-123"): + class _Req: + def __init__(self, rid): + self.request_id = rid + self.kv_transfer_params = None + self.max_tokens = 128 + self.max_completion_tokens = 128 + self.stream = True + self.stream_options = {"include_usage": True} + + def model_copy(self, deep=False): + new = _Req(self.request_id) + return new + + return _Req(request_id) + + def test_prefill_and_decode_share_request_id_and_transfer_id(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + decode_zmq = backend._zmq_address + prefill_zmq = "host:10.0.0.9,handshake:6301,notify:61005" + peer = {"mori_zmq_address": prefill_zmq} + + req = self._request_with_copy("user-req-123") + prefill, decode = self._prepared_pair(backend, req, peer) + + assert prefill.request_id == decode.request_id + assert ( + prefill.kv_transfer_params["transfer_id"] + == decode.kv_transfer_params["transfer_id"] + ) + + # Round-trips to the right peer zmq via the vLLM regexes. + assert _PREFILL_ZMQ_RE.search(prefill.request_id).group(1) == prefill_zmq + assert _DECODE_ZMQ_RE.search(prefill.request_id) is not None + assert parse_peer_zmq(prefill.request_id, is_producer=False) == prefill_zmq + assert parse_peer_zmq(prefill.request_id, is_producer=True) == decode_zmq + + # transfer_id format tx-<32hex>. + assert re.fullmatch( + r"tx-[0-9a-f]{32}", prefill.kv_transfer_params["transfer_id"] + ) + + def test_id_is_deterministic_across_calls(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + + p1 = backend.prepare_prefill_request( + request=self._request_with_copy("R"), peer=peer + ) + p2 = backend.prepare_prefill_request( + request=self._request_with_copy("R"), peer=peer + ) + assert p1.request_id == p2.request_id + assert ( + p1.kv_transfer_params["transfer_id"] == p2.kv_transfer_params["transfer_id"] + ) + + def test_prefill_kv_params_write(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + prefill = backend.prepare_prefill_request( + request=self._request_with_copy(), peer=peer + ) + assert prefill.kv_transfer_params["do_remote_decode"] is True + assert prefill.kv_transfer_params["do_remote_prefill"] is False + assert prefill.max_tokens == 1 + assert prefill.stream is False + # vLLM reads "tp_size" (not "remote_tp_size"); DP defaults at dp_size=1. + assert "remote_tp_size" not in prefill.kv_transfer_params + assert prefill.kv_transfer_params["tp_size"] == 1 + assert prefill.kv_transfer_params["remote_dp_rank"] == 0 + assert prefill.kv_transfer_params["remote_dp_size"] == 1 + + def test_decode_kv_params_write(self): + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + assert decode.kv_transfer_params["do_remote_prefill"] is True + assert decode.kv_transfer_params["do_remote_decode"] is False + assert decode.kv_transfer_params["remote_block_ids"] is None + assert "remote_tp_size" not in decode.kv_transfer_params + assert decode.kv_transfer_params["tp_size"] == 1 + assert decode.kv_transfer_params["remote_dp_rank"] == 0 + assert decode.kv_transfer_params["remote_dp_size"] == 1 + + def test_dp_routing_targets_correct_ranks(self): + """With DP>1, the prefill request addresses the decode (this + orchestrator) rank; the decode request addresses the selected prefill + peer's rank (read from peer metadata).""" + # This orchestrator is decode dp_rank=1 of a 2-way DP decode group. + backend = _make_backend(read_mode=False, dp_rank=1, dp_size=2, tp_size=4) + _setup(backend, rank=0) + # The selected prefill peer is dp_rank=3 of a 4-way DP prefill group. + peer = { + "mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005", + "dp_rank": 3, + "dp_size": 4, + "tp_size": 4, + } + prefill = backend.prepare_prefill_request( + request=self._request_with_copy(), peer=peer + ) + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + # Prefill engine's remote == this decode orchestrator (rank 1 of 2). + assert prefill.kv_transfer_params["remote_dp_rank"] == 1 + assert prefill.kv_transfer_params["remote_dp_size"] == 2 + # Decode engine's remote == the selected prefill peer (rank 3 of 4). + assert decode.kv_transfer_params["remote_dp_rank"] == 3 + assert decode.kv_transfer_params["remote_dp_size"] == 4 + assert decode.kv_transfer_params["tp_size"] == 4 + + def test_decode_kv_params_read_forwards_prefill_params(self): + backend = _make_backend(read_mode=True) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + prefill_resp = SimpleNamespace( + kv_transfer_params={ + "remote_block_ids": [1, 2, 3], + "remote_engine_id": "eng-7", + } + ) + decode = backend.prepare_decode_request( + request=self._request_with_copy(), peer=peer, prefill_response=prefill_resp + ) + assert decode.kv_transfer_params["do_remote_prefill"] is True + assert decode.kv_transfer_params["remote_block_ids"] == [1, 2, 3] + assert decode.kv_transfer_params["remote_engine_id"] == "eng-7" + assert "transfer_id" in decode.kv_transfer_params + # READ also stamps the remote (prefill peer) routing. + assert decode.kv_transfer_params["remote_dp_rank"] == 0 + assert decode.kv_transfer_params["remote_dp_size"] == 1 + assert decode.kv_transfer_params["tp_size"] == 1 + + def test_decode_read_fallback_when_no_remote_params(self): + backend = _make_backend(read_mode=True) + _setup(backend, rank=0) + peer = {"mori_zmq_address": "host:10.0.0.9,handshake:6301,notify:61005"} + decode = backend.prepare_decode_request( + request=self._request_with_copy(), + peer=peer, + prefill_response=SimpleNamespace(kv_transfer_params=None), + ) + assert decode.kv_transfer_params is None + + +class TestMoRIIOZmqValidation: + def test_missing_peer_zmq_raises(self): + """A missing/empty peer mori_zmq_address must raise a clear error, not + silently build a request id containing "None".""" + backend = _make_backend(read_mode=False) + _setup(backend, rank=0) + request = TestMoRIIORequestId()._request_with_copy("user-req-123") + for peer in (None, {}, {"mori_zmq_address": ""}): + with pytest.raises(ValueError, match="mori_zmq_address"): + backend.prepare_prefill_request(request=request, peer=peer) + + +class TestMoRIIOFactory: + def test_registered(self): + assert KVConnectorBackendFactory.is_registered("MoRIIOConnector") + + def test_create_backend_returns_class(self): + backend_class = KVConnectorBackendFactory.get_backend_class("MoRIIOConnector") + assert backend_class is MoRIIOConnectorBackend + assert issubclass(backend_class, BaseConnectorBackend) + + +if __name__ == "__main__": + sys.exit(pytest.main(["-v", __file__])) diff --git a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py index 06d245c0c409..adccf0f6aa21 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py +++ b/python/ray/llm/tests/serve/cpu/deployments/prefill_decode_disagg/test_prefill_decode_disagg.py @@ -780,6 +780,35 @@ async def _gen(): assert chunks == [prefill_error] assert decode_aborted["value"] is True + @pytest.mark.asyncio + async def test_prewarm_skipped_for_peer_binding_backend(self): + """Pre-warm broadcasts a peerless prefill request, which a peer-binding + connector (e.g. MoRIIO) cannot shape -- so it must be skipped rather + than crash decode-replica init.""" + from ray.llm._internal.serve.engines.vllm.kv_transfer.base import ( + BaseConnectorBackend, + ) + + class _PeerBindingBackend(BaseConnectorBackend): + requires_peer_binding = True + + def prepare_prefill_request(self, *, request, peer): + raise AssertionError("prepare_prefill_request must not be called") + + def prepare_decode_request(self, *, request, peer, prefill_response): + raise AssertionError("prepare_decode_request must not be called") + + server = PDDecodeServer.__new__(PDDecodeServer) + server._llm_config = LLMConfig( + model_loading_config=ModelLoadingConfig(model_id="test-model"), + experimental_configs={"_prewarm_prefill_decode": True}, + ) + server._llm_config._kv_connector_backend = _PeerBindingBackend( + server._llm_config + ) + # Must return without raising (and without touching prepare_*). + await server._maybe_prewarm() + class TestBuildPDOpenaiApp: """Test suite for build_pd_openai_app function.""" diff --git a/python/ray/remote_function.py b/python/ray/remote_function.py index 57fae3e65551..1ee1bf26a604 100644 --- a/python/ray/remote_function.py +++ b/python/ray/remote_function.py @@ -281,6 +281,15 @@ def f(): # Task g will require 2 gpus instead of 1. g = f.options(num_gpus=2) """ + if "_num_objects_per_yield" in task_options: + num_objects_per_yield = ( + self._default_options.get("_num_objects_per_yield") or 1 + ) + if task_options["_num_objects_per_yield"] != num_objects_per_yield: + raise ValueError( + "_num_objects_per_yield cannot be overridden per task call. " + "Use @ray.remote(_num_objects_per_yield=...) instead." + ) func_cls = self @@ -435,6 +444,12 @@ def _remote( ] if generator_backpressure_num_objects is None: generator_backpressure_num_objects = -1 + num_objects_per_yield = task_options["_num_objects_per_yield"] + if num_objects_per_yield is None: + num_objects_per_yield = 1 + ray_option_utils.task_options["_num_objects_per_yield"].validate( + "_num_objects_per_yield", num_objects_per_yield + ) max_retries = task_options["max_retries"] retry_exceptions = task_options["retry_exceptions"] @@ -517,6 +532,7 @@ def invocation(args, kwargs): worker.debugger_breakpoint, serialized_runtime_env_info or "{}", generator_backpressure_num_objects, + num_objects_per_yield, enable_task_events, labels, label_selector, diff --git a/python/ray/tests/kuberay/test_autoscaling_config.py b/python/ray/tests/kuberay/test_autoscaling_config.py index eb557c558818..e4bcc3094015 100644 --- a/python/ray/tests/kuberay/test_autoscaling_config.py +++ b/python/ray/tests/kuberay/test_autoscaling_config.py @@ -517,7 +517,7 @@ def test_resource_quantity(input: str, output: int): _get_basic_autoscaling_config(), None, None, - "Ignoring labels: ray.io/accelerator-type=TPU-V4 set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", + "Ignoring labels: ray.io/accelerator-type=TPU-V4 set in rayStartParams for group 'tpu-group'. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", id="groups-with-raystartparam-labels", ), pytest.param( @@ -525,7 +525,7 @@ def test_resource_quantity(input: str, output: int): _get_autoscaling_config_with_top_level_labels(), None, None, - "Ignoring labels: instance-type=n2 set in rayStartParams. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", + "Ignoring labels: instance-type=n2 set in rayStartParams for group 'small-group'. Group labels are supported in the top-level Labels field starting in KubeRay v1.5", id="groups-with-top-level-labels", ), pytest.param( @@ -566,6 +566,10 @@ def test_autoscaling_config( expected_log_warning: Optional[str], ): ray_cr_in["metadata"]["namespace"] = "default" + # Reset log_once state to ensure each test case is independent. + from ray.util.debug import _logged + + _logged.clear() with mock.patch(f"{AUTOSCALING_CONFIG_MODULE_PATH}.logger") as mock_logger: if expected_error: with pytest.raises(expected_error, match=expected_error_message): diff --git a/python/ray/tests/test_streaming_generator.py b/python/ray/tests/test_streaming_generator.py index 666efe6b7a40..34154f427adc 100644 --- a/python/ray/tests/test_streaming_generator.py +++ b/python/ray/tests/test_streaming_generator.py @@ -538,6 +538,127 @@ async def verify_async_task_async_generator(): asyncio.run(verify_async_task_async_generator()) +def test_streaming_generator_num_objects_per_yield(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + for i in range(3): + stats = yield i, f"metadata-{i}" + assert stats is None or stats.object_creation_dur_s >= 0 + + gen = generator.remote() + for i in range(3): + assert ray.get(next(gen)) == i + assert ray.get(next(gen)) == f"metadata-{i}" + + with pytest.raises(StopIteration): + next(gen) + + @ray.remote + def per_call(): + yield 1, 2 + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + per_call.options(_num_objects_per_yield=2).remote() + + +def test_actor_streaming_generator_num_objects_per_yield(shutdown_only): + ray.init() + + @ray.remote + class Actor: + @ray.method(_num_objects_per_yield=2) + def decorated(self): + yield "block", "metadata" + + def per_call(self): + yield 1, 2 + + actor = Actor.remote() + + gen = actor.decorated.remote() + assert ray.get(next(gen)) == "block" + assert ray.get(next(gen)) == "metadata" + with pytest.raises(StopIteration): + next(gen) + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + actor.per_call.options(_num_objects_per_yield=2).remote() + + +def test_streaming_generator_num_objects_per_yield_invalid_yield(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + yield (1,) + + gen = generator.remote() + with pytest.raises(ValueError, match="_num_objects_per_yield=2"): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_serialization_failure(shutdown_only): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + yield threading.Lock(), 1 + + gen = generator.remote() + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_partial_store_failure( + shutdown_only, +): + ray.init() + + @ray.remote(_num_objects_per_yield=2) + def generator(): + # If a later object fails after an earlier object has been stored, the + # caller should still receive a ref for every object in the grouped yield. + yield 1, threading.Lock() + + gen = generator.remote() + assert ray.get(next(gen)) == 1 + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + +def test_streaming_generator_num_objects_per_yield_failure_not_retried( + shutdown_only, +): + ray.init() + + @ray.remote(_num_objects_per_yield=2, retry_exceptions=True, max_retries=1) + def generator(): + # Once grouped-yield IDs are allocated, the whole group must be + # reported so those temporary refs are cleared instead of retrying. + yield 1, threading.Lock() + + gen = generator.remote() + assert ray.get(next(gen)) == 1 + with pytest.raises(ray.exceptions.RayTaskError): + ray.get(next(gen)) + + with pytest.raises(StopIteration): + next(gen) + + def test_streaming_generator_exception(shutdown_only): # Verify the exceptions are correctly raised. # Also verify the followup next will raise StopIteration. diff --git a/python/ray/tests/test_streaming_generator_backpressure.py b/python/ray/tests/test_streaming_generator_backpressure.py index 6d2a3b3d76e5..a88225754b95 100644 --- a/python/ray/tests/test_streaming_generator_backpressure.py +++ b/python/ray/tests/test_streaming_generator_backpressure.py @@ -248,6 +248,19 @@ async def f(self): def f(): pass + with pytest.raises(ValueError, match="_num_objects_per_yield"): + + @ray.remote(_num_objects_per_yield=0) + def g(): + pass + + with pytest.raises(ValueError, match="_num_objects_per_yield"): + + class Actor: + @ray.method(_num_objects_per_yield=0) + def g(self): + pass + def test_threaded_actor_generator_backpressure(shutdown_only): ray.init() diff --git a/python/ray/train/v2/BUILD.bazel b/python/ray/train/v2/BUILD.bazel index bd10115880f4..0b0a17556097 100644 --- a/python/ray/train/v2/BUILD.bazel +++ b/python/ray/train/v2/BUILD.bazel @@ -528,6 +528,22 @@ py_test( ], ) +py_test( + name = "test_preemption_watcher", + size = "small", + srcs = ["tests/test_preemption_watcher.py"], + env = {"RAY_TRAIN_V2_ENABLED": "1"}, + tags = [ + "exclusive", + "team:ml", + "train_v2", + ], + deps = [ + ":conftest", + "//:ray_lib", + ], +) + py_test( name = "test_result", size = "medium", diff --git a/python/ray/train/v2/_internal/callbacks/preemption_callback.py b/python/ray/train/v2/_internal/callbacks/preemption_callback.py new file mode 100644 index 000000000000..d2bc512d32f7 --- /dev/null +++ b/python/ray/train/v2/_internal/callbacks/preemption_callback.py @@ -0,0 +1,85 @@ +import logging +import os +from typing import TYPE_CHECKING, Dict, List, Optional + +import ray +from ray.actor import ActorHandle +from ray.train.v2._internal.constants import ( + DEFAULT_PREEMPTION_POLL_INTERVAL_S, + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, +) +from ray.train.v2._internal.execution.callback import WorkerGroupCallback +from ray.train.v2._internal.execution.preemption import PreemptionWatcher + +if TYPE_CHECKING: + from ray.train.v2._internal.execution.worker_group import ( + WorkerGroup, + WorkerGroupContext, + ) + +logger = logging.getLogger(__name__) + + +class PreemptionCallback(WorkerGroupCallback): + """Manages a :class:`PreemptionWatcher` across worker-group lifecycles. + + Spawns a fresh watcher in :meth:`after_worker_group_start` and stops it on + every teardown path (shutdown and abort). Each worker group gets its own + watcher and failure-domain map, so elastic resizes and restarts never + leak stale state. + """ + + def __init__(self) -> None: + self._poll_interval_s: float = float( + os.getenv( + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, + str(DEFAULT_PREEMPTION_POLL_INTERVAL_S), + ) + ) + self._watcher: Optional[ActorHandle] = None + + def after_worker_group_start(self, worker_group: "WorkerGroup") -> None: + # Tear down any watcher from a previous worker group first. Worker-group + # startup can fail after this hook without running the shutdown hook, so + # this also prevents leaking an orphaned watcher across a reschedule. + self._stop_watcher() + + node_to_ranks: Dict[str, List[int]] = {} + for w in worker_group.get_workers(): + node_to_ranks.setdefault(w.metadata.node_id, []).append( + w.distributed_context.world_rank + ) + + watcher_cls = ray.remote(num_cpus=0, max_restarts=-1)(PreemptionWatcher) + self._watcher = watcher_cls.remote( + node_to_ranks=node_to_ranks, + poll_interval_s=self._poll_interval_s, + ) + + logger.debug( + "PreemptionCallback: started watcher for %d node(s).", + len(node_to_ranks), + ) + + def before_worker_group_shutdown(self, worker_group: "WorkerGroup") -> None: + self._stop_watcher() + + def after_worker_group_abort( + self, worker_group_context: "WorkerGroupContext" + ) -> None: + # abort() doesn't run the shutdown hook, so tear the watcher down here + # too — otherwise it keeps polling GCS until the cluster reaps it. + self._stop_watcher() + + def _stop_watcher(self) -> None: + if self._watcher is None: + return + watcher = self._watcher + self._watcher = None + # Force-kill (non-blocking) rather than a synchronous graceful stop, so + # we never block the controller's event loop. The watcher's daemon poll + # thread dies with the actor process and holds no external resources. + try: + ray.kill(watcher) + except Exception: + logger.warning("Failed to kill PreemptionWatcher actor.", exc_info=True) diff --git a/python/ray/train/v2/_internal/constants.py b/python/ray/train/v2/_internal/constants.py index a116e12ff73a..2bcbbce73e5f 100644 --- a/python/ray/train/v2/_internal/constants.py +++ b/python/ray/train/v2/_internal/constants.py @@ -58,6 +58,15 @@ ) DEFAULT_CHECKPOINT_UPLOAD_WARN_INTERVAL_S: float = 60 +# Feature flag for the preemption watcher. Default-on; provides a quick +# rollback path if the watcher actor misbehaves in a cluster. +ENABLE_PREEMPTION_WATCHER_ENV_VAR = "RAY_TRAIN_ENABLE_PREEMPTION_WATCHER" +DEFAULT_ENABLE_PREEMPTION_WATCHER: bool = True + +# How often the preemption watcher polls Ray Core's drain state. +PREEMPTION_POLL_INTERVAL_S_ENV_VAR = "RAY_TRAIN_PREEMPTION_POLL_INTERVAL_S" +DEFAULT_PREEMPTION_POLL_INTERVAL_S: float = 5.0 + # Environment variable to enable the print function patching. ENABLE_PRINT_PATCH_ENV_VAR = "RAY_TRAIN_ENABLE_PRINT_PATCH" DEFAULT_ENABLE_PRINT_PATCH = True @@ -118,6 +127,8 @@ STATE_ACTOR_RECONCILIATION_INTERVAL_S_ENV_VAR, RAY_WARN_BLOCKING_GET_INSIDE_ASYNC_ENV_VAR, TORCHFT_LIGHTHOUSE_ADDR_ENV_VAR, + ENABLE_PREEMPTION_WATCHER_ENV_VAR, + PREEMPTION_POLL_INTERVAL_S_ENV_VAR, } diff --git a/python/ray/train/v2/_internal/execution/controller/controller.py b/python/ray/train/v2/_internal/execution/controller/controller.py index e7a509bc6cc9..c86345d62a31 100644 --- a/python/ray/train/v2/_internal/execution/controller/controller.py +++ b/python/ray/train/v2/_internal/execution/controller/controller.py @@ -12,8 +12,10 @@ from ray.exceptions import AsyncioActorExit from ray.train.v2._internal.constants import ( DEFAULT_ENABLE_CONTROLLER_LOGGING, + DEFAULT_ENABLE_PREEMPTION_WATCHER, DEFAULT_HEALTH_CHECK_INTERVAL_S, ENABLE_CONTROLLER_STRUCTURED_LOGGING_ENV_VAR, + ENABLE_PREEMPTION_WATCHER_ENV_VAR, HEALTH_CHECK_INTERVAL_S_ENV_VAR, ) from ray.train.v2._internal.execution.callback import ( @@ -188,6 +190,28 @@ def __init__( else False ) + # Register the preemption-observability callback when not in TorchFT + # mode (replica groups handle peer loss via their own quorum). + enable_preemption_watcher = ray_constants.env_bool( + ENABLE_PREEMPTION_WATCHER_ENV_VAR, + DEFAULT_ENABLE_PREEMPTION_WATCHER, + ) + if self._manages_replica_groups: + if enable_preemption_watcher and ray_constants.env_set_by_user( + ENABLE_PREEMPTION_WATCHER_ENV_VAR + ): + logger.info( + "The preemption watcher is not compatible with replica " + "groups (e.g. TorchFT), which handle peer loss via their " + "own quorum; skipping it." + ) + elif enable_preemption_watcher: + from ray.train.v2._internal.callbacks.preemption_callback import ( + PreemptionCallback, + ) + + self._worker_group_callbacks_to_propagate.append(PreemptionCallback()) + self._worker_group: Optional[WorkerGroup] = None self._state = InitializingState() self._return_value: Optional[Any] = None diff --git a/python/ray/train/v2/_internal/execution/preemption.py b/python/ray/train/v2/_internal/execution/preemption.py new file mode 100644 index 000000000000..d8ea6ab078bc --- /dev/null +++ b/python/ray/train/v2/_internal/execution/preemption.py @@ -0,0 +1,213 @@ +import logging +import threading +from dataclasses import dataclass +from typing import Dict, List, Optional, Set + +import ray +from ray.train.v2._internal.constants import DEFAULT_PREEMPTION_POLL_INTERVAL_S +from ray.util.tpu import get_tpu_slice_name_from_node + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class PreemptionInfo: + """Information about an imminent preemption event. + + Attributes: + deadline_ms: Earliest preemption deadline (UNIX time in milliseconds) + across all preempted nodes. ``None`` if no deadline was reported. + preempted_node_to_ranks: Map of preempted ``node_id`` to the worker ``world_rank``s affected when that node + is preempted. + """ + + deadline_ms: Optional[int] + preempted_node_to_ranks: Dict[str, List[int]] + + @property + def preempted_node_ids(self) -> List[str]: + """Preempted node IDs, sorted lexicographically.""" + return sorted(self.preempted_node_to_ranks) + + @property + def preempted_ranks(self) -> List[int]: + """All affected ranks across the preempted nodes, sorted ascending.""" + return sorted( + {r for ranks in self.preempted_node_to_ranks.values() for r in ranks} + ) + + +def _get_draining_nodes() -> Dict[str, int]: + """Ray Core's draining nodes as ``{node_id_hex: deadline_ms}`` (0 = no deadline).""" + return ray._private.state.state.get_draining_nodes() + + +class PreemptionWatcher: + """Polls Ray Core for node drains and logs detected preemption events. + + One watcher per worker group, spawned as a ``num_cpus=0`` actor by + ``PreemptionCallback``. The poll loop runs in a background thread. The + failure-domain map is built once on construction and is immutable for the + watcher's lifetime. + + The failure-domain map records which of our ranks are affected if a node is + preempted: for a GPU node, the ranks on that node; for a TPU node, every + rank in the node's slice, since a TPU slice is preempted atomically. + + Args: + node_to_ranks: Map ``node_id_hex -> [ranks on that node]``. Used both + as the set of nodes we care about (drains elsewhere are ignored) + and as the seed for failure-domain expansion. + poll_interval_s: Seconds between drain-state polls. + """ + + def __init__( + self, + node_to_ranks: Dict[str, List[int]], + poll_interval_s: float = DEFAULT_PREEMPTION_POLL_INTERVAL_S, + ): + self._node_to_ranks: Dict[str, List[int]] = { + nid: sorted(ranks) for nid, ranks in node_to_ranks.items() + } + self._poll_interval_s = poll_interval_s + self._failure_domain_map: Dict[str, List[int]] = self._build_failure_domain_map( + self._node_to_ranks + ) + + self._stop_event = threading.Event() + self._last_drained: Dict[str, int] = {} + self._latest_info: Optional[PreemptionInfo] = None + + self._monitor_thread = threading.Thread( + target=self._watch_loop, + name="PreemptionWatcher", + daemon=True, + ) + self._monitor_thread.start() + + @staticmethod + def _build_failure_domain_map( + node_to_ranks: Dict[str, List[int]], + ) -> Dict[str, List[int]]: + """Map each node we host to all ranks in its failure domain. + + - Non-TPU (e.g. GPU) clusters: the failure domain is the node itself, + so a drain on a node flags only the ranks this job runs there. + - TPU multislice: every host in a slice is reclaimed atomically, so a + drain on any host is fate-shared with the rest. + """ + per_node = {nid: sorted(set(ranks)) for nid, ranks in node_to_ranks.items()} + + try: + all_nodes = ray.nodes() + + # Slice label for each node we host (None for non-TPU nodes). + node_to_slice: Dict[str, Optional[str]] = { + node["NodeID"]: get_tpu_slice_name_from_node(node) + for node in all_nodes + if node["NodeID"] in node_to_ranks + } + + # Union our ranks per slice. + slice_to_ranks: Dict[str, Set[int]] = {} + for node_id, ranks in node_to_ranks.items(): + slice_label = node_to_slice.get(node_id) + if slice_label: + slice_to_ranks.setdefault(slice_label, set()).update(ranks) + + # Non-TPU cluster (or none of our nodes are on a slice): per-node. + if not slice_to_ranks: + return per_node + + result: Dict[str, List[int]] = {} + for node_id, ranks in node_to_ranks.items(): + slice_label = node_to_slice.get(node_id) + if slice_label: + result[node_id] = sorted(slice_to_ranks[slice_label]) + else: + result[node_id] = sorted(set(ranks)) + return result + except Exception: + logger.debug( + "Could not build failure-domain map; falling back to per-node " + "domains (no TPU-slice expansion).", + exc_info=True, + ) + return per_node + + def get_latest_preemption_info(self) -> Optional[PreemptionInfo]: + """Most recent :class:`PreemptionInfo` observed, or ``None``.""" + return self._latest_info + + def _watch_loop(self) -> None: + logger.debug( + "PreemptionWatcher polling %d node(s) every %.1fs.", + len(self._node_to_ranks), + self._poll_interval_s, + ) + while not self._stop_event.is_set(): + self._poll_once() + self._stop_event.wait(timeout=self._poll_interval_s) + logger.debug("PreemptionWatcher stopped.") + + def _poll_once(self) -> None: + """Poll the drain source once and dispatch on change. + + Per-poll exceptions are caught and logged so a transient GCS hiccup + doesn't kill the watcher loop. + """ + try: + drained = _get_draining_nodes() or {} + # Keep only drains on this job's own nodes (others are ignored). + # That's complete for TPU — an SPMD job fully occupies its slice, so + # every fate-shared host is one of our nodes and a drain on any slice + # host appears here. For GPU, a drain on a host we don't run on is + # correctly irrelevant. + relevant = { + n: d for n, d in drained.items() if n in self._failure_domain_map + } + if relevant != self._last_drained: + self._on_drain_change(relevant) + self._last_drained = relevant + except Exception: + # TODO(lehui): consider exponential backoff when the drain API keeps + # failing, instead of retrying at the fixed poll interval. + logger.warning("PreemptionWatcher poll failed", exc_info=True) + + def _on_drain_change(self, drained: Dict[str, int]) -> None: + """Handle a change in the drained-node set. + + ``drained`` has already been narrowed to this job's nodes by the + caller (``_poll_once``). + """ + if not drained: + return + + affected_node_ids = sorted(drained.keys()) + preempted_node_to_ranks = { + node_id: self._failure_domain_map[node_id] for node_id in affected_node_ids + } + + # Earliest deadline across the preempted nodes; None if none reported one + # (Ray Core uses 0 for "no deadline", which is falsy and filtered out). + reported_deadlines = [drained[n] for n in affected_node_ids if drained[n]] + deadline_ms = min(reported_deadlines) if reported_deadlines else None + + info = PreemptionInfo( + deadline_ms=deadline_ms, + preempted_node_to_ranks=preempted_node_to_ranks, + ) + self._latest_info = info + + logger.warning( + "PreemptionWatcher: preemption detected — " + "preempted_node_ids=%s, preempted_ranks=%s, deadline_ms=%s", + info.preempted_node_ids, + info.preempted_ranks, + deadline_ms, + ) + # TODO(lehui): forward the detected preemption to the workers so the + # training loop can react to it. + # TODO(lehui): coalesce preemptions seen within one window into a single + # worker-group restart, so a staggered drain (node A at t, node B at + # t+60s) doesn't cause back-to-back restarts. diff --git a/python/ray/train/v2/tests/test_preemption_watcher.py b/python/ray/train/v2/tests/test_preemption_watcher.py new file mode 100644 index 000000000000..32cce03ddfcd --- /dev/null +++ b/python/ray/train/v2/tests/test_preemption_watcher.py @@ -0,0 +1,192 @@ +"""Unit tests for the preemption watcher.""" +from typing import Dict +from unittest.mock import Mock, patch + +import pytest + +import ray +from ray.train.v2._internal.callbacks.preemption_callback import PreemptionCallback +from ray.train.v2._internal.execution.preemption import PreemptionWatcher + +_PREEMPTION_MOD = "ray.train.v2._internal.execution.preemption" + + +def _make_watcher( + node_to_ranks: Dict[str, list], + fd_map: Dict[str, list], +) -> PreemptionWatcher: + """Construct a watcher with a fixed failure-domain map, then halt its loop. + + The background poll thread is stopped (while the GCS drain call is mocked + out) so tests can drive ``_poll_once()`` synchronously and deterministically. + """ + fd = {nid: sorted(ranks) for nid, ranks in fd_map.items()} + with patch.object( + PreemptionWatcher, "_build_failure_domain_map", return_value=fd + ), patch(f"{_PREEMPTION_MOD}._get_draining_nodes", return_value={}): + watcher = PreemptionWatcher(node_to_ranks=node_to_ranks) + watcher._stop_event.set() + watcher._monitor_thread.join(timeout=5) + return watcher + + +def _poll_once_with(watcher: PreemptionWatcher, **patch_kwargs) -> None: + """Drive one poll with the GCS drain call mocked (``return_value``/``side_effect``).""" + with patch(f"{_PREEMPTION_MOD}._get_draining_nodes", **patch_kwargs): + watcher._poll_once() + + +class TestPreemptionWatcher: + @pytest.mark.parametrize( + "fd_map, drained, exp_nodes, exp_ranks, exp_deadline_ms", + [ + # Single drained node. + ({"node-a": [0, 1]}, {"node-a": 30_000}, ["node-a"], [0, 1], 30_000), + # Drains on nodes outside our failure domains are filtered out. + ( + {"node-a": [0, 1]}, + {"node-a": 30_000, "other": 30_000}, + ["node-a"], + [0, 1], + 30_000, + ), + # Multiple of our nodes: ranks aggregated, earliest deadline wins. + ( + {"node-a": [0, 1], "node-b": [2, 3]}, + {"node-a": 45_000, "node-b": 20_000}, + ["node-a", "node-b"], + [0, 1, 2, 3], + 20_000, + ), + # Failure-domain-expanded map: a drain on one host flags the whole + # fate-shared set (e.g. a TPU slice). + ( + {"host-0": [0, 1, 2, 3], "host-1": [0, 1, 2, 3]}, + {"host-0": 30_000}, + ["host-0"], + [0, 1, 2, 3], + 30_000, + ), + # deadline_ms=0 from Ray Core means "no deadline" -> surfaced as None. + ({"node-a": [0]}, {"node-a": 0}, ["node-a"], [0], None), + # A None deadline is also surfaced as None rather than raising. + ({"node-a": [0]}, {"node-a": None}, ["node-a"], [0], None), + ], + ) + def test_drain_produces_info( + self, fd_map, drained, exp_nodes, exp_ranks, exp_deadline_ms + ): + watcher = _make_watcher(node_to_ranks=fd_map, fd_map=fd_map) + _poll_once_with(watcher, return_value=dict(drained)) + + info = watcher.get_latest_preemption_info() + assert info.preempted_node_ids == exp_nodes + assert info.preempted_ranks == exp_ranks + assert info.deadline_ms == exp_deadline_ms + # The node->ranks map keys are the preempted nodes; each maps to its + # failure domain. The flat getters above derive from it. + assert info.preempted_node_to_ranks == {n: sorted(fd_map[n]) for n in exp_nodes} + + def test_no_drain_means_no_info(self): + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with(watcher, return_value={}) + assert watcher.get_latest_preemption_info() is None + + def test_poll_swallows_drain_errors(self): + """A raising drain call must not propagate out of a poll.""" + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with( + watcher, side_effect=RuntimeError("transient") + ) # must not raise + assert watcher.get_latest_preemption_info() is None + + def test_none_drain_result_is_safe(self): + """A drain call returning None is treated as no drains.""" + watcher = _make_watcher(node_to_ranks={"node-a": [0]}, fd_map={"node-a": [0]}) + _poll_once_with(watcher, return_value=None) # must not raise + assert watcher.get_latest_preemption_info() is None + + +class TestBuildFailureDomainMap: + def test_falls_back_on_ray_nodes_error(self, monkeypatch): + """If ray.nodes() raises, fall back to per-node domains.""" + + def raise_runtime(): + raise RuntimeError("no ray runtime") + + monkeypatch.setattr(ray, "nodes", raise_runtime) + result = PreemptionWatcher._build_failure_domain_map( + {"node-a": [0, 1], "node-b": [2, 3]} + ) + assert result == {"node-a": [0, 1], "node-b": [2, 3]} + + def test_falls_back_on_slice_lookup_error(self): + """If the TPU slice lookup raises, fall back to per-node domains.""" + with patch.object(ray, "nodes", return_value=[{"NodeID": "node-a"}]), patch( + f"{_PREEMPTION_MOD}.get_tpu_slice_name_from_node", + side_effect=RuntimeError("boom"), + ): + result = PreemptionWatcher._build_failure_domain_map({"node-a": [0, 1]}) + assert result == {"node-a": [0, 1]} + + @pytest.mark.parametrize( + "node_to_ranks, slice_labels, cluster_node_ids, expected", + [ + # GPU cluster (no slice labels): per-node domains. node-1 may be + # shared with another workload, but only our ranks are in + # node_to_ranks, so a drain on node-1 flags only our rank there. + ( + {"node-0": [0, 1, 2, 3], "node-1": [4]}, + {}, + ["node-0", "node-1"], + {"node-0": [0, 1, 2, 3], "node-1": [4]}, + ), + # TPU slice fully occupied by this job: a drain on any host is + # fate-shared, so every host maps to the union of the slice's ranks. + ( + {"sa-0": [0, 1], "sa-1": [2, 3]}, + {"sa-0": "A", "sa-1": "A"}, + ["sa-0", "sa-1"], + {"sa-0": [0, 1, 2, 3], "sa-1": [0, 1, 2, 3]}, + ), + ], + ) + def test_build_failure_domain_map( + self, node_to_ranks, slice_labels, cluster_node_ids, expected + ): + fake_nodes = [{"NodeID": nid} for nid in cluster_node_ids] + + def fake_slice_label(node): + return slice_labels.get(node["NodeID"]) + + with patch.object(ray, "nodes", return_value=fake_nodes), patch( + f"{_PREEMPTION_MOD}.get_tpu_slice_name_from_node", + side_effect=fake_slice_label, + ): + result = PreemptionWatcher._build_failure_domain_map(node_to_ranks) + + assert result == expected + + +class TestPreemptionCallbackTeardown: + @pytest.mark.parametrize( + "teardown_hook", + ["before_worker_group_shutdown", "after_worker_group_abort"], + ) + def test_teardown_kills_watcher(self, teardown_hook): + """Both shutdown and abort must kill the watcher (abort skips shutdown).""" + callback = PreemptionCallback() + watcher = Mock() + callback._watcher = watcher + + with patch.object(ray, "kill") as mock_kill: + getattr(callback, teardown_hook)(None) + + mock_kill.assert_called_once_with(watcher) + assert callback._watcher is None + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", "-x", __file__])) diff --git a/src/ray/common/task/task_spec.cc b/src/ray/common/task/task_spec.cc index d7b0986038ff..e717eaae5bb7 100644 --- a/src/ray/common/task/task_spec.cc +++ b/src/ray/common/task/task_spec.cc @@ -249,6 +249,11 @@ int64_t TaskSpecification::GeneratorBackpressureNumObjects() const { return result; } +int64_t TaskSpecification::NumObjectsPerYield() const { + auto result = message_->num_objects_per_yield(); + return result == 0 ? 1 : result; +} + std::vector TaskSpecification::DynamicReturnIds() const { RAY_CHECK(message_->returns_dynamic()); std::vector dynamic_return_ids; diff --git a/src/ray/common/task/task_spec.h b/src/ray/common/task/task_spec.h index e78184c6eb66..32ff710f1610 100644 --- a/src/ray/common/task/task_spec.h +++ b/src/ray/common/task/task_spec.h @@ -193,6 +193,8 @@ class TaskSpecification : public MessageWrapper { int64_t GeneratorBackpressureNumObjects() const; + int64_t NumObjectsPerYield() const; + std::vector DynamicReturnIds() const; void AddDynamicReturnId(const ObjectID &dynamic_return_id); diff --git a/src/ray/common/task/task_util.h b/src/ray/common/task/task_util.h index c0f0b2f1eb35..e1afa6e3fe5f 100644 --- a/src/ray/common/task/task_util.h +++ b/src/ray/common/task/task_util.h @@ -160,7 +160,8 @@ class TaskSpecBuilder { const std::unordered_map &labels = {}, const LabelSelector &label_selector = {}, const std::vector &fallback_strategy = - std::vector()) { + std::vector(), + uint64_t num_objects_per_yield = 1) { message_->set_type(TaskType::NORMAL_TASK); message_->set_name(name); message_->set_language(language); @@ -179,6 +180,7 @@ class TaskSpecBuilder { message_->set_returns_dynamic(returns_dynamic); message_->set_streaming_generator(is_streaming_generator); message_->set_generator_backpressure_num_objects(generator_backpressure_num_objects); + message_->set_num_objects_per_yield(num_objects_per_yield); message_->mutable_required_resources()->insert(required_resources.begin(), required_resources.end()); message_->mutable_required_placement_resources()->insert( diff --git a/src/ray/core_worker/common.h b/src/ray/core_worker/common.h index 6c91be760265..f1e03824f961 100644 --- a/src/ray/core_worker/common.h +++ b/src/ray/core_worker/common.h @@ -70,6 +70,7 @@ struct TaskOptions { std::unordered_map &resources_p, std::string concurrency_group_name_p = "", int64_t generator_backpressure_num_objects_p = -1, + int64_t num_objects_per_yield_p = 1, std::string serialized_runtime_env_info_p = "{}", bool enable_task_events_p = kDefaultTaskEventEnabled, std::unordered_map labels_p = {}, @@ -82,6 +83,7 @@ struct TaskOptions { concurrency_group_name(std::move(concurrency_group_name_p)), serialized_runtime_env_info(std::move(serialized_runtime_env_info_p)), generator_backpressure_num_objects(generator_backpressure_num_objects_p), + num_objects_per_yield(num_objects_per_yield_p), enable_task_events(enable_task_events_p), labels(std::move(labels_p)), label_selector(std::move(label_selector_p)), @@ -104,6 +106,9 @@ struct TaskOptions { /// -1 means either streaming generator is not used or /// it is used but the feature is disabled. int64_t generator_backpressure_num_objects; + /// Only applicable when streaming generator is used. + /// The number of ObjectRefs produced by each generator yield. + int64_t num_objects_per_yield = 1; /// True if task events (worker::TaskEvent) from this task should be reported, default /// to true. bool enable_task_events = kDefaultTaskEventEnabled; diff --git a/src/ray/core_worker/core_worker.cc b/src/ray/core_worker/core_worker.cc index 7257247b5344..deb0bc28301a 100644 --- a/src/ray/core_worker/core_worker.cc +++ b/src/ray/core_worker/core_worker.cc @@ -1835,7 +1835,8 @@ void CoreWorker::BuildCommonTaskSpec( bool enable_task_events, const std::unordered_map &labels, const LabelSelector &label_selector, - const std::vector &fallback_strategy) { + const std::vector &fallback_strategy, + int64_t num_objects_per_yield) { // Build common task spec. auto override_runtime_env_info = OverrideTaskOrActorRuntimeEnvInfo(serialized_runtime_env_info); @@ -1885,7 +1886,8 @@ void CoreWorker::BuildCommonTaskSpec( enable_task_events, labels, label_selector, - fallback_strategy); + fallback_strategy, + num_objects_per_yield); // Set task arguments. for (const auto &arg : args) { builder.AddArg(*arg); @@ -1966,7 +1968,8 @@ std::vector CoreWorker::SubmitTask( /*enable_task_events=*/task_options.enable_task_events, task_options.labels, task_options.label_selector, - task_options.fallback_strategy); + task_options.fallback_strategy, + task_options.num_objects_per_yield); ActorID root_detached_actor_id; if (!worker_context_->GetRootDetachedActorID().IsNil()) { root_detached_actor_id = worker_context_->GetRootDetachedActorID(); @@ -2404,7 +2407,8 @@ Status CoreWorker::SubmitActorTask( /*enable_task_events=*/task_options.enable_task_events, /*labels=*/task_options.labels, /*label_selector=*/{}, - /*fallback_strategy=*/{}); + /*fallback_strategy=*/{}, + task_options.num_objects_per_yield); // NOTE: placement_group_capture_child_tasks and runtime_env will // be ignored in the actor because we should always follow the actor's option. @@ -2881,6 +2885,7 @@ Status CoreWorker::ExecuteTask( /*retry_exception=*/task_spec.ShouldRetryExceptions(), /*generator_backpressure_num_objects=*/ task_spec.GeneratorBackpressureNumObjects(), + /*num_objects_per_yield=*/task_spec.NumObjectsPerYield(), /*tensor_transport=*/task_spec.TensorTransport()); // Get the reference counts for any IDs that we borrowed during this task, @@ -3113,7 +3118,8 @@ ObjectID CoreWorker::AllocateDynamicReturnId(const rpc::Address &owner_address, } Status CoreWorker::ReportGeneratorItemReturns( - const std::pair> &dynamic_return_object, + const std::vector>> + &dynamic_return_objects, const ObjectID &generator_id, const rpc::Address &owner_address, int64_t item_index, @@ -3126,26 +3132,33 @@ Status CoreWorker::ReportGeneratorItemReturns( request.set_attempt_number(attempt_number); auto client = core_worker_client_pool_->GetOrConnect(owner_address); - // This means it is the last report when the task has finished executing. - if (!dynamic_return_object.first.IsNil()) { + std::vector return_ids; + return_ids.reserve(dynamic_return_objects.size()); + for (const auto &dynamic_return_object : dynamic_return_objects) { + if (dynamic_return_object.first.IsNil()) { + continue; + } SerializeReturnObject(dynamic_return_object.first, dynamic_return_object.second, - request.mutable_returned_object()); - std::vector deleted; + request.add_returned_objects()); + return_ids.push_back(dynamic_return_object.first); + } + if (!return_ids.empty()) { // When we allocate a dynamic return ID (AllocateDynamicReturnId), - // we borrow the object. When the object value is allocatd, the + // we borrow the object. When the object value is allocated, the // memory store is updated. We should clear borrowers and memory store // here. + std::vector deleted; ReferenceCounterInterface::ReferenceTableProto borrowed_refs; - reference_counter_->PopAndClearLocalBorrowers( - {dynamic_return_object.first}, &borrowed_refs, &deleted); + reference_counter_->PopAndClearLocalBorrowers(return_ids, &borrowed_refs, &deleted); memory_store_->Delete(deleted); } - const auto return_id = dynamic_return_object.first; + + const auto return_id = return_ids.empty() ? ObjectID::Nil() : return_ids.front(); RAY_LOG(DEBUG) << "Write the object ref stream, index: " << item_index - << ", id: " << return_id; + << ", id: " << return_id << ", count: " << return_ids.size(); - waiter->IncrementObjectGenerated(); + waiter->IncrementObjectGenerated(return_ids.size()); client->ReportGeneratorItemReturns( std::move(request), diff --git a/src/ray/core_worker/core_worker.h b/src/ray/core_worker/core_worker.h index 83f25904ad08..43f1d34c83af 100644 --- a/src/ray/core_worker/core_worker.h +++ b/src/ray/core_worker/core_worker.h @@ -774,12 +774,9 @@ class CoreWorker : public std::enable_shared_from_this { /// NOTE: The API doesn't guarantee the ordering of the report. The /// owner is supposed to reorder the report based on the item_index. /// - /// \param[in] returned_object A intermediate ray object to report - /// to the owner before the task terminates. This object must have been + /// \param[in] returned_objects Intermediate ray objects to report + /// to the owner before the task terminates. These objects must have been /// created dynamically from this worker via AllocateReturnObject. - /// If the Object ID is nil, it means it is the end of the task return. - /// In this case, the owner is responsible for setting finished = true, - /// otherwise it will panic. /// \param[in] generator_id The return object ref ID from a current generator /// task. /// \param[in] owner_address The address of the owner of the current task. @@ -790,7 +787,8 @@ class CoreWorker : public std::enable_shared_from_this { /// \param[in] waiter The class to pause the thread if generator backpressure limit /// is reached. Status ReportGeneratorItemReturns( - const std::pair> &returned_object, + const std::vector>> + &returned_objects, const ObjectID &generator_id, const rpc::Address &owner_address, int64_t item_index, @@ -1428,7 +1426,8 @@ class CoreWorker : public std::enable_shared_from_this { bool enable_task_events = true, const std::unordered_map &labels = {}, const LabelSelector &label_selector = {}, - const std::vector &fallback_strategy = {}); + const std::vector &fallback_strategy = {}, + int64_t num_objects_per_yield = 1); void SetCurrentTaskId(const TaskID &task_id, uint64_t attempt_number, diff --git a/src/ray/core_worker/core_worker_options.h b/src/ray/core_worker/core_worker_options.h index 4dc47af27ddc..847931f48810 100644 --- a/src/ray/core_worker/core_worker_options.h +++ b/src/ray/core_worker/core_worker_options.h @@ -78,6 +78,8 @@ struct CoreWorkerOptions { // The max number of unconsumed objects where a generator // can run without a pause. int64_t generator_backpressure_num_objects, + // The number of ObjectRefs produced by each streaming generator yield. + int64_t num_objects_per_yield, const std::optional &tensor_transport)>; CoreWorkerOptions() diff --git a/src/ray/core_worker/generator_waiter.cc b/src/ray/core_worker/generator_waiter.cc index 132e1247f9d5..875f106a6f10 100644 --- a/src/ray/core_worker/generator_waiter.cc +++ b/src/ray/core_worker/generator_waiter.cc @@ -69,9 +69,11 @@ Status GeneratorBackpressureWaiter::WaitAllObjectsReported() { return return_status; } -void GeneratorBackpressureWaiter::IncrementObjectGenerated() { +void GeneratorBackpressureWaiter::IncrementObjectGenerated( + int64_t num_objects_generated) { + RAY_CHECK_GE(num_objects_generated, 0); absl::MutexLock lock(&mutex_); - total_objects_generated_ += 1; + total_objects_generated_ += num_objects_generated; num_object_reports_in_flight_++; } diff --git a/src/ray/core_worker/generator_waiter.h b/src/ray/core_worker/generator_waiter.h index 4d04c18f34ae..492a0359b5bc 100644 --- a/src/ray/core_worker/generator_waiter.h +++ b/src/ray/core_worker/generator_waiter.h @@ -63,7 +63,7 @@ class GeneratorBackpressureWaiter { /// Increment the number of objects generated. The executor should call this /// before sending an object report to the caller. - void IncrementObjectGenerated(); + void IncrementObjectGenerated(int64_t num_objects_generated = 1); /// Handle a completed object report. The executor should call this after /// receiving an ack from the caller for an object report. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc index 3d047dbf4e6f..6a4a627408da 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc @@ -143,6 +143,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(JNIEnv *env, bool is_streaming_generator, bool should_retry_exceptions, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) { // These 3 parameters are used for Python only, and Java worker // will not use them. diff --git a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc index 363d51234a12..cee2e8182990 100644 --- a/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc +++ b/src/ray/core_worker/lib/java/io_ray_runtime_task_NativeTaskSubmitter.cc @@ -162,6 +162,7 @@ inline TaskOptions ToTaskOptions(JNIEnv *env, jint numReturns, jobject callOptio resources, concurrency_group_name, /*generator_backpressure_num_objects*/ -1, + /*num_objects_per_yield*/ 1, serialzied_runtime_env_info}; return task_options; } diff --git a/src/ray/core_worker/task_manager.cc b/src/ray/core_worker/task_manager.cc index d26b09570b47..e15c7f3b9607 100644 --- a/src/ray/core_worker/task_manager.cc +++ b/src/ray/core_worker/task_manager.cc @@ -818,13 +818,14 @@ bool TaskManager::HandleReportGeneratorItemReturns( } size_t num_objects_written = 0; - if (request.has_returned_object()) { - const rpc::ReturnObject &returned_object = request.returned_object(); + for (int64_t i = 0; i < request.returned_objects_size(); i++) { + const rpc::ReturnObject &returned_object = request.returned_objects(i); const auto object_id = ObjectID::FromBinary(returned_object.object_id()); + const auto object_index = item_index + i; RAY_LOG(DEBUG) << "Write an object " << object_id << " to the object ref stream of id " << generator_id; - auto index_not_used_yet = stream_it->second.InsertToStream(object_id, item_index); + auto index_not_used_yet = stream_it->second.InsertToStream(object_id, object_index); // If the ref was written to a stream, we should also // own the dynamically generated task return. @@ -849,18 +850,22 @@ bool TaskManager::HandleReportGeneratorItemReturns( // Handle backpressure if needed. auto total_generated = stream_it->second.TotalNumObjectWritten(); auto total_consumed = stream_it->second.TotalNumObjectConsumed(); + auto last_item_index = request.returned_objects_size() == 0 + ? item_index + : item_index + request.returned_objects_size() - 1; - if (stream_it->second.IsObjectConsumed(item_index)) { + if (stream_it->second.IsObjectConsumed(last_item_index)) { execution_signal_callback(Status::OK(), total_consumed); return false; } // Otherwise, follow the regular backpressure logic. - // NOTE, here we check `item_index - last_consumed_index >= backpressure_threshold`, - // instead of the number of unconsumed items, because we may receive the - // `HandleReportGeneratorItemReturns` requests out of order. + // NOTE, here we check `last_item_index - last_consumed_index >= + // backpressure_threshold`, instead of the number of unconsumed items, because we may + // receive the `HandleReportGeneratorItemReturns` requests out of order. if (backpressure_threshold != -1 && - (item_index - stream_it->second.LastConsumedIndex()) >= backpressure_threshold) { + (last_item_index - stream_it->second.LastConsumedIndex()) >= + backpressure_threshold) { RAY_LOG(DEBUG) << "Stream " << generator_id << " is backpressured. total_generated: " << total_generated << ". total_consumed: " << total_consumed diff --git a/src/ray/core_worker/tests/core_worker_test.cc b/src/ray/core_worker/tests/core_worker_test.cc index db8297d2a1f5..9aee2aa21bfa 100644 --- a/src/ray/core_worker/tests/core_worker_test.cc +++ b/src/ray/core_worker/tests/core_worker_test.cc @@ -93,6 +93,7 @@ class CoreWorkerTest : public ::testing::Test { bool is_streaming_generator, bool retry_exception, int64_t generator_backpressure_num_objects, + int64_t num_objects_per_yield, const std::optional &tensor_transport) -> Status { return Status::OK(); }; diff --git a/src/ray/core_worker/tests/task_manager_test.cc b/src/ray/core_worker/tests/task_manager_test.cc index 3e1e91bba447..0b6bdb1855a8 100644 --- a/src/ray/core_worker/tests/task_manager_test.cc +++ b/src/ray/core_worker/tests/task_manager_test.cc @@ -88,7 +88,7 @@ rpc::ReportGeneratorItemReturnsRequest GetIntermediateTaskReturn( request.mutable_worker_addr()->CopyFrom(addr); request.set_item_index(idx); request.set_generator_id(generator_id.Binary()); - rpc::ReturnObject *returned_object = request.mutable_returned_object(); + rpc::ReturnObject *returned_object = request.add_returned_objects(); returned_object->set_object_id(dynamic_return_id.Binary()); returned_object->set_data(data->Data(), data->Size()); returned_object->set_in_plasma(set_in_plasma); diff --git a/src/ray/protobuf/common.proto b/src/ray/protobuf/common.proto index 6d293e80f951..ec7a8b92c88e 100644 --- a/src/ray/protobuf/common.proto +++ b/src/ray/protobuf/common.proto @@ -622,6 +622,9 @@ message TaskSpec { optional string tensor_transport = 44; // A list of fallback options defining the fallback strategy for scheduling. FallbackStrategy fallback_strategy = 45; + // Private streaming generator option. The number of ObjectRefs produced by + // each generator yield. Defaults to 1. + uint64 num_objects_per_yield = 46; } message TaskInfoEntry { diff --git a/src/ray/protobuf/core_worker.proto b/src/ray/protobuf/core_worker.proto index 443bb71976ba..5735eee215a2 100644 --- a/src/ray/protobuf/core_worker.proto +++ b/src/ray/protobuf/core_worker.proto @@ -417,8 +417,8 @@ message NumPendingTasksReply { } message ReportGeneratorItemReturnsRequest { - // Object returned from the executor (can be inlined or in plasma). - ReturnObject returned_object = 1; + // Objects returned from the executor (can be inlined or in plasma). + repeated ReturnObject returned_objects = 1; // The address of the executor. Address worker_addr = 2; // The index of the task return. It is used to