Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/bub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@
from bub.configure import Settings, config, ensure_config
from bub.framework import DEFAULT_HOME, BubFramework
from bub.hookspecs import hookimpl
from bub.runtime_options import RuntimeChoice, RuntimeOptions
from bub.tools import tool
from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot

__all__ = [
"AdmitDecision",
"BubFramework",
"RuntimeChoice",
"RuntimeOptions",
"Settings",
"SteeringBuffer",
"TurnSnapshot",
Expand Down
34 changes: 32 additions & 2 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
from bub import inquirer as bub_inquirer
from bub.builtin.agent import Agent
from bub.builtin.context import default_tape_context
from bub.builtin.settings import DEFAULT_MODEL
from bub.builtin.settings import DEFAULT_MODEL, load_settings
from bub.channels.base import Channel
from bub.channels.message import ChannelMessage, MediaItem
from bub.envelope import content_of, field_of
from bub.framework import BubFramework
from bub.hookspecs import hookimpl
from bub.runtime import AsyncStreamEvents
from bub.runtime_options import RuntimeChoice, RuntimeOptions
from bub.tape import TapeContext, TapeStore
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler, State
Expand Down Expand Up @@ -101,6 +102,12 @@ def _default_enabled_channels(current_value: object, available_channels: list[st
return selected
return available_channels

@staticmethod
def _configured_models() -> list[str]:
settings = load_settings()
models = [settings.model, *(settings.fallback_models or [])]
return list(dict.fromkeys(model for model in models if model))

@hookimpl
def resolve_session(self, message: ChannelMessage) -> str:
session_id = field_of(message, "session_id")
Expand All @@ -118,6 +125,8 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State:
state = {"session_id": session_id, "_runtime_agent": self._get_agent()}
if context := field_of(message, "context_str"):
state["context"] = context
if model := field_of(message, "runtime", {}).get("model"):
state["_runtime_model"] = model
return state

@hookimpl
Expand Down Expand Up @@ -158,7 +167,12 @@ async def build_prompt(self, message: ChannelMessage, session_id: str, state: St

@hookimpl
async def run_model_stream(self, prompt: str | list[dict], session_id: str, state: State) -> AsyncStreamEvents:
return await self._get_agent().run_stream(session_id=session_id, prompt=prompt, state=state)
return await self._get_agent().run_stream(
session_id=session_id,
prompt=prompt,
state=state,
model=state.get("_runtime_model"),
)

@hookimpl
def register_cli_commands(self, app: typer.Typer) -> None:
Expand Down Expand Up @@ -219,6 +233,22 @@ def onboard_config(self, current_config: dict[str, object]) -> dict[str, object]
config["api_base"] = api_base
return config

@hookimpl
def provide_runtime_options(
self,
session_id: str,
workspace: Path | None = None,
) -> RuntimeOptions | None:
del session_id, workspace
models = self._configured_models()
if not models:
return None

return RuntimeOptions(
models=[RuntimeChoice(id=model, name=model) for model in models],
current_model=models[0],
)

def _read_agents_file(self, state: State) -> str:
workspace = state.get("_runtime_workspace", str(Path.cwd()))
prompt_path = Path(workspace) / AGENTS_FILE_NAME
Expand Down
1 change: 1 addition & 0 deletions src/bub/channels/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class ChannelMessage:
is_active: bool = False
kind: MessageKind = "normal"
context: dict[str, Any] = field(default_factory=dict)
runtime: dict[str, Any] = field(default_factory=dict)
media: list[MediaItem] = field(default_factory=list)
lifespan: contextlib.AbstractAsyncContextManager | None = None
output_channel: str = ""
Expand Down
33 changes: 33 additions & 0 deletions src/bub/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from bub.hook_runtime import _SKIP_VALUE, HookRuntime
from bub.hookspecs import BUB_HOOK_NAMESPACE, BubHookSpecs
from bub.runtime import BubError, ErrorKind
from bub.runtime_options import RuntimeOptions
from bub.tape import AsyncTapeStore, TapeContext, TapeStore
from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot
from bub.types import Envelope, MessageHandler, OutboundChannelRouter, TurnResult
Expand Down Expand Up @@ -220,6 +221,38 @@ async def admit_message(self, *, session_id: str, message: Envelope, turn: TurnS
return decision
raise TypeError("hook.admit_message must return AdmitDecision or None")

async def get_runtime_options(
self,
*,
session_id: str,
workspace: str | Path | None = None,
) -> RuntimeOptions:
"""Collect protocol-neutral runtime choices for one session."""

resolved_workspace = self._resolve_workspace(workspace)
results = await self._hook_runtime.call_many(
"provide_runtime_options",
session_id=session_id,
workspace=resolved_workspace,
)

merged = RuntimeOptions()
for result in results:
if result is None:
continue
if not isinstance(result, RuntimeOptions):
raise TypeError("hook.provide_runtime_options must return RuntimeOptions or None")
merged = RuntimeOptions(
models=[*merged.models, *result.models],
current_model=merged.current_model or result.current_model,
)
return merged

def _resolve_workspace(self, workspace: str | Path | None) -> Path:
if workspace is None:
return self.workspace
return Path(workspace).expanduser().resolve()

def steering(self, session_id: str) -> SteeringBuffer:
buffer = self._steering_buffers.get(session_id)
if buffer is None:
Expand Down
10 changes: 10 additions & 0 deletions src/bub/hookspecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Any

import pluggy

from bub.runtime import AsyncStreamEvents
from bub.runtime_options import RuntimeOptions
from bub.tape import AsyncTapeStore, TapeContext, TapeStore
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler, State
Expand Down Expand Up @@ -85,6 +87,14 @@ def register_cli_commands(self, app: Any) -> None:
def onboard_config(self, current_config: dict[str, Any]) -> dict[str, Any] | None:
"""Collect a plugin config fragment for the interactive onboarding command."""

@hookspec
def provide_runtime_options(
self,
session_id: str,
workspace: Path | None,
) -> RuntimeOptions | None:
"""Provide protocol-neutral runtime choices for a session."""

@hookspec
def on_error(self, stage: str, error: Exception, message: Envelope | None) -> None:
"""Observe framework errors from any stage."""
Expand Down
24 changes: 24 additions & 0 deletions src/bub/runtime_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Protocol-neutral runtime option types."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any


@dataclass(frozen=True)
class RuntimeChoice:
"""One selectable runtime value."""

id: str
name: str | None = None
description: str | None = None
meta: dict[str, Any] | None = None


@dataclass(frozen=True)
class RuntimeOptions:
"""Runtime choices that a channel or adapter may present to a user."""

models: list[RuntimeChoice] = field(default_factory=list)
current_model: str | None = None
59 changes: 53 additions & 6 deletions tests/test_builtin_hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,21 @@ class FakeAgent:
def __init__(self, home: Path) -> None:
self.settings = SimpleNamespace(home=home)
self.run_calls: list[tuple[str, str, dict[str, object]]] = []
self.run_stream_calls: list[tuple[str, str, dict[str, object]]] = []
self.run_stream_calls: list[tuple[str, str, dict[str, object], str | None]] = []

async def run(self, *, session_id: str, prompt: str, state: dict[str, object]) -> str:
self.run_calls.append((session_id, prompt, state))
return "agent-output"

async def run_stream(self, *, session_id: str, prompt: str, state: dict[str, object]) -> AsyncStreamEvents:
self.run_stream_calls.append((session_id, prompt, state))
async def run_stream(
self,
*,
session_id: str,
prompt: str,
state: dict[str, object],
model: str | None = None,
) -> AsyncStreamEvents:
self.run_stream_calls.append((session_id, prompt, state, model))

async def iterator():
yield StreamEvent("text", {"delta": "agent-output"})
Expand All @@ -49,8 +56,8 @@ def _raise_value_error() -> None:
raise ValueError("boom")


def _build_impl(tmp_path: Path) -> tuple[BubFramework, BuiltinImpl, FakeAgent]:
framework = BubFramework()
def _build_impl(tmp_path: Path, config_file: Path | None = None) -> tuple[BubFramework, BuiltinImpl, FakeAgent]:
framework = BubFramework(config_file=config_file) if config_file is not None else BubFramework()
impl = BuiltinImpl(framework)
agent = FakeAgent(tmp_path)
impl._agent = agent
Expand Down Expand Up @@ -160,10 +167,50 @@ async def test_run_model_stream_delegates_to_agent(tmp_path: Path) -> None:
events = [event async for event in stream]

assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})]
assert agent.run_stream_calls == [("session", "prompt", state)]
assert agent.run_stream_calls == [("session", "prompt", state, None)]
assert agent.run_calls == []


@pytest.mark.asyncio
async def test_runtime_model_override_is_passed_to_agent(tmp_path: Path) -> None:
_, impl, agent = _build_impl(tmp_path)
message = ChannelMessage(
session_id="session",
channel="cli",
chat_id="room",
content="hello",
runtime={"model": "anthropic:claude-sonnet-4-5"},
)

state = await impl.load_state(message=message, session_id="session")
stream = await impl.run_model_stream(prompt="prompt", session_id="session", state=state)
events = [event async for event in stream]

assert [(event.kind, event.data) for event in events] == [("text", {"delta": "agent-output"})]
assert agent.run_stream_calls == [("session", "prompt", state, "anthropic:claude-sonnet-4-5")]


def test_builtin_provides_model_runtime_options(tmp_path: Path, load_config) -> None:
with pytest.MonkeyPatch.context() as monkeypatch:
monkeypatch.delenv("BUB_MODEL", raising=False)
monkeypatch.delenv("BUB_FALLBACK_MODELS", raising=False)
config_file = load_config(
"""
model: openai:gpt-5
fallback_models:
- anthropic:claude-sonnet-4-5
- openai:gpt-5
""".strip()
)
_, impl, _ = _build_impl(tmp_path, config_file=config_file)

options = impl.provide_runtime_options(session_id="session")

assert options is not None
assert options.current_model == "openai:gpt-5"
assert [item.id for item in options.models] == ["openai:gpt-5", "anthropic:claude-sonnet-4-5"]


def test_system_prompt_appends_workspace_agents_file(tmp_path: Path) -> None:
_, impl, _ = _build_impl(tmp_path)
(tmp_path / AGENTS_FILE_NAME).write_text("local rules", encoding="utf-8")
Expand Down
38 changes: 38 additions & 0 deletions tests/test_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from bub.framework import BubFramework
from bub.hookspecs import hookimpl
from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState
from bub.runtime_options import RuntimeChoice, RuntimeOptions
from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot


Expand Down Expand Up @@ -333,6 +334,43 @@ def admit_message(self, session_id, message, turn):
assert decision == AdmitDecision("follow_up", reason="busy")


@pytest.mark.asyncio
async def test_get_runtime_options_collects_models_by_priority(tmp_path: Path) -> None:
framework = BubFramework()

class LowPriorityPlugin:
@hookimpl
def provide_runtime_options(self, session_id, workspace):
assert session_id == "session"
assert workspace == tmp_path.resolve()
return RuntimeOptions(
models=[RuntimeChoice(id="low", name="Low")],
current_model="low",
)

class HighPriorityPlugin:
@hookimpl
def provide_runtime_options(self, session_id, workspace):
assert session_id == "session"
assert workspace == tmp_path.resolve()
return RuntimeOptions(
models=[RuntimeChoice(id="high", name="High"), RuntimeChoice(id="mid", name="Mid")],
current_model="high",
)

framework._plugin_manager.register(LowPriorityPlugin(), name="low")
framework._plugin_manager.register(HighPriorityPlugin(), name="high")

options = await framework.get_runtime_options(session_id="session", workspace=tmp_path)

assert [(choice.id, choice.name) for choice in options.models] == [
("high", "High"),
("mid", "Mid"),
("low", "Low"),
]
assert options.current_model == "high"


@pytest.mark.asyncio
async def test_process_inbound_streams_when_requested() -> None: # noqa: C901
framework = BubFramework()
Expand Down
1 change: 1 addition & 0 deletions website/src/content/docs/docs/reference/hooks.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ For the *why* and *how* of each stage see [Turn pipeline](/docs/concepts/turn-pi
| `dispatch_outbound` | broadcast | `(message: Envelope) -> bool` | sent flag | `process_inbound` per outbound | Each outbound is fanned out to every impl. |
| `register_cli_commands` | sync-only consumer | `(app: typer.Typer) -> None` | none | `BubFramework.create_cli_app` (`call_many_sync`) | Bootstrap only; async impls log a warning and are skipped. |
| `onboard_config` | sync-only consumer (custom merge) | `(current_config: dict) -> dict \| None` | config fragment | `BubFramework.collect_onboard_config` | Iterated by priority; each fragment is merged via `configure.merge`. Non-dict returns raise `TypeError`. |
| `provide_runtime_options` | broadcast | `(session_id: str, workspace: Path \| None) -> RuntimeOptions \| None` | runtime choices | `BubFramework.get_runtime_options` | Model choices are appended in hook priority order. Selection state is owned by the caller or adapter. |
| `on_error` | observer | `(stage: str, error: Exception, message: Envelope \| None) -> None` | none | `HookRuntime.notify_error` / `notify_error_sync` | Failures inside an `on_error` impl are caught and logged so other observers still run. |
| `system_prompt` | broadcast (joined) | `(prompt, state) -> str` | prompt fragment | `BubFramework.get_system_prompt` (`call_many_sync`) | Results are reversed and joined with `\n\n`; truthy fragments only. |
| `provide_tape_store` | firstresult | `() -> TapeStore \| AsyncTapeStore` | tape store | `BubFramework.running()` | Resolved once when the runtime scope opens; sync/async iterators are entered as context managers. |
Expand Down
2 changes: 1 addition & 1 deletion website/src/content/docs/docs/reference/index.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ This page indexes the four reference tables for Bub's public surface. Each page
| [Hooks](/docs/reference/hooks/) | Every spec in `BubHookSpecs` with kind, parameters, return type, invocation site. |
| [CLI](/docs/reference/cli/) | Every `bub` command and subcommand, with options, defaults, and behavior notes. |
| [Settings](/docs/reference/settings/) | All `BUB_*` environment variables, `pydantic-settings` classes, and `~/.bub/config.yml` keys. |
| [Types](/docs/reference/types/) | Public types from `bub`: `Envelope`, `State`, `TurnResult`, `Channel`, `OutboundChannelRouter`, `BubFramework`. |
| [Types](/docs/reference/types/) | Public types from `bub`: `Envelope`, `State`, `TurnResult`, runtime options, `Channel`, `OutboundChannelRouter`, `BubFramework`. |
Loading