diff --git a/src/bub/__init__.py b/src/bub/__init__.py index 0c4ce1be..43680844 100644 --- a/src/bub/__init__.py +++ b/src/bub/__init__.py @@ -13,13 +13,12 @@ from bub.framework import DEFAULT_HOME, BubFramework from bub.hookspecs import hookimpl from bub.tools import tool -from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot +from bub.turn_admission import AdmitDecision, TurnSnapshot __all__ = [ "AdmitDecision", "BubFramework", "Settings", - "SteeringBuffer", "TurnSnapshot", "config", "ensure_config", diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 8743fa30..18c5d927 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -2,10 +2,12 @@ from __future__ import annotations +import asyncio import inspect import re import shlex import time +from collections import deque from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable from contextlib import AsyncExitStack from dataclasses import dataclass, replace @@ -22,6 +24,7 @@ ) from bub.builtin.settings import load_settings from bub.builtin.tape import Tape +from bub.envelope import field_of from bub.framework import BubFramework from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState from bub.skills import discover_skills, render_skills_prompt @@ -46,6 +49,7 @@ def __init__(self, framework: BubFramework) -> None: self.settings = load_settings() self.framework = framework self.model_runner = ModelRunner(self.settings) + self._steering_messages: dict[str, deque[dict[str, Any]]] = {} @cached_property def tape(self) -> Tape: @@ -79,6 +83,21 @@ async def generator() -> AsyncIterator[StreamEvent]: return AsyncStreamEvents(generator(), state=events._state) + def enqueue_steering_message(self, thread_id: str, message: dict[str, Any]) -> bool: + if thread_id not in self._steering_messages: + return False + self._steering_messages[thread_id].append(message) + return True + + def _drain_steering_messages(self, thread_id: str) -> list[dict[str, Any]]: + queue = self._steering_messages.pop(thread_id, None) + if queue is None: + return [] + return list(queue) + + def _has_steering_messages(self, thread_id: str) -> bool: + return bool(self._steering_messages.get(thread_id)) + async def run_stream( self, *, @@ -98,6 +117,8 @@ async def run_stream( tape = self.tape.session_tape( session_id, workspace_from_state(state), context=replace(self.tape.context, state=state) ) + thread_id = state.get("_runtime_thread_id", tape.name) + self._steering_messages.setdefault(thread_id, deque()) merge_back = not session_id.startswith("temp/") stack = AsyncExitStack() # The fork_tape context manager must not be exited until the last chunk of the stream is consumed. @@ -117,6 +138,11 @@ async def run_stream( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) + + @stack.callback + def cleanup() -> None: + self._steering_messages.pop(thread_id, None) + return self._events_with_callback(events, callback=stack.aclose) async def _run_command(self, tape: Tape, *, line: str) -> str: @@ -273,6 +299,8 @@ async def _stream_events_with_auto_handoff( state.error = output.error state.usage = output.usage elapsed_ms = int((time.monotonic() - start) * 1000) + thread_id = tape.context.state.get("_runtime_thread_id", tape.name) + should_continue = should_continue or self._has_steering_messages(thread_id) if not should_continue: await tape.append_event( "loop.step", @@ -355,12 +383,22 @@ async def _run_once_stream( resolved_model = model or self.settings.model model_tools_for_call = model_tools(tools) + thread_id = tape.context.state.get("_runtime_thread_id", tape.name) + steering_messages = list( + await asyncio.gather(*[ + self.framework.build_prompt( + message, session_id=field_of(message, "session_id"), state=tape.context.state + ) + for message in self._drain_steering_messages(thread_id) + ]) + ) return self.model_runner.run( tape=tape, model=resolved_model, tools=model_tools_for_call, system_prompt=system_prompt, prompt=prompt, + steering_messages=steering_messages, ) def _system_prompt( diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index a786fca9..8307f17d 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -118,6 +118,8 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State: state = {"session_id": session_id, "_runtime_agent": self._get_agent()} if context := field_of(message, "context_str"): state["context"] = context + if thread_id := field_of(message, "context", {}).get("thread_id"): + state["_runtime_thread_id"] = thread_id return state @hookimpl @@ -294,7 +296,7 @@ def build_tape_context(self) -> TapeContext: return default_tape_context() @hookimpl - def admit_message( + async def admit_message( self, session_id: str, message: Envelope, @@ -303,4 +305,12 @@ def admit_message( outbound_router = self.framework._outbound_router if outbound_router is None: return None - return outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn) + return await outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn) + + @hookimpl + async def handle_steering(self, message: Envelope, reason: str | None) -> bool: + """Handle a steering message that is admitted by the `admit_message` hook with action "steer".""" + agent = self._get_agent() + context = field_of(message, "context", {}) + thread_id = field_of(context, "thread_id") if isinstance(context, dict) else None + return agent.enqueue_steering_message(str(thread_id or field_of(message, "session_id", "")), message) diff --git a/src/bub/builtin/model_runner.py b/src/bub/builtin/model_runner.py index 0684f8fc..89b92351 100644 --- a/src/bub/builtin/model_runner.py +++ b/src/bub/builtin/model_runner.py @@ -85,6 +85,7 @@ def run( tools: list[Tool], system_prompt: str | None, prompt: str | list[dict], + steering_messages: list[list[dict[str, Any]] | str] | None = None, ) -> AsyncStreamEvents: state = StreamState() @@ -96,6 +97,7 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]: system_prompt=system_prompt, prompt=prompt, model=model, + steering_messages=steering_messages, ) output = ModelOutputAccumulator() async with asyncio.timeout(self.settings.model_timeout_seconds): @@ -159,6 +161,7 @@ async def build_messages( system_prompt: str | None, prompt: str | list[dict], model: str, + steering_messages: list[list[dict[str, Any]] | str] | None = None, ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: prompt_message: dict[str, Any] = {"role": "user", "content": prompt} try: @@ -172,10 +175,12 @@ async def build_messages( model=model, ) raise + steering_messages_native = [{"role": "user", "content": message} for message in (steering_messages or [])] if system_prompt: messages = [{"role": "system", "content": system_prompt}, *messages] - messages.append(prompt_message) - return messages, [prompt_message] + new_messages = [*steering_messages_native, prompt_message] + messages.extend(new_messages) + return messages, new_messages async def record_context_error( self, diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 83bbb8b0..44baed93 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -41,7 +41,7 @@ def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEve """Optionally wrap the output stream for this channel.""" return stream - def admit_message( + async def admit_message( self, session_id: str, message: Envelope, diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index a277ecbc..34900175 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -281,12 +281,12 @@ async def _main_loop(self) -> None: continue request = self._normalize_input(raw) - await self._echo_input(raw) message = ChannelMessage( session_id=self._message_template["session_id"], channel=self._message_template["channel"], chat_id=self._message_template["chat_id"], + context={"thread_id": self._message_template["session_id"]}, # use the same thread_id for all messages content=request, lifespan=self.message_lifespan(), ) @@ -332,11 +332,11 @@ def _prompt_label(self) -> str: symbol = ">" if self._mode == "agent" else "," return f"{cwd} {symbol} " - async def _echo_input(self, raw: str) -> None: + async def _echo_input(self, raw: str, steering: bool = False) -> None: stream_printer = getattr(self, "_stream_printer", None) if stream_printer is not None: await stream_printer.commit_live_text() - self._renderer.input_echo(self._prompt_label(), raw) + self._renderer.input_echo(self._prompt_label(), raw, steering=steering) async def stream_events( self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] @@ -416,12 +416,13 @@ def _history_file(home: Path, workspace: Path) -> Path: workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest() return home / "history" / f"{workspace_hash}.history" - def admit_message( + async def admit_message( self, session_id: str, message: Envelope, turn: TurnSnapshot, ) -> AdmitDecision | None: + await self._echo_input(message.content, steering=turn.is_running) if not turn.is_running: return None - return AdmitDecision("follow_up", reason="cli session is already generating") + return AdmitDecision("steer", reason="cli session is already generating") diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py index 97eec2b9..6b568e26 100644 --- a/src/bub/channels/cli/renderer.py +++ b/src/bub/channels/cli/renderer.py @@ -52,10 +52,11 @@ def error(self, text: str) -> None: return self.console.print(f"[red bold]Error >[/]\n{text}") - def input_echo(self, prompt: str, text: str) -> None: + def input_echo(self, prompt: str, text: str, steering: bool = False) -> None: if not text.strip(): return - self.console.print(f"[bold]{prompt}[/]{text}", new_line_start=True) + mid = "[grey](steering)[/] " if steering else "" + self.console.print(f"[dim][bold]{prompt}[/]{mid}{text}[/]", new_line_start=True) def tool_call_start(self, *, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: self.console.print(Text(_format_tool_call(name, args, kwargs), style="magenta"), new_line_start=True) diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index b3ec6726..34137020 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -120,11 +120,9 @@ def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> async def quit(self, session_id: str) -> None: controller = self._session_controllers.get(session_id) if controller is None: - self.framework.clear_steering(session_id) logger.info(f"channel.manager quit session_id={session_id}, cancelled 0 tasks") return controller.clear_pending() - controller.steering.drain_nowait() tasks = set(controller.active_tasks) current_task = asyncio.current_task() cancelled_count = 0 @@ -177,7 +175,7 @@ def enabled_channels(self) -> list[Channel]: def _controller(self, session_id: str) -> SessionTurnController: controller = self._session_controllers.get(session_id) if controller is None: - controller = SessionTurnController(session_id=session_id, steering=self.framework.steering(session_id)) + controller = SessionTurnController(session_id=session_id) self._session_controllers[session_id] = controller return controller @@ -185,10 +183,9 @@ def _drop_empty_controller(self, session_id: str) -> None: controller = self._session_controllers.get(session_id) if controller is None: return - if controller.active() or controller.pending_queue or controller.steering.count > 0: + if controller.active() or controller.pending_queue: return self._session_controllers.pop(session_id, None) - self.framework.clear_steering(session_id) def _on_task_done(self, session_id: str, task: asyncio.Task) -> None: if task.cancelled(): @@ -199,8 +196,6 @@ def _on_task_done(self, session_id: str, task: asyncio.Task) -> None: if controller is None: return controller.active_tasks.discard(task) - if not controller.active(): - controller.promote_steering_to_pending() self._schedule_pending(session_id) self._drop_empty_controller(session_id) @@ -249,13 +244,7 @@ async def _apply_admission_decision( if action == "follow_up": return self._queue_pending(controller, message, decision.reason) if action == "steer": - if controller.active(): - controller.steering.put_nowait(message) - logger.info( - "channel.manager admission steer session_id={} reason={}", - message.session_id, - decision.reason, - ) + if await self.framework.steer_message(message, decision.reason): return False return self._queue_pending(controller, message, decision.reason) logger.warning("channel.manager admission unknown action={} session_id={}", decision.action, message.session_id) @@ -326,23 +315,19 @@ async def listen_and_run(self) -> None: async def shutdown(self) -> None: count = 0 - session_ids = list(self._session_controllers) for controller in list(self._session_controllers.values()): controller.clear_pending() - controller.steering.drain_nowait() for task in set(controller.active_tasks): task.cancel() with contextlib.suppress(asyncio.CancelledError): await task count += 1 self._session_controllers.clear() - for session_id in session_ids: - self.framework.clear_steering(session_id) logger.info(f"channel.manager cancelled {count} in-flight tasks") for channel in self.enabled_channels(): await channel.stop() - def admit_channel_message( + async def admit_channel_message( self, session_id: str, message: Envelope, @@ -354,4 +339,4 @@ def admit_channel_message( channel = self.get_channel(str(channel_name)) if channel is None: return None - return channel.admit_message(session_id=session_id, message=message, turn=turn) + return await channel.admit_message(session_id=session_id, message=message, turn=turn) diff --git a/src/bub/channels/message.py b/src/bub/channels/message.py index 67e9d9cf..b026fa8c 100644 --- a/src/bub/channels/message.py +++ b/src/bub/channels/message.py @@ -53,7 +53,9 @@ def __post_init__(self) -> None: @property def context_str(self) -> str: """String representation of the context for prompt building.""" - return "|".join(f"{key}={value}" for key, value in self.context.items()) + return "|".join( + f"{key}={value}" for key, value in self.context.items() if not key.startswith("_") + ) # ignore internal keys @classmethod def from_batch(cls, batch: list[ChannelMessage]) -> ChannelMessage: diff --git a/src/bub/framework.py b/src/bub/framework.py index cbece9d5..063e5592 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -6,7 +6,7 @@ from collections.abc import AsyncGenerator, AsyncIterator, Iterator from dataclasses import dataclass from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import pluggy import typer @@ -19,7 +19,7 @@ from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs from bub.runtime import BubError, ErrorKind from bub.tape import AsyncTapeStore, TapeContext, TapeStore -from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot +from bub.turn_admission import AdmitDecision, TurnSnapshot from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult if TYPE_CHECKING: @@ -48,7 +48,6 @@ def __init__(self, config_file: Path = DEFAULT_CONFIG_FILE) -> None: self._hook_runtime = HookRuntime(self._plugin_manager) self._plugin_status: dict[str, PluginStatus] = {} self._outbound_router: OutboundChannelRouter | None = None - self._steering_buffers: dict[str, SteeringBuffer] = {} self._tape_store: TapeStore | AsyncTapeStore | None = None configure.load(self.config_file) @@ -106,6 +105,17 @@ def _main( self._hook_runtime.call_many_sync("register_cli_commands", app=app) return app + async def build_prompt( + self, message: Envelope, session_id: str, state: dict[str, Any] + ) -> str | list[dict[str, Any]]: + """Build prompt for one message turn.""" + prompt = await self._hook_runtime.call_first( + "build_prompt", message=message, session_id=session_id, state=state + ) + if not prompt: + prompt = content_of(message) + return cast("str | list[dict[str, Any]]", prompt) + async def process_inbound(self, inbound: Envelope, stream_output: bool = False) -> TurnResult: """Run one inbound message through hooks and return turn result.""" @@ -113,17 +123,13 @@ async def process_inbound(self, inbound: Envelope, stream_output: bool = False) session_id = await self.resolve_session(inbound) if isinstance(inbound, dict): inbound.setdefault("session_id", session_id) - state = {"_runtime_workspace": str(self.workspace), "_runtime_steering": self.steering(session_id)} + state = {"_runtime_workspace": str(self.workspace)} for hook_state in reversed( await self._hook_runtime.call_many("load_state", message=inbound, session_id=session_id) ): if isinstance(hook_state, dict): state.update(hook_state) - prompt = await self._hook_runtime.call_first( - "build_prompt", message=inbound, session_id=session_id, state=state - ) - if not prompt: - prompt = content_of(inbound) + prompt = await self.build_prompt(inbound, session_id, state) model_output = "" try: model_output = await self._run_model(inbound, prompt, session_id, state, stream_output) @@ -220,16 +226,6 @@ async def admit_message(self, *, session_id: str, message: Envelope, turn: TurnS return decision raise TypeError("hook.admit_message must return AdmitDecision or None") - def steering(self, session_id: str) -> SteeringBuffer: - buffer = self._steering_buffers.get(session_id) - if buffer is None: - buffer = SteeringBuffer(session_id=session_id) - self._steering_buffers[session_id] = buffer - return buffer - - def clear_steering(self, session_id: str) -> None: - self._steering_buffers.pop(session_id, None) - @staticmethod def _default_session_id(message: Envelope) -> str: session_id = field_of(message, "session_id") @@ -329,3 +325,6 @@ def collect_onboard_config(self) -> dict[str, Any]: raise TypeError("hook.onboard_config must return dict or None") configure.merge(current_config, result) return configure.validate(current_config) + + async def steer_message(self, message: Envelope, reason: str | None = None) -> Any: + return await self._hook_runtime.call_first("handle_steering", message=message, reason=reason) diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 6bb5da2b..e3ce40a8 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -121,3 +121,8 @@ def admit_message( Return ``None`` to keep Bub's default concurrent scheduling behavior. """ raise NotImplementedError + + @hookspec(firstresult=True) + async def handle_steering(self, message: Envelope, reason: str | None) -> bool: + """Handle a steering message that is admitted by the `admit_message` hook with action "steer".""" + raise NotImplementedError diff --git a/src/bub/turn_admission.py b/src/bub/turn_admission.py index a3216be6..444eb1dd 100644 --- a/src/bub/turn_admission.py +++ b/src/bub/turn_admission.py @@ -28,38 +28,6 @@ class TurnSnapshot: is_running: bool running_count: int pending_count: int - steering_count: int - - -@dataclass -class SteeringBuffer: - """Per-session queue for steering messages offered to active turns.""" - - session_id: str - _queue: deque[Envelope] = field(default_factory=deque, init=False, repr=False) - - def put_nowait(self, message: Envelope) -> None: - """Append one message.""" - - self._queue.append(message) - - @property - def count(self) -> int: - return len(self._queue) - - def get_nowait(self) -> Envelope | None: - """Return one queued message without waiting.""" - - if not self._queue: - return None - return self._queue.popleft() - - def drain_nowait(self) -> list[Envelope]: - """Drain steering input and acknowledge ownership of those messages.""" - - messages = list(self._queue) - self._queue.clear() - return messages @dataclass @@ -67,7 +35,6 @@ class SessionTurnController: """Per-session runtime queues used by ``ChannelManager``.""" session_id: str - steering: SteeringBuffer active_tasks: set[asyncio.Task] = field(default_factory=set) pending_queue: deque[Envelope] = field(default_factory=deque) @@ -81,17 +48,12 @@ def snapshot(self) -> TurnSnapshot: is_running=running_count > 0, running_count=running_count, pending_count=len(self.pending_queue), - steering_count=self.steering.count, ) def add_pending(self, message: Envelope) -> bool: self.pending_queue.append(message) return True - def add_pending_left(self, message: Envelope) -> bool: - self.pending_queue.appendleft(message) - return True - def pop_pending(self) -> Envelope | None: if not self.pending_queue: return None @@ -99,7 +61,3 @@ def pop_pending(self) -> Envelope | None: def clear_pending(self) -> None: self.pending_queue.clear() - - def promote_steering_to_pending(self) -> None: - for message in reversed(self.steering.drain_nowait()): - self.add_pending_left(message) diff --git a/src/bub/types.py b/src/bub/types.py index 7fb661a9..90197297 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -21,7 +21,7 @@ class OutboundChannelRouter(Protocol): async def dispatch_output(self, message: Envelope) -> bool: ... def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ... async def quit(self, session_id: str) -> None: ... - def admit_channel_message( + async def admit_channel_message( self, session_id: str, message: Envelope, @@ -34,6 +34,6 @@ class TurnResult: """Result of one complete message turn.""" session_id: str - prompt: str + prompt: str | list[dict[str, Any]] model_output: str outbounds: list[Envelope] = field(default_factory=list) diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index bbb4cb2d..1ddde615 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -1,6 +1,7 @@ from __future__ import annotations import contextlib +from collections import defaultdict, deque from collections.abc import AsyncGenerator, AsyncIterator from typing import Any from unittest.mock import MagicMock, patch @@ -36,12 +37,18 @@ def _make_agent() -> Agent: framework.get_tape_store.return_value = None framework.get_system_prompt.return_value = "" + async def build_prompt(message: dict[str, Any], session_id: str, state: dict[str, Any]) -> str: + return str(message["content"]) + + framework.build_prompt = build_prompt + with patch.object(Agent, "__init__", lambda self, fw: None): agent = Agent.__new__(Agent) agent.settings = AgentSettings.model_construct(model="test:model", api_key="k", api_base="b", client_args={}) agent.framework = framework agent.model_runner = _FakeModelRunner(agent.settings) + agent._steering_messages = defaultdict(deque) return agent @@ -243,6 +250,46 @@ async def test_agent_run_model_defaults_to_none() -> None: assert completion_kwargs["model"] == "test:model" +@pytest.mark.asyncio +async def test_agent_run_injects_steering_messages_once_by_tape_name() -> None: + agent = _make_agent() + fork_capture = _ForkCapture() + fake_tapes = _FakeTapeFactory(fork_capture) + agent.tape = fake_tapes # type: ignore[assignment] + + agent._steering_messages["test-tape"] = deque() + assert agent.enqueue_steering_message("other-tape", {"role": "user", "content": "ignore me"}) is False + assert agent.enqueue_steering_message("test-tape", {"role": "user", "content": "first steer"}) is True + assert agent.enqueue_steering_message("test-tape", {"role": "assistant", "content": "second steer"}) is True + + result = await agent.run_stream(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + [event async for event in result] + + completion_kwargs = _model_runner(agent).completion_kwargs + assert completion_kwargs is not None + completion_messages = completion_kwargs["messages"] + assert completion_messages[-3:] == [ + {"role": "user", "content": "first steer"}, + {"role": "user", "content": "second steer"}, + {"role": "user", "content": "hello"}, + ] + assert fake_tapes.tape.messages == [ + {"role": "user", "content": "first steer"}, + {"role": "user", "content": "second steer"}, + {"role": "user", "content": "hello"}, + {"role": "assistant", "content": "done"}, + ] + + result = await agent.run_stream(session_id="user/s1", prompt="again", state={"_runtime_workspace": "/tmp"}) # noqa: S108 + [event async for event in result] + + completion_kwargs = _model_runner(agent).completion_kwargs + assert completion_kwargs is not None + completion_messages = completion_kwargs["messages"] + assert completion_messages[-1] == {"role": "user", "content": "again"} + assert {"role": "user", "content": "ignore me"} not in completion_messages + + @pytest.mark.asyncio async def test_agent_run_resolves_allowed_tool_aliases_and_limits_prompt() -> None: allowed_name = "tests.allowed_agent_tool" diff --git a/tests/test_channels.py b/tests/test_channels.py index 54faa432..d279aa55 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -24,7 +24,7 @@ from bub.channels.message import ChannelMessage from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser from bub.runtime import StreamEvent -from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer, TurnSnapshot +from bub.turn_admission import AdmitDecision, SessionTurnController, TurnSnapshot ANSI_RE = re.compile(r"\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[()][A-Za-z])") @@ -117,7 +117,8 @@ def __init__(self, channels: dict[str, Channel]) -> None: self.process_calls: list[tuple[ChannelMessage, bool]] = [] self.admission_decisions: list[AdmitDecision | None] = [] self.admission_calls: list[tuple[str, ChannelMessage, object]] = [] - self._steering_buffers: dict[str, SteeringBuffer] = {} + self.steering_calls: list[tuple[ChannelMessage, str | None]] = [] + self.steering_results: list[bool | None] = [] self.resolved_sessions: dict[str, str] = {} self._hook_runtime = SimpleNamespace(notify_error=self._notify_error) self.running_entries = 0 @@ -154,15 +155,11 @@ async def admit_message(self, *, session_id: str, message: ChannelMessage, turn) async def resolve_session(self, message: ChannelMessage) -> str: return self.resolved_sessions.get(message.session_id, message.session_id) - def steering(self, session_id: str) -> SteeringBuffer: - buffer = self._steering_buffers.get(session_id) - if buffer is None: - buffer = SteeringBuffer(session_id=session_id) - self._steering_buffers[session_id] = buffer - return buffer - - def clear_steering(self, session_id: str) -> None: - self._steering_buffers.pop(session_id, None) + async def steer_message(self, message: ChannelMessage, reason: str | None = None) -> bool | None: + self.steering_calls.append((message, reason)) + if self.steering_results: + return self.steering_results.pop(0) + return False async def _notify_error(self, *, stage: str, error: Exception, message: ChannelMessage | None) -> None: return None @@ -366,7 +363,7 @@ async def on_receive(message: ChannelMessage) -> None: channel._renderer = SimpleNamespace( welcome=lambda **kwargs: None, info=lambda message: None, - input_echo=lambda prompt, text: echoed.append((prompt, text)), + input_echo=lambda prompt, text, steering=False: echoed.append((prompt, text, steering)), ) channel._refresh_tape_info = _async_return(None) @@ -379,7 +376,7 @@ async def on_receive(message: ChannelMessage) -> None: assert channel._prompt.received_callables == [True, True, True] assert "Generating\n" not in channel._prompt.messages[0] assert "Generating\n" in channel._prompt.messages[1] - assert echoed == [(f"{Path.cwd().name} > ", "first"), (f"{Path.cwd().name} > ", "second")] + assert echoed == [] assert all(message.lifespan is not None for message in received) @@ -429,23 +426,29 @@ def test_cli_channel_generating_spinner_renders_above_input_not_toolbar(monkeypa assert first_frame != second_frame -def test_cli_channel_admit_message_queues_follow_up_when_turn_is_running() -> None: +@pytest.mark.asyncio +async def test_cli_channel_admit_message_steers_when_turn_is_running() -> None: channel = CliChannel.__new__(CliChannel) + channel._mode = "agent" + echoed: list[tuple[str, str, bool]] = [] + channel._renderer = SimpleNamespace( + input_echo=lambda prompt, text, steering=False: echoed.append((prompt, text, steering)), + ) turn = TurnSnapshot( session_id="cli_session", is_running=True, running_count=1, pending_count=0, - steering_count=0, ) - decision = channel.admit_message( + decision = await channel.admit_message( session_id="cli_session", message=_message("second", channel="cli", session_id="cli_session"), turn=turn, ) - assert decision == AdmitDecision("follow_up", reason="cli session is already generating") + assert decision == AdmitDecision("steer", reason="cli session is already generating") + assert echoed == [(f"{Path.cwd().name} > ", "second", True)] @pytest.mark.asyncio @@ -646,7 +649,6 @@ async def never_finish() -> None: assert turn.is_running is True assert turn.running_count == 1 assert turn.pending_count == 0 - assert turn.steering_count == 0 active.cancel() with contextlib.suppress(asyncio.CancelledError): @@ -726,12 +728,61 @@ async def never_finish() -> None: @pytest.mark.asyncio -async def test_channel_manager_admission_steer_promotes_undrained_messages_to_pending(load_config) -> None: +async def test_channel_manager_admission_steer_temporarily_queues_pending_message(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.append(AdmitDecision("steer", reason="correction")) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("telegram:chat").active_tasks = {active} + + admitted = await manager._admit_message(_message("steer me")) + + assert admitted is False + assert [(message.content, reason) for message, reason in framework.steering_calls] == [("steer me", "correction")] + assert [message.content for message in manager._session_controllers["telegram:chat"].pending_queue] == ["steer me"] + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_steer_handler_takes_ownership(load_config) -> None: + _load_channel_config(load_config, enabled_channels="telegram") + framework = FakeFramework({"telegram": FakeChannel("telegram")}) + framework.admission_decisions.append(AdmitDecision("steer", reason="correction")) + framework.steering_results.append(True) + manager = ChannelManager(framework, enabled_channels=["telegram"]) + + async def never_finish() -> None: + await asyncio.sleep(10) + + active = asyncio.create_task(never_finish()) + manager._controller("telegram:chat").active_tasks = {active} + + admitted = await manager._admit_message(_message("steer me")) + + assert admitted is False + assert [(message.content, reason) for message, reason in framework.steering_calls] == [("steer me", "correction")] + assert not manager._session_controllers["telegram:chat"].pending_queue + + active.cancel() + with contextlib.suppress(asyncio.CancelledError): + await active + + +@pytest.mark.asyncio +async def test_channel_manager_admission_follow_up_preserves_pending_order(load_config) -> None: _load_channel_config(load_config, enabled_channels="telegram") framework = FakeFramework({"telegram": FakeChannel("telegram")}) framework.admission_decisions.extend([ - AdmitDecision("steer", reason="correction"), - AdmitDecision("steer", reason="correction"), + AdmitDecision("follow_up", reason="serial"), + AdmitDecision("follow_up", reason="serial"), ]) manager = ChannelManager(framework, enabled_channels=["telegram"]) @@ -752,58 +803,21 @@ async def test_channel_manager_admission_steer_promotes_undrained_messages_to_pe assert admitted is False assert admitted_again is False assert [message.content for message, _ in framework.process_calls] == [ + "already waiting", "actually do this", "then this", - "already waiting", ] -@pytest.mark.asyncio -async def test_channel_manager_admission_steer_drain_acknowledges_ownership(load_config) -> None: - _load_channel_config(load_config, enabled_channels="telegram") - framework = FakeFramework({"telegram": FakeChannel("telegram")}) - framework.admission_decisions.append(AdmitDecision("steer", reason="correction")) - manager = ChannelManager(framework, enabled_channels=["telegram"]) - - done = asyncio.create_task(asyncio.sleep(0)) - controller = manager._controller("telegram:chat") - controller.active_tasks = {done} - - admitted = await manager._admit_message(_message("consume me")) - drained = framework.steering("telegram:chat").drain_nowait() - await done - manager._on_task_done("telegram:chat", done) - - assert admitted is False - assert [message.content for message in drained] == ["consume me"] - assert framework.process_calls == [] - - def test_turn_admission_queues_preserve_messages_without_capacity_policy() -> None: - steering = SteeringBuffer(session_id="telegram:chat") - - steering.put_nowait(_message("one")) - steering.put_nowait(_message("two")) - steering.put_nowait(_message("three with a long body")) - drained_one = steering.get_nowait() - assert drained_one is not None - assert drained_one.content == "one" - assert [message.content for message in steering.drain_nowait()] == ["two", "three with a long body"] - - controller = SessionTurnController(session_id="telegram:chat", steering=SteeringBuffer(session_id="telegram:chat")) + controller = SessionTurnController(session_id="telegram:chat") controller.add_pending(_message("one")) controller.add_pending(_message("two")) controller.add_pending(_message("three with a long body")) assert [message.content for message in controller.pending_queue] == ["one", "two", "three with a long body"] - controller.add_pending_left(_message("priority")) - assert [message.content for message in controller.pending_queue] == [ - "priority", - "one", - "two", - "three with a long body", - ] + assert [message.content for message in controller.pending_queue] == ["one", "two", "three with a long body"] def test_cli_channel_normalize_input_prefixes_shell_commands() -> None: @@ -944,7 +958,7 @@ async def commit_live_text(self) -> None: channel._stream_printer = FakeStreamPrinter() channel._mode = "agent" - channel._renderer = SimpleNamespace(input_echo=lambda prompt, text: calls.append(f"echo:{text}")) + channel._renderer = SimpleNamespace(input_echo=lambda prompt, text, steering=False: calls.append(f"echo:{text}")) await channel._echo_input("steer now") diff --git a/tests/test_framework.py b/tests/test_framework.py index f39b8b03..d64175a3 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -20,7 +20,7 @@ from bub.framework import BubFramework from bub.hookspecs import hookimpl from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState -from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot +from bub.turn_admission import AdmitDecision, TurnSnapshot def make_named_channel(name: str, label: str) -> Channel: @@ -285,26 +285,6 @@ async def dispatch_outbound(self, message) -> bool: assert saved_outputs == ["plain-text"] -@pytest.mark.asyncio -async def test_process_inbound_exposes_runtime_steering_handle() -> None: - framework = BubFramework() - observed_state: dict[str, Any] = {} - - class SteeringAwarePlugin: - @hookimpl - async def run_model(self, prompt, session_id, state) -> str: - observed_state.update(state) - return "ok" - - framework._plugin_manager.register(SteeringAwarePlugin(), name="steering-aware") - - result = await framework.process_inbound({"session_id": "session", "content": "hi"}) - - assert result.model_output == "ok" - assert isinstance(observed_state["_runtime_steering"], SteeringBuffer) - assert observed_state["_runtime_steering"].session_id == "session" - - @pytest.mark.asyncio async def test_framework_admit_message_calls_hook_with_snapshot() -> None: framework = BubFramework() @@ -326,7 +306,6 @@ def admit_message(self, session_id, message, turn): is_running=True, running_count=1, pending_count=1, - steering_count=0, ), ) diff --git a/website/src/content/docs/docs/reference/types.mdx b/website/src/content/docs/docs/reference/types.mdx index 2d98223f..8360f042 100644 --- a/website/src/content/docs/docs/reference/types.mdx +++ b/website/src/content/docs/docs/reference/types.mdx @@ -51,7 +51,7 @@ Normalizes one `render_outbound` return value to a list. `None` → `[]`; `list` type State = dict[str, Any] ``` -The per-turn state dict. The framework seeds it with `_runtime_workspace` and `_runtime_steering`, then merges the results of every `load_state` hook before the model call. The same dict is passed to `build_prompt`, `run_model[_stream]`, `save_state`, `render_outbound`, and `system_prompt`. +The per-turn state dict. The framework seeds it with `_runtime_workspace`, then merges the results of every `load_state` hook before the model call. The same dict is passed to `build_prompt`, `run_model[_stream]`, `save_state`, `render_outbound`, and `system_prompt`. ## `MessageHandler` @@ -113,23 +113,9 @@ class TurnSnapshot: is_running: bool running_count: int pending_count: int - steering_count: int ``` -`admit_message` implementations return `AdmitDecision` to tell the channel manager whether to process immediately, drop, queue as follow-up input, or steer one inbound message. - -```python -@dataclass -class SteeringBuffer: - session_id: str - - @property - def count(self) -> int: ... - def get_nowait(self) -> Envelope | None: ... - def drain_nowait(self) -> list[Envelope]: ... -``` - -`SteeringBuffer` is exposed to model hooks as `state["_runtime_steering"]`. `get_nowait()` removes one message; `drain_nowait()` removes all currently queued messages. Returned messages are owned by the model hook and will not be replayed. +`admit_message` implementations return `AdmitDecision` to tell the channel manager whether to process immediately, drop, queue as follow-up input, or route as steering input. The `steer` handler is currently a placeholder and falls back to the follow-up queue until dedicated steering handling is rebuilt. ## `Channel` @@ -206,8 +192,6 @@ class BubFramework: async def admit_message( self, *, session_id: str, message: Envelope, turn: TurnSnapshot ) -> AdmitDecision | None: ... - def steering(self, session_id: str) -> SteeringBuffer: ... - def clear_steering(self, session_id: str) -> None: ... @contextlib.asynccontextmanager async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]: ... @@ -229,8 +213,6 @@ class BubFramework: | `get_system_prompt(prompt, state)` | Run `system_prompt` impls (sync), reverse, and join non-empty results with `\n\n`. | | `hook_report()` | Map hook name → discovered adapter names. Backs `bub hooks`; read the hook reference before treating this order as runtime precedence. | | `admit_message(...)` | Call the `admit_message` hook and return the selected decision. Used by `ChannelManager`. | -| `steering(session_id)` | Return the per-session steering buffer exposed to model hooks. | -| `clear_steering(session_id)` | Clear an idle session's steering buffer. | | `running()` | Async context manager; resolves `provide_tape_store` once and binds the resulting store for the duration. | | `bind_outbound_router(router)` | Attach (or detach with `None`) the `OutboundChannelRouter`. The `ChannelManager` calls this on start/stop. | | `build_tape_context()` | Sync-call `build_tape_context` and return the resulting `TapeContext`. | @@ -247,7 +229,6 @@ From `src/bub/__init__.py`: | `BubFramework` | class | Framework runtime (above). | | `AdmitDecision` | dataclass | Decision returned by `admit_message`. | | `Settings` | class | Base class for plugin settings (re-exported from `bub.configure`). | -| `SteeringBuffer` | dataclass | Per-session steering queue handle exposed to model hooks. | | `TurnSnapshot` | dataclass | Snapshot passed to `admit_message`. | | `config` | decorator | `@config(name="...")` registers a settings class for YAML/env validation. | | `ensure_config` | function | `ensure_config(SettingsCls)` — return the singleton instance for that class. | diff --git a/website/src/content/docs/zh-cn/docs/reference/types.mdx b/website/src/content/docs/zh-cn/docs/reference/types.mdx index fc465d70..fd87afe9 100644 --- a/website/src/content/docs/zh-cn/docs/reference/types.mdx +++ b/website/src/content/docs/zh-cn/docs/reference/types.mdx @@ -51,7 +51,7 @@ def unpack_batch(batch: Any) -> list[Envelope] type State = dict[str, Any] ``` -per-turn 的 state dict。框架先以 `_runtime_workspace` 与 `_runtime_steering` 初始化,再合并所有 `load_state` 钩子的结果,然后才调用模型。同一个 dict 会被传给 `build_prompt`、`run_model[_stream]`、`save_state`、`render_outbound` 与 `system_prompt`。 +per-turn 的 state dict。框架先以 `_runtime_workspace` 初始化,再合并所有 `load_state` 钩子的结果,然后才调用模型。同一个 dict 会被传给 `build_prompt`、`run_model[_stream]`、`save_state`、`render_outbound` 与 `system_prompt`。 ## `MessageHandler` @@ -113,23 +113,9 @@ class TurnSnapshot: is_running: bool running_count: int pending_count: int - steering_count: int ``` -`admit_message` 实现返回 `AdmitDecision`,告诉 channel manager 对单条 inbound message 立即 process、drop、作为 follow-up input 排队,或 steer。 - -```python -@dataclass -class SteeringBuffer: - session_id: str - - @property - def count(self) -> int: ... - def get_nowait(self) -> Envelope | None: ... - def drain_nowait(self) -> list[Envelope]: ... -``` - -`SteeringBuffer` 会以 `state["_runtime_steering"]` 暴露给 model hooks。`get_nowait()` 取出一条;`drain_nowait()` 取出当前全部 queued messages。返回的消息由 model hook 接管,不会重放。 +`admit_message` 实现返回 `AdmitDecision`,告诉 channel manager 对单条 inbound message 立即 process、drop、作为 follow-up input 排队,或作为 steering input 路由。`steer` handler 当前只是占位,在专用 steering 处理恢复前会退化到 follow-up queue。 ## `Channel` @@ -206,8 +192,6 @@ class BubFramework: async def admit_message( self, *, session_id: str, message: Envelope, turn: TurnSnapshot ) -> AdmitDecision | None: ... - def steering(self, session_id: str) -> SteeringBuffer: ... - def clear_steering(self, session_id: str) -> None: ... @contextlib.asynccontextmanager async def running(self) -> AsyncGenerator[contextlib.AsyncExitStack, None]: ... @@ -229,8 +213,6 @@ class BubFramework: | `get_system_prompt(prompt, state)` | 同步调用 `system_prompt` 实现,反转后用 `\n\n` 拼接非空片段。 | | `hook_report()` | 返回 hook 名 → 已发现的 adapter 列表。`bub hooks` 的数据来源;不要只根据该输出顺序推断运行时优先级。 | | `admit_message(...)` | 调用 `admit_message` hook 并返回选中的 decision。由 `ChannelManager` 使用。 | -| `steering(session_id)` | 返回暴露给 model hooks 的 per-session steering buffer。 | -| `clear_steering(session_id)` | 清除 idle session 的 steering buffer。 | | `running()` | 异步 context manager;一次性解析 `provide_tape_store` 并在作用域内绑定 tape store。 | | `bind_outbound_router(router)` | 绑定(或传 `None` 解绑)`OutboundChannelRouter`。`ChannelManager` 在启停时调用。 | | `build_tape_context()` | 同步调用 `build_tape_context` 并返回 `TapeContext`。 | @@ -247,7 +229,6 @@ class BubFramework: | `BubFramework` | class | 框架运行时(见上)。 | | `AdmitDecision` | dataclass | `admit_message` 返回的 decision。 | | `Settings` | class | 插件配置基类(从 `bub.configure` 重新导出)。 | -| `SteeringBuffer` | dataclass | 暴露给 model hooks 的 per-session steering queue handle。 | | `TurnSnapshot` | dataclass | 传给 `admit_message` 的快照。 | | `config` | decorator | `@config(name="...")` 注册一个用于 YAML/env 验证的配置类。 | | `ensure_config` | function | `ensure_config(SettingsCls)` —— 返回该类的单例实例。 |