diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index 54a51ce1..0c01f653 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -7,7 +7,7 @@ import re import shlex import time -from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator, Callable, Collection, Coroutine, Iterable from contextlib import AsyncExitStack from dataclasses import dataclass, replace from datetime import UTC, datetime @@ -32,8 +32,8 @@ from bub.builtin.settings import ModelCandidate, load_settings from bub.builtin.store import ForkTapeStore from bub.builtin.tape import TapeService +from bub.errors import BubError, ErrorKind from bub.framework import BubFramework -from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState from bub.skills import discover_skills, render_skills_prompt from bub.tape import InMemoryTapeStore, Tape from bub.tools import ( @@ -42,7 +42,7 @@ ToolContext, ToolExecutor, ) -from bub.types import State +from bub.types import Envelope, State from bub.utils import workspace_from_state CONTINUE_PROMPT = "Continue the task until all targets are completed." @@ -55,6 +55,69 @@ TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any]) +@dataclass(frozen=True) +class StreamEvent: + kind: Literal["text", "reasoning", "tool_call", "tool_result", "usage", "error", "final"] + data: dict[str, Any] + + +@dataclass +class StreamState: + error: BubError | None = None + usage: dict[str, Any] | None = None + + +class AsyncStreamEvents: + def __init__(self, iterator: AsyncIterator[StreamEvent], *, state: StreamState | None = None) -> None: + self._iterator = iterator + self._state = state or StreamState() + + def __aiter__(self) -> AsyncIterator[StreamEvent]: + return self._iterator + + @property + def state(self) -> StreamState: + return self._state + + @property + def error(self) -> BubError | None: + return self._state.error + + @property + def usage(self) -> dict[str, Any] | None: + return self._state.usage + + +class BuiltinModelStream: + """Builtin-owned stream envelope and binding.""" + + def __init__(self, events: AsyncStreamEvents) -> None: + self._events = events + self._output_parts: list[str] = [] + self._stream_started = False + + def stream(self) -> AsyncIterable[Envelope] | None: + if self._stream_started: + return None + self._stream_started = True + + async def iterator() -> AsyncIterator[Envelope]: + async for event in self._events: + if event.kind == "text": + delta = str(event.data.get("delta", "")) + self._output_parts.append(delta) + yield {"content": delta, "source": event} + elif event.kind == "final": + yield {"end": True, "source": event} + else: + yield event + + return iterator() + + def output(self) -> Envelope | None: + return "".join(self._output_parts) + + class Agent: """Agent that processes prompts using hooks, tools, tape, and any-llm-sdk.""" @@ -73,8 +136,8 @@ def tapes(self) -> TapeService: return TapeService(bub.home / "tapes", tape_store, self.framework.build_tape_context()) @staticmethod - def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents: - async def generator() -> AsyncIterator: + def _events_from_iterable(iterable: Iterable[StreamEvent]) -> AsyncStreamEvents: + async def generator() -> AsyncIterator[StreamEvent]: for item in iterable: yield item @@ -91,7 +154,7 @@ async def generator() -> AsyncIterator[StreamEvent]: finally: await callback() - return AsyncStreamEvents(generator(), state=events._state) + return AsyncStreamEvents(generator(), state=events.state) async def run( self, @@ -467,7 +530,8 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]: ) yield StreamEvent("tool_result", {"tool_results": execution.tool_results}) yield StreamEvent( - "final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results} + "final", + {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results}, ) return diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index 0ed75721..18d8d6d3 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -7,7 +7,7 @@ from loguru import logger from bub import inquirer as bub_inquirer -from bub.builtin.agent import Agent +from bub.builtin.agent import Agent, BuiltinModelStream from bub.builtin.context import default_tape_context from bub.builtin.settings import DEFAULT_MODEL from bub.channels.base import Channel @@ -15,9 +15,8 @@ from bub.envelope import content_of, field_of from bub.framework import BubFramework from bub.hookspecs import hookimpl -from bub.runtime import AsyncStreamEvents from bub.tape import TapeContext, TapeStore -from bub.types import Envelope, MessageHandler, State +from bub.types import Envelope, EnvelopeBinding, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" MODEL_PROVIDER_CHOICES: tuple[str, ...] = ( @@ -120,7 +119,7 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State: return state @hookimpl - async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: str) -> None: + async def save_state(self, session_id: str, state: State, message: ChannelMessage, model_output: Envelope) -> None: tp, value, traceback = sys.exc_info() lifespan = field_of(message, "lifespan") if lifespan is not None: @@ -156,12 +155,19 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St return text @hookimpl - async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: + async def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: return await self._get_agent().run(session_id=session_id, prompt=prompt, state=state) @hookimpl - async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: - return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state) + async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: + stream = await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state) + return BuiltinModelStream(stream) + + @hookimpl + def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None: + if isinstance(envelope, BuiltinModelStream): + return envelope + return None @hookimpl def register_cli_commands(self, app: typer.Typer) -> None: @@ -273,13 +279,13 @@ def render_outbound( message: Envelope, session_id: str, state: State, - model_output: str, + model_output: Envelope, ) -> list[ChannelMessage]: outbound = ChannelMessage( session_id=session_id, channel=field_of(message, "channel", "default"), chat_id=field_of(message, "chat_id", "default"), - content=model_output, + content=content_of(model_output), output_channel=field_of(message, "output_channel", "default"), kind=field_of(message, "kind", "normal"), ) diff --git a/src/bub/builtin/tape.py b/src/bub/builtin/tape.py index ca992521..c89d9bc3 100644 --- a/src/bub/builtin/tape.py +++ b/src/bub/builtin/tape.py @@ -11,7 +11,7 @@ from pydantic import BaseModel from bub.builtin.store import ForkTapeStore -from bub.runtime import BubError +from bub.errors import BubError from bub.tape import AsyncTapeStore, Tape, TapeContext, TapeEntry, TapeQuery, build_messages diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 920469ec..55112d96 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -4,7 +4,7 @@ from typing import ClassVar from bub.channels.message import ChannelMessage -from bub.runtime import StreamEvent +from bub.types import Envelope class Channel(ABC): @@ -35,7 +35,7 @@ async def send(self, message: ChannelMessage) -> None: # Do nothing by default return - def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: + def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: """Optionally wrap the output stream for this channel.""" return stream diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 1b87324b..d68799c7 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -23,10 +23,9 @@ from bub.channels.base import Interface from bub.channels.cli.renderer import CliRenderer from bub.channels.message import ChannelMessage -from bub.envelope import field_of -from bub.runtime import StreamEvent +from bub.envelope import content_of, field_of from bub.tools import REGISTRY -from bub.types import MessageHandler +from bub.types import Envelope, MessageHandler class _StreamPrinter: @@ -39,16 +38,20 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking: self._reasoning_status: Status | None = None self.head_printed = False - def render(self, event: StreamEvent) -> bool: - if event.kind == "reasoning": - self._record_reasoning(str(event.data.get("delta", ""))) + def render(self, event: Envelope) -> bool: + kind = field_of(event, "kind") + data = field_of(event, "data", {}) + data = data if isinstance(data, dict) else {} + if kind == "reasoning": + self._record_reasoning(str(data.get("delta", ""))) return True - if event.kind == "text": - return self._print_content(str(event.data.get("delta", ""))) - elif event.kind == "tool_call": + content = content_of(event) + if content: + return self._print_content(content) + if kind == "tool_call": self._print_stream_boundary() - elif event.kind == "final": + elif kind == "final" or field_of(event, "end"): self._print_end() return True @@ -236,9 +239,7 @@ def _prompt_message(self) -> FormattedText: symbol = ">" if self._mode == "agent" else "," return FormattedText([("bold", f"{cwd} {symbol} ")]) - async def stream_events( - self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] - ) -> AsyncIterable[StreamEvent]: + async def stream_events(self, message: ChannelMessage, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: console = get_console() printer = _StreamPrinter( console=console, @@ -246,8 +247,8 @@ async def stream_events( expand_thinking=self._expand_thinking, ) async for event in stream: - if printer.render(event): - yield event + printer.render(event) + yield event def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 313cc2ce..3a0fa7cf 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -14,7 +14,6 @@ from bub.configure import Settings, ensure_config from bub.envelope import content_of, field_of from bub.framework import BubFramework -from bub.runtime import StreamEvent from bub.turn_admission import AdmitDecision, SessionTurnController from bub.types import Envelope, MessageHandler from bub.utils import wait_until_stopped @@ -105,7 +104,7 @@ async def dispatch_output(self, message: Envelope) -> bool: await channel.send(outbound) return True - def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: + def wrap_stream(self, message: Envelope, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: channel_name = field_of(message, "output_channel", field_of(message, "channel")) if channel_name is None: return stream diff --git a/src/bub/envelope.py b/src/bub/envelope.py index 362d4bc5..48d1f0ab 100644 --- a/src/bub/envelope.py +++ b/src/bub/envelope.py @@ -19,6 +19,8 @@ def field_of(message: Envelope, key: str, default: Any = None) -> Any: def content_of(message: Envelope) -> str: """Get textual content from any envelope shape.""" + if isinstance(message, str): + return message return str(field_of(message, "content", "")) diff --git a/src/bub/errors.py b/src/bub/errors.py new file mode 100644 index 00000000..f8346e88 --- /dev/null +++ b/src/bub/errors.py @@ -0,0 +1,40 @@ +"""Shared Bub error types.""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import StrEnum +from typing import Any + + +class ErrorKind(StrEnum): + """Stable error kinds for Bub failures.""" + + INVALID_INPUT = "invalid_input" + CONFIG = "config" + PROVIDER = "provider" + TOOL = "tool" + TEMPORARY = "temporary" + NOT_FOUND = "not_found" + UNKNOWN = "unknown" + + +@dataclass(frozen=True) +class BubError(Exception): + """Public error type for Bub failures.""" + + kind: ErrorKind + message: str + details: dict[str, Any] | None = None + + def __str__(self) -> str: + return f"[{self.kind.value}] {self.message}" + + def as_dict(self) -> dict[str, Any]: + payload: dict[str, Any] = { + "kind": self.kind.value, + "message": self.message, + } + if self.details: + payload["details"] = self.details + return payload diff --git a/src/bub/framework.py b/src/bub/framework.py index cbece9d5..5a2f5041 100644 --- a/src/bub/framework.py +++ b/src/bub/framework.py @@ -17,7 +17,6 @@ from bub.envelope import content_of, field_of, unpack_batch from bub.hook_runtime import _SKIP_VALUE, HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs -from bub.runtime import BubError, ErrorKind from bub.tape import AsyncTapeStore, TapeContext, TapeStore from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult @@ -158,7 +157,7 @@ async def _run_model( session_id: str, state: dict[str, Any], stream_output: bool, - ) -> str: + ) -> Envelope: if not stream_output: output = await self._hook_runtime.run_model(prompt=prompt, session_id=session_id, state=state) if output is None: @@ -169,28 +168,29 @@ async def _run_model( ) return prompt if isinstance(prompt, str) else content_of(inbound) return output - stream = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) - if stream is None: + stream_envelope = await self._hook_runtime.run_model_stream(prompt=prompt, session_id=session_id, state=state) + if stream_envelope is None: await self._hook_runtime.notify_error( stage="run_model", error=RuntimeError("no model skill returned output"), message=inbound, ) return prompt if isinstance(prompt, str) else content_of(inbound) + binding = await self._hook_runtime.bind_envelope(stream_envelope, session_id=session_id, state=state) + if binding is None: + return stream_envelope else: - parts: list[str] = [] - events = self._outbound_router.wrap_stream(inbound, stream) if self._outbound_router is not None else stream - async for event in events: - if event.kind == "text": - parts.append(str(event.data.get("delta", ""))) - elif event.kind == "error": - # Turn "kind" to enum type otherwise BubError's __str__ won't work well. - data = { - **event.data, - "kind": ErrorKind(event.data.get("kind", "unknown")), - } - await self._hook_runtime.notify_error(stage="run_model", error=BubError(**data), message=inbound) - return "".join(parts) + stream = binding.stream() + if stream is not None: + events = ( + self._outbound_router.wrap_stream(inbound, stream) if self._outbound_router is not None else stream + ) + async for _event in events: + pass + output = binding.output() + if output is None: + return prompt if isinstance(prompt, str) else content_of(inbound) + return output def hook_report(self) -> dict[str, list[str]]: """Return hook implementation summary for diagnostics.""" @@ -244,7 +244,7 @@ async def _collect_outbounds( message: Envelope, session_id: str, state: dict[str, Any], - model_output: str, + model_output: Envelope, ) -> list[Envelope]: batches = await self._hook_runtime.call_many( "render_outbound", diff --git a/src/bub/hook_runtime.py b/src/bub/hook_runtime.py index bdb0eaf7..42dec20d 100644 --- a/src/bub/hook_runtime.py +++ b/src/bub/hook_runtime.py @@ -3,14 +3,12 @@ from __future__ import annotations import inspect -from collections.abc import AsyncGenerator from typing import Any import pluggy from loguru import logger -from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState -from bub.types import Envelope +from bub.types import Envelope, EnvelopeBinding class HookRuntime: @@ -160,41 +158,43 @@ def _iter_hookimpls(self, hook_name: str) -> list[Any]: def _kwargs_for_impl(impl: Any, kwargs: dict[str, Any]) -> dict[str, Any]: return {name: kwargs[name] for name in impl.argnames if name in kwargs} - async def run_model(self, prompt: str | list[dict], session_id: str, state: dict[str, Any]) -> str | None: + async def run_model(self, prompt: str | list[dict], session_id: str, state: dict[str, Any]) -> Envelope | None: """Run the first `run_model` hook found and return its result.""" for _, plugin in reversed(self._plugin_manager.list_name_plugin()): if hasattr(plugin, "run_model"): - output = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) - if output is None or isinstance(output, str): - return output - raise TypeError("hook.run_model must return str or None") + return await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) elif hasattr(plugin, "run_model_stream"): - stream = await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) - text = "" - async for event in stream: - if event.kind == "text": - text += str(event.data.get("delta", "")) - return text + stream_envelope = await self.run_model_stream(prompt=prompt, session_id=session_id, state=state) + binding = await self.bind_envelope(stream_envelope, session_id=session_id, state=state) + if binding is None: + return stream_envelope + stream = binding.stream() + if stream is not None: + async for _event in stream: + pass + return binding.output() return None async def run_model_stream( self, prompt: str | list[dict], session_id: str, state: dict[str, Any] - ) -> AsyncStreamEvents | None: + ) -> Envelope | None: """Run the first `run_model_stream` hook found and fallback to `run_model` hook.""" for _, plugin in reversed(self._plugin_manager.list_name_plugin()): if hasattr(plugin, "run_model_stream"): - stream = await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) - if stream is None or isinstance(stream, AsyncStreamEvents): - return stream - raise TypeError("hook.run_model_stream must return AsyncStreamEvents or None") + return await self.call_first("run_model_stream", prompt=prompt, session_id=session_id, state=state) elif hasattr(plugin, "run_model"): - - async def iterator() -> AsyncGenerator[StreamEvent, None]: - result = await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) - yield StreamEvent("text", {"delta": result}) - - return AsyncStreamEvents(iterator(), state=StreamState()) + return await self.call_first("run_model", prompt=prompt, session_id=session_id, state=state) return None + async def bind_envelope( + self, envelope: Envelope, *, session_id: str, state: dict[str, Any] + ) -> EnvelopeBinding | None: + if envelope is None: + return None + binding = await self.call_first("bind_envelope", envelope=envelope, session_id=session_id, state=state) + if binding is None or isinstance(binding, EnvelopeBinding): + return binding + raise TypeError("hook.bind_envelope must return EnvelopeBinding or None") + _SKIP_VALUE = object() diff --git a/src/bub/hookspecs.py b/src/bub/hookspecs.py index 6bb5da2b..24186579 100644 --- a/src/bub/hookspecs.py +++ b/src/bub/hookspecs.py @@ -6,10 +6,9 @@ import pluggy -from bub.runtime import AsyncStreamEvents from bub.tape import AsyncTapeStore, TapeContext, TapeStore from bub.turn_admission import AdmitDecision, TurnSnapshot -from bub.types import Envelope, MessageHandler, State +from bub.types import Envelope, EnvelopeBinding, MessageHandler, State if TYPE_CHECKING: from bub.channels.base import Channel @@ -37,13 +36,18 @@ def build_prompt(self, message: Envelope, session_id: str, state: State) -> str raise NotImplementedError @hookspec(firstresult=True) - def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> str: - """Run model for one turn and return plain text output. Should not be implemented if `run_model_stream` is implemented.""" + def run_model(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: + """Run model for one turn and return an output envelope. Should not be implemented if `run_model_stream` is implemented.""" raise NotImplementedError @hookspec(firstresult=True) - def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents: - """Run model for one turn and return a stream of events. Should not be implemented if `run_model` is implemented.""" + def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> Envelope: + """Run model for one turn and return an envelope with stream capabilities.""" + raise NotImplementedError + + @hookspec(firstresult=True) + def bind_envelope(self, envelope: Envelope, session_id: str, state: State) -> EnvelopeBinding | None: + """Bind producer-defined capabilities to an envelope.""" raise NotImplementedError @hookspec @@ -57,7 +61,7 @@ def save_state( session_id: str, state: State, message: Envelope, - model_output: str, + model_output: Envelope, ) -> None: """Persist state updates after one model turn.""" @@ -67,7 +71,7 @@ def render_outbound( message: Envelope, session_id: str, state: State, - model_output: str, + model_output: Envelope, ) -> list[Envelope]: """Render outbound messages from model output.""" raise NotImplementedError diff --git a/src/bub/runtime.py b/src/bub/runtime.py deleted file mode 100644 index 93b05ca5..00000000 --- a/src/bub/runtime.py +++ /dev/null @@ -1,70 +0,0 @@ -"""Small runtime primitives shared by Bub core and channels.""" - -from __future__ import annotations - -from collections.abc import AsyncIterator -from dataclasses import dataclass -from enum import StrEnum -from typing import Any, Literal - - -class ErrorKind(StrEnum): - """Stable error kinds for runtime decisions.""" - - INVALID_INPUT = "invalid_input" - CONFIG = "config" - PROVIDER = "provider" - TOOL = "tool" - TEMPORARY = "temporary" - NOT_FOUND = "not_found" - UNKNOWN = "unknown" - - -@dataclass(frozen=True) -class BubError(Exception): - """Public error type for Bub runtime failures.""" - - kind: ErrorKind - message: str - details: dict[str, Any] | None = None - - def __str__(self) -> str: - return f"[{self.kind.value}] {self.message}" - - def as_dict(self) -> dict[str, Any]: - payload: dict[str, Any] = { - "kind": self.kind.value, - "message": self.message, - } - if self.details: - payload["details"] = self.details - return payload - - -@dataclass -class StreamState: - error: BubError | None = None - usage: dict[str, Any] | None = None - - -@dataclass(frozen=True) -class StreamEvent: - kind: Literal["text", "reasoning", "tool_call", "tool_result", "usage", "error", "final"] - data: dict[str, Any] - - -class AsyncStreamEvents: - def __init__(self, iterator: AsyncIterator[StreamEvent], *, state: StreamState | None = None) -> None: - self._iterator = iterator - self._state = state or StreamState() - - def __aiter__(self) -> AsyncIterator[StreamEvent]: - return self._iterator - - @property - def error(self) -> BubError | None: - return self._state.error - - @property - def usage(self) -> dict[str, Any] | None: - return self._state.usage diff --git a/src/bub/tape.py b/src/bub/tape.py index 29886348..23e5074f 100644 --- a/src/bub/tape.py +++ b/src/bub/tape.py @@ -13,7 +13,7 @@ from typing_extensions import TypeIs -from bub.runtime import BubError, ErrorKind +from bub.errors import BubError, ErrorKind def utc_now() -> str: diff --git a/src/bub/tools.py b/src/bub/tools.py index 3c6a4525..5104d9fa 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -11,7 +11,7 @@ from loguru import logger from pydantic import BaseModel, TypeAdapter, ValidationError, validate_call -from bub.runtime import BubError, ErrorKind +from bub.errors import BubError, ErrorKind @dataclass(frozen=True) diff --git a/src/bub/types.py b/src/bub/types.py index 65ad88af..14cd3d28 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -4,9 +4,7 @@ from collections.abc import AsyncIterable, Callable, Coroutine from dataclasses import dataclass, field -from typing import Any, Protocol - -from bub.runtime import StreamEvent +from typing import Any, Protocol, runtime_checkable type Envelope = Any type State = dict[str, Any] @@ -14,9 +12,17 @@ type OutboundDispatcher = Callable[[Envelope], Coroutine[Any, Any, bool]] +@runtime_checkable +class EnvelopeBinding(Protocol): + """Capabilities attached to an envelope by the producer that understands it.""" + + def stream(self) -> AsyncIterable[Envelope] | None: ... + def output(self) -> Envelope | None: ... + + class OutboundChannelRouter(Protocol): async def dispatch_output(self, message: Envelope) -> bool: ... - def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ... + def wrap_stream(self, message: Envelope, stream: AsyncIterable[Envelope]) -> AsyncIterable[Envelope]: ... async def quit(self, session_id: str) -> None: ... @@ -26,5 +32,5 @@ class TurnResult: session_id: str prompt: str - model_output: str + model_output: Envelope outbounds: list[Envelope] = field(default_factory=list) diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index ba470306..bc436fe7 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -8,9 +8,9 @@ import pytest from any_llm.types.completion import ChatCompletionChunk -from bub.builtin.agent import Agent +from bub.builtin.agent import Agent, AsyncStreamEvents, StreamEvent, StreamState from bub.builtin.settings import AgentSettings -from bub.runtime import BubError +from bub.errors import BubError from bub.tape import Tape, TapeContext from bub.tools import REGISTRY, tool @@ -60,6 +60,10 @@ async def _chat_stream(content: str) -> AsyncIterator[ChatCompletionChunk]: yield _chat_chunk(content) +async def _consume_stream(stream: AsyncStreamEvents) -> list[StreamEvent]: + return [event async for event in stream] + + class _ForkCapture: """Captures fork_tape enter and exit behavior.""" @@ -146,7 +150,7 @@ async def test_agent_run_regular_session_merges_back() -> None: assert fork_capture.merge_back_values == [True] assert fork_capture.exit_count == 0 - [event async for event in result] + await _consume_stream(result) assert fork_capture.merge_back_values == [True] assert fork_capture.exit_count == 1 @@ -164,7 +168,7 @@ async def test_agent_run_temp_session_does_not_merge_back() -> None: assert fork_capture.merge_back_values == [False] assert fork_capture.exit_count == 0 - [event async for event in result] + await _consume_stream(result) assert fork_capture.merge_back_values == [False] assert fork_capture.exit_count == 1 @@ -184,7 +188,7 @@ async def test_agent_run_passes_model_to_llm() -> None: state={"_runtime_workspace": "/tmp"}, # noqa: S108 model="openai:gpt-4o", ) - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs["model"] == "openai:gpt-4o" @@ -195,14 +199,38 @@ async def test_agent_run_empty_prompt_returns_error() -> None: agent.tapes = MagicMock() # type: ignore[assignment] result = await agent.run_stream(session_id="user/s1", prompt="", state={}) - events = [event async for event in result] + events = await _consume_stream(result) - assert [(event.kind, event.data) for event in events] == [ - ("text", {"delta": "error: empty prompt"}), - ("final", {"ok": False, "text": "error: empty prompt"}), + assert events == [ + StreamEvent("text", {"delta": "error: empty prompt"}), + StreamEvent("final", {"ok": False, "text": "error: empty prompt"}), ] +@pytest.mark.asyncio +async def test_agent_loop_preserves_run_once_stream_state() -> None: + agent = _make_agent() + fork_capture = _ForkCapture() + fake_tapes = _FakeTapeService(fork_capture) + agent.tapes = fake_tapes # type: ignore[assignment] + state = StreamState(usage={"total_tokens": 3}) + + async def iterator() -> AsyncIterator[StreamEvent]: + yield StreamEvent("text", {"delta": "done"}) + yield StreamEvent("final", {"ok": True, "text": "done"}) + + async def fake_run_once(**kwargs: Any) -> AsyncStreamEvents: + return AsyncStreamEvents(iterator(), state=state) + + agent._run_once = fake_run_once # type: ignore[method-assign] + tape = Tape(name="test-tape", context=TapeContext(state={})) + + events = await agent._agent_loop(tape=tape, prompt="hello") + [event async for event in events] + + assert events.usage == {"total_tokens": 3} + + @pytest.mark.asyncio async def test_agent_run_model_defaults_to_none() -> None: """When model is not specified, settings.model is used for any-llm.""" @@ -212,7 +240,7 @@ async def test_agent_run_model_defaults_to_none() -> None: agent.tapes = fake_tapes # type: ignore[assignment] result = await agent.run_stream(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs["model"] == "test:model" @@ -243,7 +271,7 @@ def denied_agent_tool() -> str: state={"_runtime_workspace": "/tmp"}, # noqa: S108 allowed_tools=[" tests_allowed_agent_tool "], ) - [event async for event in result] + await _consume_stream(result) assert agent.completion_kwargs is not None assert [tool.name for tool in agent.completion_kwargs["tools"]] == ["tests_allowed_agent_tool"] @@ -267,4 +295,4 @@ async def test_agent_run_rejects_unknown_allowed_tools() -> None: ) with pytest.raises(ValueError, match="tests_missing_agent_tool"): - [event async for event in stream] + await _consume_stream(stream) diff --git a/tests/test_builtin_hook_impl.py b/tests/test_builtin_hook_impl.py index 8c8310d1..a3ecf497 100644 --- a/tests/test_builtin_hook_impl.py +++ b/tests/test_builtin_hook_impl.py @@ -7,11 +7,11 @@ import pytest +from bub.builtin.agent import AsyncStreamEvents, StreamEvent from bub.builtin.hook_impl import AGENTS_FILE_NAME, DEFAULT_SYSTEM_PROMPT, BuiltinImpl from bub.builtin.store import FileTapeStore from bub.channels.message import ChannelMessage from bub.framework import BubFramework -from bub.runtime import AsyncStreamEvents, StreamEvent class RecordingLifespan: @@ -169,9 +169,14 @@ async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None: state = {"context": "ctx"} stream = await impl.run_model_stream(prompt="prompt", session_id="session", state=state) - events = [event async for event in stream] - - assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})] + binding = impl.bind_envelope(stream, session_id="session", state=state) + assert binding is stream + events = binding.stream() + assert events is not None + + assert [event async for event in events] == [ + {"content": "agent-output", "source": StreamEvent("text", {"delta": "agent-output"})} + ] assert agent.run_stream_calls == [("session", "prompt", state)] assert agent.run_calls == [] diff --git a/tests/test_builtin_tools.py b/tests/test_builtin_tools.py index d5bc7e56..1268b06e 100644 --- a/tests/test_builtin_tools.py +++ b/tests/test_builtin_tools.py @@ -20,7 +20,7 @@ render_tools_prompt, resolve_tool_names, ) -from bub.runtime import ErrorKind +from bub.errors import ErrorKind from bub.tools import REGISTRY, Tool, ToolContext, ToolExecutor, tool diff --git a/tests/test_channels.py b/tests/test_channels.py index 079e1db6..5f14a591 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -15,7 +15,6 @@ from bub.channels.manager import ChannelManager from bub.channels.message import ChannelMessage from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser -from bub.runtime import StreamEvent from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer @@ -674,17 +673,17 @@ async def test_cli_channel_stream_events_prints_stream_and_yields_events(monkeyp message = _message("ignored", channel="cli", kind="command", session_id="cli:1") - async def source() -> asyncio.AsyncIterator[StreamEvent]: - yield StreamEvent("text", {"delta": " "}) - yield StreamEvent("text", {"delta": "hel"}) - yield StreamEvent("text", {"delta": "lo"}) - yield StreamEvent("final", {}) + async def source() -> asyncio.AsyncIterator[dict[str, str]]: + yield {"content": " "} + yield {"content": "hel"} + yield {"content": "lo"} + yield {"end": True} yielded = [event async for event in channel.stream_events(message, source())] assert heads == ["command"] assert printed == [("hel", "", False), ("lo", "", False), ("", None, None)] - assert [event.kind for event in yielded] == ["text", "text", "final"] + assert yielded == [{"content": " "}, {"content": "hel"}, {"content": "lo"}, {"end": True}] def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: diff --git a/tests/test_framework.py b/tests/test_framework.py index f39b8b03..efc2c2a0 100644 --- a/tests/test_framework.py +++ b/tests/test_framework.py @@ -19,7 +19,6 @@ from bub.configure import ensure_config from bub.framework import BubFramework from bub.hookspecs import hookimpl -from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot @@ -42,6 +41,22 @@ async def stop(self) -> None: return NamedChannelImpl() +class _BoundEnvelope: + def __init__(self, events: list[dict[str, Any]], output: Any) -> None: + self._events = events + self._output = output + + def stream(self): + async def iterator(): + for event in self._events: + yield event + + return iterator() + + def output(self): + return self._output + + def test_create_cli_app_sets_workspace_and_context(tmp_path: Path) -> None: framework = BubFramework() @@ -337,7 +352,7 @@ def admit_message(self, session_id, message, turn): async def test_process_inbound_streams_when_requested() -> None: # noqa: C901 framework = BubFramework() stream_calls: list[str] = [] - wrapped_events: list[str] = [] + wrapped_events: list[dict[str, Any]] = [] class StreamingPlugin: @hookimpl @@ -355,13 +370,14 @@ def build_prompt(self, message, session_id, state) -> str: @hookimpl async def run_model_stream(self, prompt, session_id, state): stream_calls.append(prompt) + return _BoundEnvelope( + [{"chunk": "stream"}, {"chunk": "ed"}, {"done": True}], + "streamed", + ) - async def iterator(): - yield StreamEvent("text", {"delta": "stream"}) - yield StreamEvent("text", {"delta": "ed"}) - yield StreamEvent("final", {"text": "streamed", "ok": True}) - - return AsyncStreamEvents(iterator(), state=StreamState()) + @hookimpl + def bind_envelope(self, envelope, session_id, state): + return envelope if isinstance(envelope, _BoundEnvelope) else None @hookimpl async def save_state(self, session_id, state, message, model_output) -> None: @@ -379,7 +395,7 @@ class RecordingRouter: def wrap_stream(self, message, stream): async def iterator(): async for event in stream: - wrapped_events.append(event.kind) + wrapped_events.append(event) yield event return iterator() @@ -399,5 +415,5 @@ async def quit(self, session_id: str) -> None: ) assert stream_calls == ["prompt"] - assert wrapped_events == ["text", "text", "final"] + assert wrapped_events == [{"chunk": "stream"}, {"chunk": "ed"}, {"done": True}] assert result.model_output == "streamed" diff --git a/tests/test_hook_runtime.py b/tests/test_hook_runtime.py index 78088858..3b19ac88 100644 --- a/tests/test_hook_runtime.py +++ b/tests/test_hook_runtime.py @@ -1,9 +1,11 @@ +from collections.abc import AsyncIterable, AsyncIterator + import pluggy import pytest from bub.hook_runtime import HookRuntime from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs, hookimpl -from bub.runtime import AsyncStreamEvents, StreamEvent +from bub.types import Envelope def _runtime_with_plugins(*plugins: tuple[str, object]) -> HookRuntime: @@ -14,6 +16,22 @@ def _runtime_with_plugins(*plugins: tuple[str, object]) -> HookRuntime: return HookRuntime(manager) +class _BoundEnvelope: + def __init__(self, events: list[Envelope], output: Envelope) -> None: + self._events = events + self._output = output + + def stream(self) -> AsyncIterable[Envelope] | None: + async def iterator() -> AsyncIterator[Envelope]: + for event in self._events: + yield event + + return iterator() + + def output(self) -> Envelope | None: + return self._output + + @pytest.mark.asyncio async def test_call_first_respects_priority_and_returns_first_non_none() -> None: called: list[str] = [] @@ -112,11 +130,14 @@ async def test_run_model_uses_streaming_hook_when_plain_hook_absent() -> None: class StreamPlugin: @hookimpl async def run_model_stream(self, prompt, session_id, state): - async def iterator(): - yield StreamEvent("text", {"delta": "stream"}) - yield StreamEvent("text", {"delta": "ed"}) + return _BoundEnvelope( + [{"content": "stream"}, {"content": "ed"}], + "streamed", + ) - return AsyncStreamEvents(iterator()) + @hookimpl + def bind_envelope(self, envelope, session_id, state): + return envelope if isinstance(envelope, _BoundEnvelope) else None runtime = _runtime_with_plugins(("stream", StreamPlugin())) @@ -136,6 +157,4 @@ async def run_model(self, prompt, session_id, state): stream = await runtime.run_model_stream(prompt="hello", session_id="s", state={}) - assert stream is not None - events = [event async for event in stream] - assert [(event.kind, event.data) for event in events] == [("text", {"delta": "plain"})] + assert stream == "plain" diff --git a/tests/test_subagent_tool.py b/tests/test_subagent_tool.py index 6cf3853b..7ee9f220 100644 --- a/tests/test_subagent_tool.py +++ b/tests/test_subagent_tool.py @@ -5,8 +5,8 @@ import pytest +from bub.builtin.agent import AsyncStreamEvents, StreamEvent from bub.builtin.tools import run_subagent -from bub.runtime import AsyncStreamEvents, StreamEvent from bub.tools import REGISTRY, tool