Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/bub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@
from bub.framework import DEFAULT_HOME, BubFramework
from bub.hookspecs import hookimpl
from bub.tools import tool
from bub.turn_admission import AdmitDecision, SteeringBuffer, TurnSnapshot
from bub.turn_admission import AdmitDecision, TurnSnapshot

__all__ = [
"AdmitDecision",
"BubFramework",
"Settings",
"SteeringBuffer",
"TurnSnapshot",
"config",
"ensure_config",
Expand Down
38 changes: 38 additions & 0 deletions src/bub/builtin/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import asyncio
import inspect
import re
import shlex
import time
from collections import deque
from collections.abc import AsyncGenerator, AsyncIterator, Callable, Collection, Coroutine, Iterable
from contextlib import AsyncExitStack
from dataclasses import dataclass, replace
Expand All @@ -22,6 +24,7 @@
)
from bub.builtin.settings import load_settings
from bub.builtin.tape import Tape
from bub.envelope import field_of
from bub.framework import BubFramework
from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState
from bub.skills import discover_skills, render_skills_prompt
Expand All @@ -46,6 +49,7 @@ def __init__(self, framework: BubFramework) -> None:
self.settings = load_settings()
self.framework = framework
self.model_runner = ModelRunner(self.settings)
self._steering_messages: dict[str, deque[dict[str, Any]]] = {}

@cached_property
def tape(self) -> Tape:
Expand Down Expand Up @@ -79,6 +83,21 @@ async def generator() -> AsyncIterator[StreamEvent]:

return AsyncStreamEvents(generator(), state=events._state)

def enqueue_steering_message(self, thread_id: str, message: dict[str, Any]) -> bool:
if thread_id not in self._steering_messages:
return False
self._steering_messages[thread_id].append(message)
return True

def _drain_steering_messages(self, thread_id: str) -> list[dict[str, Any]]:
queue = self._steering_messages.pop(thread_id, None)
if queue is None:
return []
return list(queue)

def _has_steering_messages(self, thread_id: str) -> bool:
return bool(self._steering_messages.get(thread_id))

