From ece7b7dad8310dcf114894eddece522cbd56b75b Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 14 Jun 2026 22:35:29 +0800 Subject: [PATCH 1/2] refactor: bind native envelopes for model streams --- src/bub/builtin/agent.py | 156 ++++++++++++++++++++----------- src/bub/builtin/hook_impl.py | 21 +++-- src/bub/builtin/tape.py | 2 +- src/bub/builtin/tools.py | 15 ++- src/bub/channels/base.py | 4 +- src/bub/channels/cli/__init__.py | 27 +++--- src/bub/channels/manager.py | 3 +- src/bub/envelope.py | 2 + src/bub/errors.py | 40 ++++++++ src/bub/framework.py | 36 +++---- src/bub/hook_runtime.py | 50 +++++----- src/bub/hookspecs.py | 20 ++-- src/bub/runtime.py | 70 -------------- src/bub/tape.py | 2 +- src/bub/tools.py | 2 +- src/bub/types.py | 16 +++- tests/test_builtin_agent.py | 36 ++++--- tests/test_builtin_hook_impl.py | 19 ++-- tests/test_builtin_tools.py | 2 +- tests/test_channels.py | 13 ++- tests/test_framework.py | 36 +++++-- tests/test_hook_runtime.py | 35 +++++-- tests/test_subagent_tool.py | 8 +- 23 files changed, 348 insertions(+), 267 deletions(-) create mode 100644 src/bub/errors.py delete mode 100644 src/bub/runtime.py diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 54a51ce1..65926ba1 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -7,7 +7,7 @@ import re import shlex import time -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Collection, Coroutine, Iterable from contextlib import AsyncExitStack from dataclasses import dataclass, replace from datetime import UTC, datetime @@ -32,8 +32,9 @@ from bub.builtin.settings import ModelCandidate, load_settings from bub.builtin.store import ForkTapeStore from bub.builtin.tape import TapeService +from bub.envelope import field_of +from bub.errors import BubError, ErrorKind from bub.framework import BubFramework -from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState from bub.skills import discover_skills, render_skills_prompt from bub.tape import InMemoryTapeStore, Tape from bub.tools import ( @@ -42,7 +43,7 @@ ToolContext, ToolExecutor, ) -from bub.types import State +from bub.types import Envelope, State from bub.utils import workspace_from_state CONTINUE_PROMPT = "Continue the task until all targets are completed." @@ -55,6 +56,45 @@ TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any]) +def _stream_event(kind: str, data: dict[str, Any] | None = None) -> Envelope: + return {"kind": kind, "data": data or {}} + + +def _event_data(event: Envelope) -> dict[str, Any]: + data = field_of(event, "data", {}) + return data if isinstance(data, dict) else {} + + +class BuiltinModelStream: + """Builtin-owned stream envelope and binding.""" + + def __init__(self, events: AsyncIterable[Envelope]) -> None: + self._events = events + self._output_parts: list[str] = [] + self._stream_started = False + + def stream(self) -> AsyncIterable[Envelope] | None: + if self._stream_started: + return None + self._stream_started = True + + async def iterator() -> AsyncIterator[Envelope]: + async for event in self._events: + if field_of(event, "kind") == "text": + delta = str(_event_data(event).get("delta", "")) + self._output_parts.append(delta) + yield {"content": delta, "source": event} + elif field_of(event, "kind") == "final": + yield {"end": True, "source": event} + else: + yield event + + return iterator() + + def output(self) -> Envelope | None: + return "".join(self._output_parts) + + class Agent: """Agent that processes prompts using hooks, tools, tape, and any-llm-sdk.""" @@ -73,25 +113,25 @@ def tapes(self) -> TapeService: return TapeService(bub.home / "tapes", tape_store, self.framework.build_tape_context()) @staticmethod - def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents: - async def generator() -> AsyncIterator: + def _events_from_iterable(iterable: Iterable[Envelope]) -> AsyncIterable[Envelope]: + async def generator() -> AsyncIterator[Envelope]: for item in iterable: yield item - return AsyncStreamEvents(generator()) + return generator() @staticmethod def _events_with_callback( - events: AsyncStreamEvents, callback: Callable[[], Coroutine[Any, Any, Any]] - ) -> AsyncStreamEvents: - async def generator() -> AsyncIterator[StreamEvent]: + events: AsyncIterable[Envelope], callback: Callable[[], Coroutine[Any, Any, Any]] + ) -> AsyncIterable[Envelope]: + async def generator() -> AsyncIterator[Envelope]: try: async for event in events: yield event finally: await callback() - return AsyncStreamEvents(generator(), state=events._state) + return generator() async def run( self, @@ -112,9 +152,12 @@ async def run( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - async for event in stream: - if event.kind == "text": - output.append(str(event.data.get("delta", ""))) + events = stream.stream() + if events is not None: + async for _event in events: + pass + if result := stream.output(): + output.append(str(result)) return "".join(output) async def run_stream( @@ -126,12 +169,14 @@ async def run_stream( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncStreamEvents: + ) -> BuiltinModelStream: if not prompt: - return self._events_from_iterable([ - StreamEvent("text", {"delta": "error: empty prompt"}), - StreamEvent("final", {"text": "error: empty prompt", "ok": False}), - ]) + return BuiltinModelStream( + self._events_from_iterable([ + _stream_event("text", {"delta": "error: empty prompt"}), + _stream_event("final", {"text": "error: empty prompt", "ok": False}), + ]) + ) tape = self.tapes.session_tape(session_id, workspace_from_state(state)) tape.context = replace(tape.context, state=state) @@ -143,8 +188,8 @@ async def run_stream( if isinstance(prompt, str) and prompt.strip().startswith(","): result = await self._run_command(tape=tape, line=prompt.strip()) events = self._events_from_iterable([ - StreamEvent("text", {"delta": result}), - StreamEvent("final", {"text": result, "ok": True}), + _stream_event("text", {"delta": result}), + _stream_event("final", {"text": result, "ok": True}), ]) else: events = await self._agent_loop( @@ -154,7 +199,7 @@ async def run_stream( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - return self._events_with_callback(events, callback=stack.aclose) + return BuiltinModelStream(self._events_with_callback(events, callback=stack.aclose)) async def _run_command(self, tape: Tape, *, line: str) -> str: line = line[1:].strip() @@ -204,7 +249,7 @@ async def _agent_loop( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncStreamEvents: + ) -> AsyncIterable[Envelope]: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model await self.tapes.append_event( @@ -217,7 +262,7 @@ async def _agent_loop( "allowed_tools": list(allowed_tools) if allowed_tools else None, }, ) - state = StreamState() + state: dict[str, Any] = {} iterator = self._stream_events_with_auto_handoff( tape=tape, prompt=next_prompt, @@ -226,17 +271,17 @@ async def _agent_loop( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - return AsyncStreamEvents(iterator, state=state) + return iterator async def _stream_events_with_auto_handoff( self, tape: Tape, prompt: str | list[dict], - state: StreamState, + state: dict[str, Any], model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncGenerator[StreamEvent, None]: + ) -> AsyncGenerator[Envelope, None]: auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES display_model = model or self.settings.model next_prompt = prompt @@ -255,7 +300,9 @@ async def _stream_events_with_auto_handoff( ) async for event in output: yield event - if event.kind == "error": + kind = field_of(event, "kind") + data = _event_data(event) + if kind == "error": elapsed_ms = int((time.monotonic() - start) * 1000) await self.tapes.append_event( tape.name, @@ -264,12 +311,12 @@ async def _stream_events_with_auto_handoff( "step": step, "elapsed_ms": elapsed_ms, "status": "error", - "error": event.data.get("message", ""), + "error": data.get("message", ""), "date": datetime.now(UTC).isoformat(), }, ) - elif event.kind == "final": - should_continue = bool(event.data.get("tool_calls") or event.data.get("tool_results")) + elif kind == "final": + should_continue = bool(data.get("tool_calls") or data.get("tool_results")) except Exception as exc: error_message = f"{exc!s}" elapsed_ms = int((time.monotonic() - start) * 1000) @@ -312,8 +359,6 @@ async def _stream_events_with_auto_handoff( ) raise - state.error = output.error - state.usage = output.usage elapsed_ms = int((time.monotonic() - start) * 1000) if not should_continue: await self.tapes.append_event( @@ -359,7 +404,7 @@ async def _run_once( model: str | None = None, allowed_tools: Collection[str] | None = None, allowed_skills: Collection[str] | None = None, - ) -> AsyncStreamEvents: + ) -> AsyncIterable[Envelope]: prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt) if allowed_tools is not None: from bub.builtin.tools import resolve_tool_names @@ -390,10 +435,10 @@ async def _run_once_stream( model: str | None, allowed_skills: set[str] | None, tools: list[Tool], - ) -> AsyncStreamEvents: - state = StreamState() + ) -> AsyncIterable[Envelope]: + state: dict[str, Any] = {} - async def iterator() -> AsyncGenerator[StreamEvent, None]: + async def iterator() -> AsyncGenerator[Envelope, None]: system_prompt = self._system_prompt( prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools ) @@ -447,7 +492,7 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]: tool_invocations = [ _tool_invocation_from_native(tool_call, tool_map) for tool_call in native_tool_calls ] - yield StreamEvent("tool_call", {"tool_calls": serialized_tool_calls}) + yield _stream_event("tool_call", {"tool_calls": serialized_tool_calls}) context = ToolContext(tape=tape.name, run_id=run_id, state=tape.context.state) execution = await ToolExecutor().execute_async( tool_invocations, @@ -463,11 +508,12 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]: tool_results=execution.tool_results, response=response, model=model or self.settings.model, - usage=state.usage, + usage=state.get("usage"), ) - yield StreamEvent("tool_result", {"tool_results": execution.tool_results}) - yield StreamEvent( - "final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results} + yield _stream_event("tool_result", {"tool_results": execution.tool_results}) + yield _stream_event( + "final", + {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results}, ) return @@ -479,11 +525,11 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]: response_text=text, response=response, model=model or self.settings.model, - usage=state.usage, + usage=state.get("usage"), ) - yield StreamEvent("final", {"ok": True, "text": text}) + yield _stream_event("final", {"ok": True, "text": text}) - return AsyncStreamEvents(iterator(), state=state) + return iterator() def _build_llm(self, candidate: ModelCandidate) -> AnyLLM: return AnyLLM.create( @@ -609,13 +655,13 @@ def _parse_native_function_call(tool_call: ChatCompletionMessageToolCall) -> tup async def _completion_events( completion: ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk], - state: StreamState, + state: dict[str, Any], text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[StreamEvent, None]: +) -> AsyncGenerator[Envelope, None]: if isinstance(completion, ChatCompletion): if usage := TapeService._extract_usage(completion): - state.usage = usage + state["usage"] = usage message = completion.choices[0].message for event in _completion_message_events(message, text_parts, tool_calls): yield event @@ -630,30 +676,30 @@ def _completion_message_events( message: ChatCompletionMessage, text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> Iterable[StreamEvent]: +) -> Iterable[Envelope]: if message.reasoning: - yield StreamEvent("reasoning", {"delta": _reasoning_text(message.reasoning)}) + yield _stream_event("reasoning", {"delta": _reasoning_text(message.reasoning)}) if message.content: text_parts.append(message.content) - yield StreamEvent("text", {"delta": message.content}) + yield _stream_event("text", {"delta": message.content}) tool_calls.add_message_calls(cast("Iterable[ChatCompletionMessageToolCall]", message.tool_calls or [])) async def _completion_chunk_events( chunk: ChatCompletionChunk, - state: StreamState, + state: dict[str, Any], text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[StreamEvent, None]: +) -> AsyncGenerator[Envelope, None]: if usage := TapeService._extract_usage(chunk): - state.usage = usage + state["usage"] = usage for choice in chunk.choices: delta = choice.delta if delta.reasoning: - yield StreamEvent("reasoning", {"delta": _reasoning_text(delta.reasoning)}) + yield _stream_event("reasoning", {"delta": _reasoning_text(delta.reasoning)}) if delta.content: text_parts.append(delta.content) - yield StreamEvent("text", {"delta": delta.content}) + yield _stream_event("text", {"delta": delta.content}) if delta.tool_calls: tool_calls.merge_delta_calls(delta.tool_calls) diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index 0ed75721..e9a73ac7 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -7,7 +7,7 @@ from loguru import logger from bub import inquirer as bub_inquirer -from bub.builtin.agent import Agent +from bub.builtin.agent import Agent, BuiltinModelStream from bub.builtin.context import default_tape_context from bub.builtin.settings import DEFAULT_MODEL from bub.channels.base import Channel @@ -15,9 +15,8 @@ from bub.envelope import content_of, field_of from bub.framework import BubFramework from bub.hookspecs import hookimpl -from bub.runtime import AsyncStreamEvents from bub.tape import TapeContext, TapeStore -from bub.types import Envelope, MessageHandler, State +from bub.types import Envelope, EnvelopeBinding, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( @@ -120,7 +119,7 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State: return state @hookimpl - async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: str) -> None: + async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: Envelope) -> None: tp, value, traceback = sys.exc_info() lifespan = field_of(message, "lifespan") if lifespan is not None: @@ -156,13 +155,19 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St return text @hookimpl - async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: + async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: return await self._get_agent().run(session_id=session_id, prompt=prompt, state=state) @hookimpl - async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: + async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state) + @hookimpl + def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None: + if isinstance(envelope, BuiltinModelStream): + return envelope + return None + @hookimpl def register_cli_commands(self, app: typer.Typer) -> None: from bub.builtin import cli @@ -273,13 +278,13 @@ def render_outbound( message: Envelope, session_id: str, state: State, - model_output: str, + model_output: Envelope, ) -> list[ChannelMessage]: outbound = ChannelMessage( session_id=session_id, channel=field_of(message, "channel", "default"), chat_id=field_of(message, "chat_id", "default"), - content=model_output, + content=content_of(model_output), output_channel=field_of(message, "output_channel", "default"), kind=field_of(message, "kind", "normal"), ) diff --git a/src/bub/builtin/tape.py b/src/bub/builtin/tape.py index ca992521..c89d9bc3 100644 --- a/src/bub/builtin/tape.py +++ b/src/bub/builtin/tape.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from bub.builtin.store import ForkTapeStore -from bub.runtime import BubError +from bub.errors import BubError from bub.tape import AsyncTapeStore, Tape, TapeContext, TapeEntry, TapeQuery, build_messages diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index 51484d52..88c89af8 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -359,20 +359,19 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str: subagent_session = param.session state = {**context.state, "session_id": subagent_session} allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"}) - output = "" - async for event in await agent.run_stream( + stream = await agent.run_stream( session_id=subagent_session, prompt=param.prompt, state=state, model=param.model, allowed_tools=allowed_tools, allowed_skills=param.allowed_skills, - ): - if event.kind == "error": - output += f"[Error: {event.data.get('message', 'unknown error')}]" - elif event.kind == "text": - output += str(event.data.get("delta", "")) - return output + ) + events = stream.stream() + if events is not None: + async for _event in events: + pass + return str(stream.output() or "") @tool(name="help") diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 920469ec..55112d96 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -4,7 +4,7 @@ from typing import ClassVar from bub.channels.message import ChannelMessage -from bub.runtime import StreamEvent +from bub.types import Envelope class Channel(ABC): @@ -35,7 +35,7 @@ async def send(self, message: ChannelMessage) -> None: # Do nothing by default return - def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: + def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: """Optionally wrap the output stream for this channel.""" return stream diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 1b87324b..1df6d6b0 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -23,10 +23,9 @@ from bub.channels.base import Interface from bub.channels.cli.renderer import CliRenderer from bub.channels.message import ChannelMessage -from bub.envelope import field_of -from bub.runtime import StreamEvent +from bub.envelope import content_of, field_of from bub.tools import REGISTRY -from bub.types import MessageHandler +from bub.types import Envelope, MessageHandler class _StreamPrinter: @@ -39,16 +38,20 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking: self._reasoning_status: Status | None = None self.head_printed = False - def render(self, event: StreamEvent) -> bool: - if event.kind == "reasoning": - self._record_reasoning(str(event.data.get("delta", ""))) + def render(self, event: Envelope) -> bool: + kind = field_of(event, "kind") + data = field_of(event, "data", {}) + data = data if isinstance(data, dict) else {} + if kind == "reasoning": + self._record_reasoning(str(data.get("delta", ""))) return True - if event.kind == "text": - return self._print_content(str(event.data.get("delta", ""))) - elif event.kind == "tool_call": + content = content_of(event) + if content: + return self._print_content(content) + if kind == "tool_call": self._print_stream_boundary() - elif event.kind == "final": + elif kind == "final" or field_of(event, "end"): self._print_end() return True @@ -236,9 +239,7 @@ def _prompt_message(self) -> FormattedText: symbol = ">" if self._mode == "agent" else "," return FormattedText([("bold", f"{cwd} {symbol} ")]) - async def stream_events( - self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] - ) -> AsyncIterable[StreamEvent]: + async def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: console = get_console() printer = _StreamPrinter( console=console, diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 313cc2ce..3a0fa7cf 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -14,7 +14,6 @@ from bub.configure import Settings, ensure_config from bub.envelope import content_of, field_of from bub.framework import BubFramework -from bub.runtime import StreamEvent from bub.turn_admission import AdmitDecision, SessionTurnController from bub.types import Envelope, MessageHandler from bub.utils import wait_until_stopped @@ -105,7 +104,7 @@ async def dispatch_output(self, message: Envelope) -> bool: await channel.send(outbound) return True - def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: + def wrap_stream(self, message: Envelope, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: channel_name = field_of(message, "output_channel", field_of(message, "channel")) if channel_name is None: return stream diff --git a/src/bub/envelope.py b/src/bub/envelope.py index 362d4bc5..48d1f0ab 100644 --- a/src/bub/envelope.py +++ b/src/bub/envelope.py @@ -19,6 +19,8 @@ def field_of(message: Envelope, key: str, default: Any = None) -> Any: def content_of(message: Envelope) -> str: """Get textual content from any envelope shape.""" + if isinstance(message, str): + return message return str(field_of(message, "content", "")) diff --git a/src/bub/errors.py b/src/bub/errors.py new file mode 100644 index 00000000..f8346e88 --- /dev/null +++ b/src/bub/errors.py @@ -0,0 +1,40 @@ +"""Shared Bub error types.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + + +class ErrorKind(StrEnum): + """Stable error kinds for Bub failures.""" + + INVALID_INPUT = "invalid_input" + CONFIG = "config" + PROVIDER = "provider" + TOOL = "tool" + TEMPORARY = "temporary" + NOT_FOUND = "not_found" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class BubError(Exception): + """Public error type for Bub failures.""" + + kind: ErrorKind + message: str + details: dict[str, Any] | None = None + + def __str__(self) -> str: + return f"[{self.kind.value}] {self.message}" + + def as_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "kind": self.kind.value, + "message": self.message, + } + if self.details: + payload["details"] = self.details + return payload diff --git a/src/bub/framework.py b/src/bub/framework.py index cbece9d5..5a2f5041 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -17,7 +17,6 @@ from bub.envelope import content_of, field_of, unpack_batch from bub.hook_runtime import _SKIP_VALUE, HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs -from bub.runtime import BubError, ErrorKind from bub.tape import AsyncTapeStore, TapeContext, TapeStore from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult @@ -158,7 +157,7 @@ async def _run_model( session_id: str, state: dict[str, Any], stream_output: bool, - ) -> str: + ) -> Envelope: if not stream_output: output = await self._hook_runtime.run_model(prompt=prompt, session_id=session_id, state=state) if output is None: @@ -169,28 +168,29 @@ async def _run_model( ) return prompt if isinstance(prompt, str) else content_of(inbound) return output - stream = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) - if stream is None: + stream_envelope = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) + if stream_envelope is None: await self._hook_runtime.notify_error( stage="run_model", error=RuntimeError("no model skill returned output"), message=inbound, ) return prompt if isinstance(prompt, str) else content_of(inbound) + binding = await self._hook_runtime.bind_envelope(stream_envelope, session_id=session_id, state=state) + if binding is None: + return stream_envelope else: - parts: list[str] = [] - events = self._outbound_router.wrap_stream(inbound, stream) if self._outbound_router is not None else stream - async for event in events: - if event.kind == "text": - parts.append(str(event.data.get("delta", ""))) - elif event.kind == "error": - # Turn "kind" to enum type otherwise BubError's __str__ won't work well. - data = { - **event.data, - "kind": ErrorKind(event.data.get("kind", "unknown")), - } - await self._hook_runtime.notify_error(stage="run_model", error=BubError(**data), message=inbound) - return "".join(parts) + stream = binding.stream() + if stream is not None: + events = ( + self._outbound_router.wrap_stream(inbound, stream) if self._outbound_router is not None else stream + ) + async for _event in events: + pass + output = binding.output() + if output is None: + return prompt if isinstance(prompt, str) else content_of(inbound) + return output def hook_report(self) -> dict[str, list[str]]: """Return hook implementation summary for diagnostics.""" @@ -244,7 +244,7 @@ async def _collect_outbounds( message: Envelope, session_id: str, state: dict[str, Any], - model_output: str, + model_output: Envelope, ) -> list[Envelope]: batches = await self._hook_runtime.call_many( "render_outbound", diff --git a/src/bub/hook_runtime.py b/src/bub/hook_runtime.py index bdb0eaf7..42dec20d 100644 --- a/src/bub/hook_runtime.py +++ b/src/bub/hook_runtime.py @@ -3,14 +3,12 @@ from __future__ import annotations import inspect -from collections.abc import AsyncGenerator from typing import Any import pluggy from loguru import logger -from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState -from bub.types import Envelope +from bub.types import Envelope, EnvelopeBinding class HookRuntime: @@ -160,41 +158,43 @@ def _iter_hookimpls(self, hook_name: str) -> list[Any]: def _kwargs_for_impl(impl: Any, kwargs: dict[str, Any]) -> dict[str, Any]: return {name: kwargs[name] for name in impl.argnames if name in kwargs} - async def run_model(self, prompt: str | list[dict], session_id: str, state: dict[str, Any]) -> str | None: + async def run_model(self, prompt: str | list[dict], session_id: str, state: dict[str, Any]) -> Envelope | None: """Run the first `run_model` hook found and return its result.""" for _, plugin in reversed(self._plugin_manager.list_name_plugin()): if hasattr(plugin, "run_model"): - output = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) - if output is None or isinstance(output, str): - return output - raise TypeError("hook.run_model must return str or None") + return await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) elif hasattr(plugin, "run_model_stream"): - stream = await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) - text = "" - async for event in stream: - if event.kind == "text": - text += str(event.data.get("delta", "")) - return text + stream_envelope = await self.run_model_stream(prompt=prompt, session_id=session_id, state=state) + binding = await self.bind_envelope(stream_envelope, session_id=session_id, state=state) + if binding is None: + return stream_envelope + stream = binding.stream() + if stream is not None: + async for _event in stream: + pass + return binding.output() return None async def run_model_stream( self, prompt: str | list[dict], session_id: str, state: dict[str, Any] - ) -> AsyncStreamEvents | None: + ) -> Envelope | None: """Run the first `run_model_stream` hook found and fallback to `run_model` hook.""" for _, plugin in reversed(self._plugin_manager.list_name_plugin()): if hasattr(plugin, "run_model_stream"): - stream = await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) - if stream is None or isinstance(stream, AsyncStreamEvents): - return stream - raise TypeError("hook.run_model_stream must return AsyncStreamEvents or None") + return await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) elif hasattr(plugin, "run_model"): - - async def iterator() -> AsyncGenerator[StreamEvent, None]: - result = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) - yield StreamEvent("text", {"delta": result}) - - return AsyncStreamEvents(iterator(), state=StreamState()) + return await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) return None + async def bind_envelope( + self, envelope: Envelope, *, session_id: str, state: dict[str, Any] + ) -> EnvelopeBinding | None: + if envelope is None: + return None + binding = await self.call_first("bind_envelope", envelope=envelope, session_id=session_id, state=state) + if binding is None or isinstance(binding, EnvelopeBinding): + return binding + raise TypeError("hook.bind_envelope must return EnvelopeBinding or None") + _SKIP_VALUE = object() diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 6bb5da2b..24186579 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -6,10 +6,9 @@ import pluggy -from bub.runtime import AsyncStreamEvents from bub.tape import AsyncTapeStore, TapeContext, TapeStore from bub.turn_admission import AdmitDecision, TurnSnapshot -from bub.types import Envelope, MessageHandler, State +from bub.types import Envelope, EnvelopeBinding, MessageHandler, State if TYPE_CHECKING: from bub.channels.base import Channel @@ -37,13 +36,18 @@ def build_prompt(self, message: Envelope, session_id: str, state: State) -> str raise NotImplementedError @hookspec(firstresult=True) - def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: - """Run model for one turn and return plain text output. Should not be implemented if `run_model_stream` is implemented.""" + def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: + """Run model for one turn and return an output envelope. Should not be implemented if `run_model_stream` is implemented.""" raise NotImplementedError @hookspec(firstresult=True) - def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: - """Run model for one turn and return a stream of events. Should not be implemented if `run_model` is implemented.""" + def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: + """Run model for one turn and return an envelope with stream capabilities.""" + raise NotImplementedError + + @hookspec(firstresult=True) + def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None: + """Bind producer-defined capabilities to an envelope.""" raise NotImplementedError @hookspec @@ -57,7 +61,7 @@ def save_state( session_id: str, state: State, message: Envelope, - model_output: str, + model_output: Envelope, ) -> None: """Persist state updates after one model turn.""" @@ -67,7 +71,7 @@ def render_outbound( message: Envelope, session_id: str, state: State, - model_output: str, + model_output: Envelope, ) -> list[Envelope]: """Render outbound messages from model output.""" raise NotImplementedError diff --git a/src/bub/runtime.py b/src/bub/runtime.py deleted file mode 100644 index 93b05ca5..00000000 --- a/src/bub/runtime.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Small runtime primitives shared by Bub core and channels.""" - -from __future__ import annotations - -from collections.abc import AsyncIterator -from dataclasses import dataclass -from enum import StrEnum -from typing import Any, Literal - - -class ErrorKind(StrEnum): - """Stable error kinds for runtime decisions.""" - - INVALID_INPUT = "invalid_input" - CONFIG = "config" - PROVIDER = "provider" - TOOL = "tool" - TEMPORARY = "temporary" - NOT_FOUND = "not_found" - UNKNOWN = "unknown" - - -@dataclass(frozen=True) -class BubError(Exception): - """Public error type for Bub runtime failures.""" - - kind: ErrorKind - message: str - details: dict[str, Any] | None = None - - def __str__(self) -> str: - return f"[{self.kind.value}] {self.message}" - - def as_dict(self) -> dict[str, Any]: - payload: dict[str, Any] = { - "kind": self.kind.value, - "message": self.message, - } - if self.details: - payload["details"] = self.details - return payload - - -@dataclass -class StreamState: - error: BubError | None = None - usage: dict[str, Any] | None = None - - -@dataclass(frozen=True) -class StreamEvent: - kind: Literal["text", "reasoning", "tool_call", "tool_result", "usage", "error", "final"] - data: dict[str, Any] - - -class AsyncStreamEvents: - def __init__(self, iterator: AsyncIterator[StreamEvent], *, state: StreamState | None = None) -> None: - self._iterator = iterator - self._state = state or StreamState() - - def __aiter__(self) -> AsyncIterator[StreamEvent]: - return self._iterator - - @property - def error(self) -> BubError | None: - return self._state.error - - @property - def usage(self) -> dict[str, Any] | None: - return self._state.usage diff --git a/src/bub/tape.py b/src/bub/tape.py index 29886348..23e5074f 100644 --- a/src/bub/tape.py +++ b/src/bub/tape.py @@ -13,7 +13,7 @@ from typing_extensions import TypeIs -from bub.runtime import BubError, ErrorKind +from bub.errors import BubError, ErrorKind def utc_now() -> str: diff --git a/src/bub/tools.py b/src/bub/tools.py index 3c6a4525..5104d9fa 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -11,7 +11,7 @@ from loguru import logger from pydantic import BaseModel, TypeAdapter, ValidationError, validate_call -from bub.runtime import BubError, ErrorKind +from bub.errors import BubError, ErrorKind @dataclass(frozen=True) diff --git a/src/bub/types.py b/src/bub/types.py index 65ad88af..14cd3d28 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -4,9 +4,7 @@ from collections.abc import AsyncIterable, Callable, Coroutine from dataclasses import dataclass, field -from typing import Any, Protocol - -from bub.runtime import StreamEvent +from typing import Any, Protocol, runtime_checkable type Envelope = Any type State = dict[str, Any] @@ -14,9 +12,17 @@ type OutboundDispatcher = Callable[[Envelope], Coroutine[Any, Any, bool]] +@runtime_checkable +class EnvelopeBinding(Protocol): + """Capabilities attached to an envelope by the producer that understands it.""" + + def stream(self) -> AsyncIterable[Envelope] | None: ... + def output(self) -> Envelope | None: ... + + class OutboundChannelRouter(Protocol): async def dispatch_output(self, message: Envelope) -> bool: ... - def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ... + def wrap_stream(self, message: Envelope, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: ... async def quit(self, session_id: str) -> None: ... @@ -26,5 +32,5 @@ class TurnResult: session_id: str prompt: str - model_output: str + model_output: Envelope outbounds: list[Envelope] = field(default_factory=list) diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index ba470306..b586d3bc 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -8,9 +8,9 @@ import pytest from any_llm.types.completion import ChatCompletionChunk -from bub.builtin.agent import Agent +from bub.builtin.agent import Agent, BuiltinModelStream from bub.builtin.settings import AgentSettings -from bub.runtime import BubError +from bub.errors import BubError from bub.tape import Tape, TapeContext from bub.tools import REGISTRY, tool @@ -60,6 +60,13 @@ async def _chat_stream(content: str) -> AsyncIterator[ChatCompletionChunk]: yield _chat_chunk(content) +async def _consume_stream(stream: BuiltinModelStream) -> list[Any]: + events = stream.stream() + if events is None: + return [] + return [event async for event in events] + + class _ForkCapture: """Captures fork_tape enter and exit behavior.""" @@ -146,7 +153,7 @@ async def test_agent_run_regular_session_merges_back() -> None: assert fork_capture.merge_back_values == [True] assert fork_capture.exit_count == 0 - [event async for event in result] + await _consume_stream(result) assert fork_capture.merge_back_values == [True] assert fork_capture.exit_count == 1 @@ -164,7 +171,7 @@ async def test_agent_run_temp_session_does_not_merge_back() -> None: assert fork_capture.merge_back_values == [False] assert fork_capture.exit_count == 0 - [event async for event in result] + await _consume_stream(result) assert fork_capture.merge_back_values == [False] assert fork_capture.exit_count == 1 @@ -184,7 +191,7 @@ async def test_agent_run_passes_model_to_llm() -> None: state={"_runtime_workspace": "/tmp"}, # noqa: S108 model="openai:gpt-4o", ) - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs["model"] == "openai:gpt-4o" @@ -195,11 +202,14 @@ async def test_agent_run_empty_prompt_returns_error() -> None: agent.tapes = MagicMock() # type: ignore[assignment] result = await agent.run_stream(session_id="user/s1", prompt="", state={}) - events = [event async for event in result] - - assert [(event.kind, event.data) for event in events] == [ - ("text", {"delta": "error: empty prompt"}), - ("final", {"ok": False, "text": "error: empty prompt"}), + events = await _consume_stream(result) + + assert events == [ + { + "content": "error: empty prompt", + "source": {"kind": "text", "data": {"delta": "error: empty prompt"}}, + }, + {"end": True, "source": {"kind": "final", "data": {"ok": False, "text": "error: empty prompt"}}}, ] @@ -212,7 +222,7 @@ async def test_agent_run_model_defaults_to_none() -> None: agent.tapes = fake_tapes # type: ignore[assignment] result = await agent.run_stream(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs["model"] == "test:model" @@ -243,7 +253,7 @@ def denied_agent_tool() -> str: state={"_runtime_workspace": "/tmp"}, # noqa: S108 allowed_tools=[" tests_allowed_agent_tool "], ) - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs is not None assert [tool.name for tool in agent.completion_kwargs["tools"]] == ["tests_allowed_agent_tool"] @@ -267,4 +277,4 @@ async def test_agent_run_rejects_unknown_allowed_tools() -> None: ) with pytest.raises(ValueError, match="tests_missing_agent_tool"): - [event async for event in stream] + await _consume_stream(stream) diff --git a/tests/test_builtin_hook_impl.py b/tests/test_builtin_hook_impl.py index 8c8310d1..f982a349 100644 --- a/tests/test_builtin_hook_impl.py +++ b/tests/test_builtin_hook_impl.py @@ -7,11 +7,11 @@ import pytest +from bub.builtin.agent import BuiltinModelStream from bub.builtin.hook_impl import AGENTS_FILE_NAME, DEFAULT_SYSTEM_PROMPT, BuiltinImpl from bub.builtin.store import FileTapeStore from bub.channels.message import ChannelMessage from bub.framework import BubFramework -from bub.runtime import AsyncStreamEvents, StreamEvent class RecordingLifespan: @@ -36,13 +36,13 @@ async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) - self.run_calls.append((session_id, prompt, state)) return "agent-output" - async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents: + async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> BuiltinModelStream: self.run_stream_calls.append((session_id, prompt, state)) async def iterator(): - yield StreamEvent("text", {"delta": "agent-output"}) + yield {"kind": "text", "data": {"delta": "agent-output"}} - return AsyncStreamEvents(iterator()) + return BuiltinModelStream(iterator()) def _raise_value_error() -> None: @@ -169,9 +169,14 @@ async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: state = {"context": "ctx"} stream = await impl.run_model_stream(prompt="prompt", session_id="session", state=state) - events = [event async for event in stream] - - assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})] + binding = impl.bind_envelope(stream, session_id="session", state=state) + assert binding is stream + events = binding.stream() + assert events is not None + + assert [event async for event in events] == [ + {"content": "agent-output", "source": {"kind": "text", "data": {"delta": "agent-output"}}} + ] assert agent.run_stream_calls == [("session", "prompt", state)] assert agent.run_calls == [] diff --git a/tests/test_builtin_tools.py b/tests/test_builtin_tools.py index d5bc7e56..1268b06e 100644 --- a/tests/test_builtin_tools.py +++ b/tests/test_builtin_tools.py @@ -20,7 +20,7 @@ render_tools_prompt, resolve_tool_names, ) -from bub.runtime import ErrorKind +from bub.errors import ErrorKind from bub.tools import REGISTRY, Tool, ToolContext, ToolExecutor, tool diff --git a/tests/test_channels.py b/tests/test_channels.py index 079e1db6..5f14a591 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -15,7 +15,6 @@ from bub.channels.manager import ChannelManager from bub.channels.message import ChannelMessage from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser -from bub.runtime import StreamEvent from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer @@ -674,17 +673,17 @@ async def test_cli_channel_stream_events_prints_stream_and_yields_events(monkeyp message = _message("ignored", channel="cli", kind="command", session_id="cli:1") - async def source() -> asyncio.AsyncIterator[StreamEvent]: - yield StreamEvent("text", {"delta": " "}) - yield StreamEvent("text", {"delta": "hel"}) - yield StreamEvent("text", {"delta": "lo"}) - yield StreamEvent("final", {}) + async def source() -> asyncio.AsyncIterator[dict[str, str]]: + yield {"content": " "} + yield {"content": "hel"} + yield {"content": "lo"} + yield {"end": True} yielded = [event async for event in channel.stream_events(message, source())] assert heads == ["command"] assert printed == [("hel", "", False), ("lo", "", False), ("", None, None)] - assert [event.kind for event in yielded] == ["text", "text", "final"] + assert yielded == [{"content": " "}, {"content": "hel"}, {"content": "lo"}, {"end": True}] def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: diff --git a/tests/test_framework.py b/tests/test_framework.py index f39b8b03..efc2c2a0 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -19,7 +19,6 @@ from bub.configure import ensure_config from bub.framework import BubFramework from bub.hookspecs import hookimpl -from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot @@ -42,6 +41,22 @@ async def stop(self) -> None: return NamedChannelImpl() +class _BoundEnvelope: + def __init__(self, events: list[dict[str, Any]], output: Any) -> None: + self._events = events + self._output = output + + def stream(self): + async def iterator(): + for event in self._events: + yield event + + return iterator() + + def output(self): + return self._output + + def test_create_cli_app_sets_workspace_and_context(tmp_path: Path) -> None: framework = BubFramework() @@ -337,7 +352,7 @@ def admit_message(self, session_id, message, turn): async def test_process_inbound_streams_when_requested() -> None: # noqa: C901 framework = BubFramework() stream_calls: list[str] = [] - wrapped_events: list[str] = [] + wrapped_events: list[dict[str, Any]] = [] class StreamingPlugin: @hookimpl @@ -355,13 +370,14 @@ def build_prompt(self, message, session_id, state) -> str: @hookimpl async def run_model_stream(self, prompt, session_id, state): stream_calls.append(prompt) + return _BoundEnvelope( + [{"chunk": "stream"}, {"chunk": "ed"}, {"done": True}], + "streamed", + ) - async def iterator(): - yield StreamEvent("text", {"delta": "stream"}) - yield StreamEvent("text", {"delta": "ed"}) - yield StreamEvent("final", {"text": "streamed", "ok": True}) - - return AsyncStreamEvents(iterator(), state=StreamState()) + @hookimpl + def bind_envelope(self, envelope, session_id, state): + return envelope if isinstance(envelope, _BoundEnvelope) else None @hookimpl async def save_state(self, session_id, state, message, model_output) -> None: @@ -379,7 +395,7 @@ class RecordingRouter: def wrap_stream(self, message, stream): async def iterator(): async for event in stream: - wrapped_events.append(event.kind) + wrapped_events.append(event) yield event return iterator() @@ -399,5 +415,5 @@ async def quit(self, session_id: str) -> None: ) assert stream_calls == ["prompt"] - assert wrapped_events == ["text", "text", "final"] + assert wrapped_events == [{"chunk": "stream"}, {"chunk": "ed"}, {"done": True}] assert result.model_output == "streamed" diff --git a/tests/test_hook_runtime.py b/tests/test_hook_runtime.py index 78088858..3b19ac88 100644 --- a/tests/test_hook_runtime.py +++ b/tests/test_hook_runtime.py @@ -1,9 +1,11 @@ +from collections.abc import AsyncIterable, AsyncIterator + import pluggy import pytest from bub.hook_runtime import HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs, hookimpl -from bub.runtime import AsyncStreamEvents, StreamEvent +from bub.types import Envelope def _runtime_with_plugins(*plugins: tuple[str, object]) -> HookRuntime: @@ -14,6 +16,22 @@ def _runtime_with_plugins(*plugins: tuple[str, object]) -> HookRuntime: return HookRuntime(manager) +class _BoundEnvelope: + def __init__(self, events: list[Envelope], output: Envelope) -> None: + self._events = events + self._output = output + + def stream(self) -> AsyncIterable[Envelope] | None: + async def iterator() -> AsyncIterator[Envelope]: + for event in self._events: + yield event + + return iterator() + + def output(self) -> Envelope | None: + return self._output + + @pytest.mark.asyncio async def test_call_first_respects_priority_and_returns_first_non_none() -> None: called: list[str] = [] @@ -112,11 +130,14 @@ async def test_run_model_uses_streaming_hook_when_plain_hook_absent() -> None: class StreamPlugin: @hookimpl async def run_model_stream(self, prompt, session_id, state): - async def iterator(): - yield StreamEvent("text", {"delta": "stream"}) - yield StreamEvent("text", {"delta": "ed"}) + return _BoundEnvelope( + [{"content": "stream"}, {"content": "ed"}], + "streamed", + ) - return AsyncStreamEvents(iterator()) + @hookimpl + def bind_envelope(self, envelope, session_id, state): + return envelope if isinstance(envelope, _BoundEnvelope) else None runtime = _runtime_with_plugins(("stream", StreamPlugin())) @@ -136,6 +157,4 @@ async def run_model(self, prompt, session_id, state): stream = await runtime.run_model_stream(prompt="hello", session_id="s", state={}) - assert stream is not None - events = [event async for event in stream] - assert [(event.kind, event.data) for event in events] == [("text", {"delta": "plain"})] + assert stream == "plain" diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index 6cf3853b..61382b9c 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -5,8 +5,8 @@ import pytest +from bub.builtin.agent import BuiltinModelStream from bub.builtin.tools import run_subagent -from bub.runtime import AsyncStreamEvents, StreamEvent from bub.tools import REGISTRY, tool @@ -22,11 +22,11 @@ class FakeAgent: def __init__(self) -> None: self.run_stream = AsyncMock(side_effect=self._run_stream) - async def _run_stream(self, **kwargs: Any) -> AsyncStreamEvents: + async def _run_stream(self, **kwargs: Any) -> BuiltinModelStream: async def iterator(): - yield StreamEvent("text", {"delta": "agent result"}) + yield {"kind": "text", "data": {"delta": "agent result"}} - return AsyncStreamEvents(iterator()) + return BuiltinModelStream(iterator()) @pytest.mark.asyncio From 43bc6ec7ad4774d91c1fe784715098d970e75cfa Mon Sep 17 00:00:00 2001 From: Chojan Shang Date: Sun, 14 Jun 2026 23:40:59 +0800 Subject: [PATCH 2/2] refactor: align native runtime streaming --- src/bub/builtin/agent.py | 150 +++++++++++++++++-------------- src/bub/builtin/hook_impl.py | 3 +- src/bub/builtin/tools.py | 15 ++-- src/bub/channels/cli/__init__.py | 4 +- tests/test_builtin_agent.py | 40 ++++++--- tests/test_builtin_hook_impl.py | 10 +-- tests/test_subagent_tool.py | 8 +- 7 files changed, 134 insertions(+), 96 deletions(-) diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 65926ba1..0c01f653 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -32,7 +32,6 @@ from bub.builtin.settings import ModelCandidate, load_settings from bub.builtin.store import ForkTapeStore from bub.builtin.tape import TapeService -from bub.envelope import field_of from bub.errors import BubError, ErrorKind from bub.framework import BubFramework from bub.skills import discover_skills, render_skills_prompt @@ -56,19 +55,43 @@ TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any]) -def _stream_event(kind: str, data: dict[str, Any] | None = None) -> Envelope: - return {"kind": kind, "data": data or {}} +@dataclass(frozen=True) +class StreamEvent: + kind: Literal["text", "reasoning", "tool_call", "tool_result", "usage", "error", "final"] + data: dict[str, Any] + + +@dataclass +class StreamState: + error: BubError | None = None + usage: dict[str, Any] | None = None + + +class AsyncStreamEvents: + def __init__(self, iterator: AsyncIterator[StreamEvent], *, state: StreamState | None = None) -> None: + self._iterator = iterator + self._state = state or StreamState() + def __aiter__(self) -> AsyncIterator[StreamEvent]: + return self._iterator -def _event_data(event: Envelope) -> dict[str, Any]: - data = field_of(event, "data", {}) - return data if isinstance(data, dict) else {} + @property + def state(self) -> StreamState: + return self._state + + @property + def error(self) -> BubError | None: + return self._state.error + + @property + def usage(self) -> dict[str, Any] | None: + return self._state.usage class BuiltinModelStream: """Builtin-owned stream envelope and binding.""" - def __init__(self, events: AsyncIterable[Envelope]) -> None: + def __init__(self, events: AsyncStreamEvents) -> None: self._events = events self._output_parts: list[str] = [] self._stream_started = False @@ -80,11 +103,11 @@ def stream(self) -> AsyncIterable[Envelope] | None: async def iterator() -> AsyncIterator[Envelope]: async for event in self._events: - if field_of(event, "kind") == "text": - delta = str(_event_data(event).get("delta", "")) + if event.kind == "text": + delta = str(event.data.get("delta", "")) self._output_parts.append(delta) yield {"content": delta, "source": event} - elif field_of(event, "kind") == "final": + elif event.kind == "final": yield {"end": True, "source": event} else: yield event @@ -113,25 +136,25 @@ def tapes(self) -> TapeService: return TapeService(bub.home / "tapes", tape_store, self.framework.build_tape_context()) @staticmethod - def _events_from_iterable(iterable: Iterable[Envelope]) -> AsyncIterable[Envelope]: - async def generator() -> AsyncIterator[Envelope]: + def _events_from_iterable(iterable: Iterable[StreamEvent]) -> AsyncStreamEvents: + async def generator() -> AsyncIterator[StreamEvent]: for item in iterable: yield item - return generator() + return AsyncStreamEvents(generator()) @staticmethod def _events_with_callback( - events: AsyncIterable[Envelope], callback: Callable[[], Coroutine[Any, Any, Any]] - ) -> AsyncIterable[Envelope]: - async def generator() -> AsyncIterator[Envelope]: + events: AsyncStreamEvents, callback: Callable[[], Coroutine[Any, Any, Any]] + ) -> AsyncStreamEvents: + async def generator() -> AsyncIterator[StreamEvent]: try: async for event in events: yield event finally: await callback() - return generator() + return AsyncStreamEvents(generator(), state=events.state) async def run( self, @@ -152,12 +175,9 @@ async def run( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - events = stream.stream() - if events is not None: - async for _event in events: - pass - if result := stream.output(): - output.append(str(result)) + async for event in stream: + if event.kind == "text": + output.append(str(event.data.get("delta", ""))) return "".join(output) async def run_stream( @@ -169,14 +189,12 @@ async def run_stream( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> BuiltinModelStream: + ) -> AsyncStreamEvents: if not prompt: - return BuiltinModelStream( - self._events_from_iterable([ - _stream_event("text", {"delta": "error: empty prompt"}), - _stream_event("final", {"text": "error: empty prompt", "ok": False}), - ]) - ) + return self._events_from_iterable([ + StreamEvent("text", {"delta": "error: empty prompt"}), + StreamEvent("final", {"text": "error: empty prompt", "ok": False}), + ]) tape = self.tapes.session_tape(session_id, workspace_from_state(state)) tape.context = replace(tape.context, state=state) @@ -188,8 +206,8 @@ async def run_stream( if isinstance(prompt, str) and prompt.strip().startswith(","): result = await self._run_command(tape=tape, line=prompt.strip()) events = self._events_from_iterable([ - _stream_event("text", {"delta": result}), - _stream_event("final", {"text": result, "ok": True}), + StreamEvent("text", {"delta": result}), + StreamEvent("final", {"text": result, "ok": True}), ]) else: events = await self._agent_loop( @@ -199,7 +217,7 @@ async def run_stream( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - return BuiltinModelStream(self._events_with_callback(events, callback=stack.aclose)) + return self._events_with_callback(events, callback=stack.aclose) async def _run_command(self, tape: Tape, *, line: str) -> str: line = line[1:].strip() @@ -249,7 +267,7 @@ async def _agent_loop( model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncIterable[Envelope]: + ) -> AsyncStreamEvents: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model await self.tapes.append_event( @@ -262,7 +280,7 @@ async def _agent_loop( "allowed_tools": list(allowed_tools) if allowed_tools else None, }, ) - state: dict[str, Any] = {} + state = StreamState() iterator = self._stream_events_with_auto_handoff( tape=tape, prompt=next_prompt, @@ -271,17 +289,17 @@ async def _agent_loop( allowed_skills=allowed_skills, allowed_tools=allowed_tools, ) - return iterator + return AsyncStreamEvents(iterator, state=state) async def _stream_events_with_auto_handoff( self, tape: Tape, prompt: str | list[dict], - state: dict[str, Any], + state: StreamState, model: str | None = None, allowed_skills: Collection[str] | None = None, allowed_tools: Collection[str] | None = None, - ) -> AsyncGenerator[Envelope, None]: + ) -> AsyncGenerator[StreamEvent, None]: auto_handoff_remaining = MAX_AUTO_HANDOFF_RETRIES display_model = model or self.settings.model next_prompt = prompt @@ -300,9 +318,7 @@ async def _stream_events_with_auto_handoff( ) async for event in output: yield event - kind = field_of(event, "kind") - data = _event_data(event) - if kind == "error": + if event.kind == "error": elapsed_ms = int((time.monotonic() - start) * 1000) await self.tapes.append_event( tape.name, @@ -311,12 +327,12 @@ async def _stream_events_with_auto_handoff( "step": step, "elapsed_ms": elapsed_ms, "status": "error", - "error": data.get("message", ""), + "error": event.data.get("message", ""), "date": datetime.now(UTC).isoformat(), }, ) - elif kind == "final": - should_continue = bool(data.get("tool_calls") or data.get("tool_results")) + elif event.kind == "final": + should_continue = bool(event.data.get("tool_calls") or event.data.get("tool_results")) except Exception as exc: error_message = f"{exc!s}" elapsed_ms = int((time.monotonic() - start) * 1000) @@ -359,6 +375,8 @@ async def _stream_events_with_auto_handoff( ) raise + state.error = output.error + state.usage = output.usage elapsed_ms = int((time.monotonic() - start) * 1000) if not should_continue: await self.tapes.append_event( @@ -404,7 +422,7 @@ async def _run_once( model: str | None = None, allowed_tools: Collection[str] | None = None, allowed_skills: Collection[str] | None = None, - ) -> AsyncIterable[Envelope]: + ) -> AsyncStreamEvents: prompt_text = prompt if isinstance(prompt, str) else _extract_text_from_parts(prompt) if allowed_tools is not None: from bub.builtin.tools import resolve_tool_names @@ -435,10 +453,10 @@ async def _run_once_stream( model: str | None, allowed_skills: set[str] | None, tools: list[Tool], - ) -> AsyncIterable[Envelope]: - state: dict[str, Any] = {} + ) -> AsyncStreamEvents: + state = StreamState() - async def iterator() -> AsyncGenerator[Envelope, None]: + async def iterator() -> AsyncGenerator[StreamEvent, None]: system_prompt = self._system_prompt( prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools ) @@ -492,7 +510,7 @@ async def iterator() -> AsyncGenerator[Envelope, None]: tool_invocations = [ _tool_invocation_from_native(tool_call, tool_map) for tool_call in native_tool_calls ] - yield _stream_event("tool_call", {"tool_calls": serialized_tool_calls}) + yield StreamEvent("tool_call", {"tool_calls": serialized_tool_calls}) context = ToolContext(tape=tape.name, run_id=run_id, state=tape.context.state) execution = await ToolExecutor().execute_async( tool_invocations, @@ -508,10 +526,10 @@ async def iterator() -> AsyncGenerator[Envelope, None]: tool_results=execution.tool_results, response=response, model=model or self.settings.model, - usage=state.get("usage"), + usage=state.usage, ) - yield _stream_event("tool_result", {"tool_results": execution.tool_results}) - yield _stream_event( + yield StreamEvent("tool_result", {"tool_results": execution.tool_results}) + yield StreamEvent( "final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results}, ) @@ -525,11 +543,11 @@ async def iterator() -> AsyncGenerator[Envelope, None]: response_text=text, response=response, model=model or self.settings.model, - usage=state.get("usage"), + usage=state.usage, ) - yield _stream_event("final", {"ok": True, "text": text}) + yield StreamEvent("final", {"ok": True, "text": text}) - return iterator() + return AsyncStreamEvents(iterator(), state=state) def _build_llm(self, candidate: ModelCandidate) -> AnyLLM: return AnyLLM.create( @@ -655,13 +673,13 @@ def _parse_native_function_call(tool_call: ChatCompletionMessageToolCall) -> tup async def _completion_events( completion: ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk], - state: dict[str, Any], + state: StreamState, text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[Envelope, None]: +) -> AsyncGenerator[StreamEvent, None]: if isinstance(completion, ChatCompletion): if usage := TapeService._extract_usage(completion): - state["usage"] = usage + state.usage = usage message = completion.choices[0].message for event in _completion_message_events(message, text_parts, tool_calls): yield event @@ -676,30 +694,30 @@ def _completion_message_events( message: ChatCompletionMessage, text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> Iterable[Envelope]: +) -> Iterable[StreamEvent]: if message.reasoning: - yield _stream_event("reasoning", {"delta": _reasoning_text(message.reasoning)}) + yield StreamEvent("reasoning", {"delta": _reasoning_text(message.reasoning)}) if message.content: text_parts.append(message.content) - yield _stream_event("text", {"delta": message.content}) + yield StreamEvent("text", {"delta": message.content}) tool_calls.add_message_calls(cast("Iterable[ChatCompletionMessageToolCall]", message.tool_calls or [])) async def _completion_chunk_events( chunk: ChatCompletionChunk, - state: dict[str, Any], + state: StreamState, text_parts: list[str], tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[Envelope, None]: +) -> AsyncGenerator[StreamEvent, None]: if usage := TapeService._extract_usage(chunk): - state["usage"] = usage + state.usage = usage for choice in chunk.choices: delta = choice.delta if delta.reasoning: - yield _stream_event("reasoning", {"delta": _reasoning_text(delta.reasoning)}) + yield StreamEvent("reasoning", {"delta": _reasoning_text(delta.reasoning)}) if delta.content: text_parts.append(delta.content) - yield _stream_event("text", {"delta": delta.content}) + yield StreamEvent("text", {"delta": delta.content}) if delta.tool_calls: tool_calls.merge_delta_calls(delta.tool_calls) diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index e9a73ac7..18d8d6d3 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -160,7 +160,8 @@ async def run_model(self, prompt: str | list[dict], session_id: str, state: Stat @hookimpl async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: - return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state) + stream = await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state) + return BuiltinModelStream(stream) @hookimpl def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None: diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index 88c89af8..51484d52 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -359,19 +359,20 @@ async def run_subagent(param: SubAgentInput, *, context: ToolContext) -> str: subagent_session = param.session state = {**context.state, "session_id": subagent_session} allowed_tools = resolve_tool_names(param.allowed_tools or None, exclude={"subagent"}) - stream = await agent.run_stream( + output = "" + async for event in await agent.run_stream( session_id=subagent_session, prompt=param.prompt, state=state, model=param.model, allowed_tools=allowed_tools, allowed_skills=param.allowed_skills, - ) - events = stream.stream() - if events is not None: - async for _event in events: - pass - return str(stream.output() or "") + ): + if event.kind == "error": + output += f"[Error: {event.data.get('message', 'unknown error')}]" + elif event.kind == "text": + output += str(event.data.get("delta", "")) + return output @tool(name="help") diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 1df6d6b0..d68799c7 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -247,8 +247,8 @@ async def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Env expand_thinking=self._expand_thinking, ) async for event in stream: - if printer.render(event): - yield event + printer.render(event) + yield event def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index b586d3bc..bc436fe7 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -8,7 +8,7 @@ import pytest from any_llm.types.completion import ChatCompletionChunk -from bub.builtin.agent import Agent, BuiltinModelStream +from bub.builtin.agent import Agent, AsyncStreamEvents, StreamEvent, StreamState from bub.builtin.settings import AgentSettings from bub.errors import BubError from bub.tape import Tape, TapeContext @@ -60,11 +60,8 @@ async def _chat_stream(content: str) -> AsyncIterator[ChatCompletionChunk]: yield _chat_chunk(content) -async def _consume_stream(stream: BuiltinModelStream) -> list[Any]: - events = stream.stream() - if events is None: - return [] - return [event async for event in events] +async def _consume_stream(stream: AsyncStreamEvents) -> list[StreamEvent]: + return [event async for event in stream] class _ForkCapture: @@ -205,14 +202,35 @@ async def test_agent_run_empty_prompt_returns_error() -> None: events = await _consume_stream(result) assert events == [ - { - "content": "error: empty prompt", - "source": {"kind": "text", "data": {"delta": "error: empty prompt"}}, - }, - {"end": True, "source": {"kind": "final", "data": {"ok": False, "text": "error: empty prompt"}}}, + StreamEvent("text", {"delta": "error: empty prompt"}), + StreamEvent("final", {"ok": False, "text": "error: empty prompt"}), ] +@pytest.mark.asyncio +async def test_agent_loop_preserves_run_once_stream_state() -> None: + agent = _make_agent() + fork_capture = _ForkCapture() + fake_tapes = _FakeTapeService(fork_capture) + agent.tapes = fake_tapes # type: ignore[assignment] + state = StreamState(usage={"total_tokens": 3}) + + async def iterator() -> AsyncIterator[StreamEvent]: + yield StreamEvent("text", {"delta": "done"}) + yield StreamEvent("final", {"ok": True, "text": "done"}) + + async def fake_run_once(**kwargs: Any) -> AsyncStreamEvents: + return AsyncStreamEvents(iterator(), state=state) + + agent._run_once = fake_run_once # type: ignore[method-assign] + tape = Tape(name="test-tape", context=TapeContext(state={})) + + events = await agent._agent_loop(tape=tape, prompt="hello") + [event async for event in events] + + assert events.usage == {"total_tokens": 3} + + @pytest.mark.asyncio async def test_agent_run_model_defaults_to_none() -> None: """When model is not specified, settings.model is used for any-llm.""" diff --git a/tests/test_builtin_hook_impl.py b/tests/test_builtin_hook_impl.py index f982a349..a3ecf497 100644 --- a/tests/test_builtin_hook_impl.py +++ b/tests/test_builtin_hook_impl.py @@ -7,7 +7,7 @@ import pytest -from bub.builtin.agent import BuiltinModelStream +from bub.builtin.agent import AsyncStreamEvents, StreamEvent from bub.builtin.hook_impl import AGENTS_FILE_NAME, DEFAULT_SYSTEM_PROMPT, BuiltinImpl from bub.builtin.store import FileTapeStore from bub.channels.message import ChannelMessage @@ -36,13 +36,13 @@ async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) - self.run_calls.append((session_id, prompt, state)) return "agent-output" - async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> BuiltinModelStream: + async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents: self.run_stream_calls.append((session_id, prompt, state)) async def iterator(): - yield {"kind": "text", "data": {"delta": "agent-output"}} + yield StreamEvent("text", {"delta": "agent-output"}) - return BuiltinModelStream(iterator()) + return AsyncStreamEvents(iterator()) def _raise_value_error() -> None: @@ -175,7 +175,7 @@ async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: assert events is not None assert [event async for event in events] == [ - {"content": "agent-output", "source": {"kind": "text", "data": {"delta": "agent-output"}}} + {"content": "agent-output", "source": StreamEvent("text", {"delta": "agent-output"})} ] assert agent.run_stream_calls == [("session", "prompt", state)] assert agent.run_calls == [] diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index 61382b9c..7ee9f220 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -5,7 +5,7 @@ import pytest -from bub.builtin.agent import BuiltinModelStream +from bub.builtin.agent import AsyncStreamEvents, StreamEvent from bub.builtin.tools import run_subagent from bub.tools import REGISTRY, tool @@ -22,11 +22,11 @@ class FakeAgent: def __init__(self) -> None: self.run_stream = AsyncMock(side_effect=self._run_stream) - async def _run_stream(self, **kwargs: Any) -> BuiltinModelStream: + async def _run_stream(self, **kwargs: Any) -> AsyncStreamEvents: async def iterator(): - yield {"kind": "text", "data": {"delta": "agent result"}} + yield StreamEvent("text", {"delta": "agent result"}) - return BuiltinModelStream(iterator()) + return AsyncStreamEvents(iterator()) @pytest.mark.asyncio