async def run_stream(
self,
*,
Expand All @@ -98,6 +117,8 @@ async def run_stream(
tape = self.tape.session_tape(
session_id, workspace_from_state(state), context=replace(self.tape.context, state=state)
)
thread_id = state.get("_runtime_thread_id", tape.name)
self._steering_messages.setdefault(thread_id, deque())
merge_back = not session_id.startswith("temp/")
stack = AsyncExitStack()
# The fork_tape context manager must not be exited until the last chunk of the stream is consumed.
Expand All @@ -117,6 +138,11 @@ async def run_stream(
allowed_skills=allowed_skills,
allowed_tools=allowed_tools,
)

@stack.callback
def cleanup() -> None:
self._steering_messages.pop(thread_id, None)

return self._events_with_callback(events, callback=stack.aclose)

async def _run_command(self, tape: Tape, *, line: str) -> str:
Expand Down Expand Up @@ -273,6 +299,8 @@ async def _stream_events_with_auto_handoff(
state.error = output.error
state.usage = output.usage
elapsed_ms = int((time.monotonic() - start) * 1000)
thread_id = tape.context.state.get("_runtime_thread_id", tape.name)
should_continue = should_continue or self._has_steering_messages(thread_id)
if not should_continue:
await tape.append_event(
"loop.step",
Expand Down Expand Up @@ -355,12 +383,22 @@ async def _run_once_stream(
resolved_model = model or self.settings.model

model_tools_for_call = model_tools(tools)
thread_id = tape.context.state.get("_runtime_thread_id", tape.name)
steering_messages = list(
await asyncio.gather(*[
self.framework.build_prompt(
message, session_id=field_of(message, "session_id"), state=tape.context.state
)
for message in self._drain_steering_messages(thread_id)
])
)
return self.model_runner.run(
tape=tape,
model=resolved_model,
tools=model_tools_for_call,
system_prompt=system_prompt,
prompt=prompt,
steering_messages=steering_messages,
)

def _system_prompt(
Expand Down
14 changes: 12 additions & 2 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ async def load_state(self, message: ChannelMessage, session_id: str) -> State:
state = {"session_id": session_id, "_runtime_agent": self._get_agent()}
if context := field_of(message, "context_str"):
state["context"] = context
if thread_id := field_of(message, "context", {}).get("thread_id"):
state["_runtime_thread_id"] = thread_id
return state

@hookimpl
Expand Down Expand Up @@ -294,7 +296,7 @@ def build_tape_context(self) -> TapeContext:
return default_tape_context()

@hookimpl
def admit_message(
async def admit_message(
self,
session_id: str,
message: Envelope,
Expand All @@ -303,4 +305,12 @@ def admit_message(
outbound_router = self.framework._outbound_router
if outbound_router is None:
return None
return outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn)
return await outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn)

@hookimpl
async def handle_steering(self, message: Envelope, reason: str | None) -> bool:
"""Handle a steering message that is admitted by the `admit_message` hook with action "steer"."""
agent = self._get_agent()
context = field_of(message, "context", {})
thread_id = field_of(context, "thread_id") if isinstance(context, dict) else None
return agent.enqueue_steering_message(str(thread_id or field_of(message, "session_id", "")), message)
9 changes: 7 additions & 2 deletions src/bub/builtin/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def run(
tools: list[Tool],
system_prompt: str | None,
prompt: str | list[dict],
steering_messages: list[list[dict[str, Any]] | str] | None = None,
) -> AsyncStreamEvents:
state = StreamState()

Expand All @@ -96,6 +97,7 @@ async def iterator() -> AsyncGenerator[StreamEvent, None]:
system_prompt=system_prompt,
prompt=prompt,
model=model,
steering_messages=steering_messages,
)
output = ModelOutputAccumulator()
async with asyncio.timeout(self.settings.model_timeout_seconds):
Expand Down Expand Up @@ -159,6 +161,7 @@ async def build_messages(
system_prompt: str | None,
prompt: str | list[dict],
model: str,
steering_messages: list[list[dict[str, Any]] | str] | None = None,
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
prompt_message: dict[str, Any] = {"role": "user", "content": prompt}
try:
Expand All @@ -172,10 +175,12 @@ async def build_messages(
model=model,
)
raise
steering_messages_native = [{"role": "user", "content": message} for message in (steering_messages or [])]
if system_prompt:
messages = [{"role": "system", "content": system_prompt}, *messages]
messages.append(prompt_message)
return messages, [prompt_message]
new_messages = [*steering_messages_native, prompt_message]
messages.extend(new_messages)
return messages, new_messages

async def record_context_error(
self,
Expand Down
2 changes: 1 addition & 1 deletion src/bub/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEve
"""Optionally wrap the output stream for this channel."""
return stream

def admit_message(
async def admit_message(
self,
session_id: str,
message: Envelope,
Expand Down
11 changes: 6 additions & 5 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ async def _main_loop(self) -> None:
continue

request = self._normalize_input(raw)
await self._echo_input(raw)

message = ChannelMessage(
session_id=self._message_template["session_id"],
channel=self._message_template["channel"],
chat_id=self._message_template["chat_id"],
context={"thread_id": self._message_template["session_id"]}, # use the same thread_id for all messages
content=request,
lifespan=self.message_lifespan(),
)
Expand Down Expand Up @@ -332,11 +332,11 @@ def _prompt_label(self) -> str:
symbol = ">" if self._mode == "agent" else ","
return f"{cwd} {symbol} "

async def _echo_input(self, raw: str) -> None:
async def _echo_input(self, raw: str, steering: bool = False) -> None:
stream_printer = getattr(self, "_stream_printer", None)
if stream_printer is not None:
await stream_printer.commit_live_text()
self._renderer.input_echo(self._prompt_label(), raw)
self._renderer.input_echo(self._prompt_label(), raw, steering=steering)

async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
Expand Down Expand Up @@ -416,12 +416,13 @@ def _history_file(home: Path, workspace: Path) -> Path:
workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest()
return home / "history" / f"{workspace_hash}.history"

def admit_message(
async def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
await self._echo_input(message.content, steering=turn.is_running)
if not turn.is_running:
return None
return AdmitDecision("follow_up", reason="cli session is already generating")
return AdmitDecision("steer", reason="cli session is already generating")
5 changes: 3 additions & 2 deletions src/bub/channels/cli/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ def error(self, text: str) -> None:
return
self.console.print(f"[red bold]Error >[/]\n{text}")

def input_echo(self, prompt: str, text: str) -> None:
def input_echo(self, prompt: str, text: str, steering: bool = False) -> None:
if not text.strip():
return
self.console.print(f"[bold]{prompt}[/]{text}", new_line_start=True)
mid = "[grey](steering)[/] " if steering else ""
self.console.print(f"[dim][bold]{prompt}[/]{mid}{text}[/]", new_line_start=True)

def tool_call_start(self, *, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
self.console.print(Text(_format_tool_call(name, args, kwargs), style="magenta"), new_line_start=True)
Expand Down
25 changes: 5 additions & 20 deletions src/bub/channels/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,9 @@ def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) ->
async def quit(self, session_id: str) -> None:
controller = self._session_controllers.get(session_id)
if controller is None:
self.framework.clear_steering(session_id)
logger.info(f"channel.manager quit session_id={session_id}, cancelled 0 tasks")
return
controller.clear_pending()
controller.steering.drain_nowait()
tasks = set(controller.active_tasks)
current_task = asyncio.current_task()
cancelled_count = 0
Expand Down Expand Up @@ -177,18 +175,17 @@ def enabled_channels(self) -> list[Channel]:
def _controller(self, session_id: str) -> SessionTurnController:
controller = self._session_controllers.get(session_id)
if controller is None:
controller = SessionTurnController(session_id=session_id, steering=self.framework.steering(session_id))
controller = SessionTurnController(session_id=session_id)
self._session_controllers[session_id] = controller
return controller

def _drop_empty_controller(self, session_id: str) -> None:
controller = self._session_controllers.get(session_id)
if controller is None:
return
if controller.active() or controller.pending_queue or controller.steering.count > 0:
if controller.active() or controller.pending_queue:
return
self._session_controllers.pop(session_id, None)
self.framework.clear_steering(session_id)

def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
if task.cancelled():
Expand All @@ -199,8 +196,6 @@ def _on_task_done(self, session_id: str, task: asyncio.Task) -> None:
if controller is None:
return
controller.active_tasks.discard(task)
if not controller.active():
controller.promote_steering_to_pending()
self._schedule_pending(session_id)
self._drop_empty_controller(session_id)

Expand Down Expand Up @@ -249,13 +244,7 @@ async def _apply_admission_decision(
if action == "follow_up":
return self._queue_pending(controller, message, decision.reason)
if action == "steer":
if controller.active():
controller.steering.put_nowait(message)
logger.info(
"channel.manager admission steer session_id={} reason={}",
message.session_id,
decision.reason,
)
if await self.framework.steer_message(message, decision.reason):
return False
return self._queue_pending(controller, message, decision.reason)
logger.warning("channel.manager admission unknown action={} session_id={}", decision.action, message.session_id)
Expand Down Expand Up @@ -326,23 +315,19 @@ async def listen_and_run(self) -> None:

async def shutdown(self) -> None:
count = 0
session_ids = list(self._session_controllers)
for controller in list(self._session_controllers.values()):
controller.clear_pending()
controller.steering.drain_nowait()
for task in set(controller.active_tasks):
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task
count += 1
self._session_controllers.clear()
for session_id in session_ids:
self.framework.clear_steering(session_id)
logger.info(f"channel.manager cancelled {count} in-flight tasks")
for channel in self.enabled_channels():
await channel.stop()

def admit_channel_message(
async def admit_channel_message(
self,
session_id: str,
message: Envelope,
Expand All @@ -354,4 +339,4 @@ def admit_channel_message(
channel = self.get_channel(str(channel_name))
if channel is None:
return None
return channel.admit_message(session_id=session_id, message=message, turn=turn)
return await channel.admit_message(session_id=session_id, message=message, turn=turn)
4 changes: 3 additions & 1 deletion src/bub/channels/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def __post_init__(self) -> None:
@property
def context_str(self) -> str:
"""String representation of the context for prompt building."""
return "|".join(f"{key}={value}" for key, value in self.context.items())
return "|".join(
f"{key}={value}" for key, value in self.context.items() if not key.startswith("_")
) # ignore internal keys

@classmethod
def from_batch(cls, batch: list[ChannelMessage]) -> ChannelMessage:
Expand Down
Loading
Loading