Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/bub/builtin/hook_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from bub.hookspecs import hookimpl
from bub.runtime import AsyncStreamEvents
from bub.tape import TapeContext, TapeStore
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler, State

AGENTS_FILE_NAME = "AGENTS.md"
Expand Down Expand Up @@ -291,3 +292,15 @@ def provide_tape_store(self) -> TapeStore:
@hookimpl
def build_tape_context(self) -> TapeContext:
return default_tape_context()

@hookimpl
def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
outbound_router = self.framework._outbound_router
if outbound_router is None:
return None
return outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn)
11 changes: 11 additions & 0 deletions src/bub/channels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from bub.channels.message import ChannelMessage
from bub.runtime import StreamEvent
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope


class Channel(ABC):
Expand Down Expand Up @@ -39,6 +41,15 @@ def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEve
"""Optionally wrap the output stream for this channel."""
return stream

def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
"""Optionally admit or reject an incoming message before processing."""
return None


class Interface(Channel):
"""User-facing inbound/outbound surface managed by the channel runtime."""
Expand Down
174 changes: 112 additions & 62 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,19 @@
from datetime import datetime
from hashlib import md5
from pathlib import Path
from time import monotonic
from typing import Any

from loguru import logger
from prompt_toolkit import PromptSession
from prompt_toolkit.application import run_in_terminal
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.formatted_text import FormattedText
from prompt_toolkit.history import FileHistory
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.patch_stdout import patch_stdout
from rich import get_console
from rich.status import Status
from rich.spinner import SPINNERS
from rich.text import Text
from rich.tree import Tree

Expand All @@ -25,8 +28,12 @@
from bub.channels.message import ChannelMessage
from bub.envelope import field_of
from bub.runtime import StreamEvent
from bub.tools import REGISTRY
from bub.types import MessageHandler
from bub.tools import REGISTRY, tool_call_reporter
from bub.turn_admission import AdmitDecision, TurnSnapshot
from bub.types import Envelope, MessageHandler

_GENERATION_SPINNER: str = SPINNERS["dots"]["frames"] # type: ignore[assignment]
_PROMPT_REFRESH_INTERVAL: float = SPINNERS["dots"]["interval"] / 1000.0 # type: ignore[operator]


class _StreamPrinter:
Expand All @@ -36,89 +43,91 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking:
self._expand_thinking = expand_thinking
self._reasoning_chars = 0
self._reasoning_streaming = False
self._reasoning_status: Status | None = None
self.head_printed = False

def render(self, event: StreamEvent) -> bool:
async def render(self, event: StreamEvent) -> bool:
if event.kind == "reasoning":
self._record_reasoning(str(event.data.get("delta", "")))
await self._record_reasoning(str(event.data.get("delta", "")))
return True

if event.kind == "text":
return self._print_content(str(event.data.get("delta", "")))
return await self._print_content(str(event.data.get("delta", "")))
elif event.kind == "tool_call":
self._print_stream_boundary()
await self._print_stream_boundary()
elif event.kind == "final":
self._print_end()
await self._print_end()
return True

def _record_reasoning(self, reasoning: str) -> None:
async def _record_reasoning(self, reasoning: str) -> None:
if not self._expand_thinking:
if self._reasoning_chars == 0:
self._ensure_head()
self._start_reasoning_status()
await self._ensure_head()
self._reasoning_chars += len(reasoning)
return

self._ensure_head()
await self._ensure_head()
if not self._reasoning_streaming:
self._console.print(Text("[-] Thinking", style="dim"))
await self._print(Text("[-] Thinking", style="dim"))
self._reasoning_streaming = True
self._console.print(Text(reasoning, style="dim"), end="", highlight=False)
await self._print(Text(reasoning, style="dim"), end="", highlight=False)

def _print_content(self, content: str) -> bool:
async def _print_content(self, content: str) -> bool:
if not (content.strip() or self.head_printed or self._reasoning_chars or self._reasoning_streaming):
return False
self._ensure_head()
self._close_reasoning_stream()
self._flush_reasoning()
self._console.print(content, end="", highlight=False)
await self._ensure_head()
await self._close_reasoning_stream()
await self._flush_reasoning()
await self._print(content, end="", highlight=False)
return True

def _print_end(self) -> None:
async def _print_end(self) -> None:
if self._reasoning_chars:
self._ensure_head()
self._flush_reasoning()
await self._ensure_head()
await self._flush_reasoning()
if self.head_printed:
self._console.print("")
await self._print("")

def _print_stream_boundary(self) -> None:
self._close_reasoning_stream()
self._flush_reasoning()
async def _print_stream_boundary(self) -> None:
await self._close_reasoning_stream()
await self._flush_reasoning()
if self.head_printed:
self._console.print("")
await self._print("")

def _ensure_head(self) -> None:
async def _ensure_head(self) -> None:
if self.head_printed:
return
self._print_head()
await run_in_terminal(self._print_head, render_cli_done=False)
self.head_printed = True

def _close_reasoning_stream(self) -> None:
async def _close_reasoning_stream(self) -> None:
if not self._reasoning_streaming:
return
self._console.print("")
await self._print("")
self._reasoning_streaming = False

def _flush_reasoning(self) -> None:
async def _flush_reasoning(self) -> None:
if self._reasoning_chars <= 0:
return
self._stop_reasoning_status()
label = Text(f"[+] Thinking ({self._reasoning_chars} chars hidden)", style="dim")
self._console.print(Tree(label, guide_style="dim", expanded=False))
await self._print(Tree(label, guide_style="dim", expanded=False))
self._reasoning_chars = 0

def _start_reasoning_status(self) -> None:
if self._reasoning_status is not None:
return
self._reasoning_status = self._console.status(Text("Thinking", style="dim"), spinner_style="dim")
self._reasoning_status.start()
async def _print(self, *args: Any, **kwargs: Any) -> None:
await run_in_terminal(lambda: self._console.print(*args, **kwargs), render_cli_done=False)

def _stop_reasoning_status(self) -> None:
if self._reasoning_status is None:
return
self._reasoning_status.stop()
self._reasoning_status = None

class _CliToolCallReporter:
def __init__(self, renderer: CliRenderer) -> None:
self._renderer = renderer

def start(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None:
self._renderer.tool_call_start(name=name, args=args, kwargs=kwargs)

def success(self, name: str, result: object, elapsed_ms: float) -> None:
self._renderer.tool_call_success(name=name, result=result, elapsed_ms=elapsed_ms)

def error(self, name: str, error: BaseException, elapsed_ms: float) -> None:
self._renderer.tool_call_error(name=name, error=error, elapsed_ms=elapsed_ms)


class CliChannel(Interface):
Expand All @@ -137,16 +146,16 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None:
}
self._mode = "agent" # or "shell"
self._expand_thinking = False
self._llm_loop_running = False
self._main_task: asyncio.Task | None = None
self._renderer = CliRenderer(get_console())
self._last_tape_info: TapeInfo | None = None
self._workspace = self._agent.framework.workspace
self._prompt = self._build_prompt(self._workspace)

def _install_log_sink(self) -> int:
def _suppress_logs(self) -> None:
with contextlib.suppress(ValueError):
logger.remove()
return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}")

async def _refresh_tape_info(self) -> None:
tape = self._agent.tape.session_tape(self._message_template["session_id"], self._workspace)
Expand All @@ -160,7 +169,7 @@ def set_metadata(self, session_id: str | None = None, chat_id: str | None = None
self._message_template["chat_id"] = chat_id

async def start(self, stop_event: asyncio.Event) -> None:
self._log_handler_id = self._install_log_sink()
self._suppress_logs()
self._stop_event = stop_event
self._main_task = asyncio.create_task(self._main_loop())

Expand All @@ -169,8 +178,6 @@ async def stop(self) -> None:
self._main_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._main_task
with contextlib.suppress(ValueError):
logger.remove(self._log_handler_id)

async def send(self, message: ChannelMessage) -> None:
if message.kind != "error":
Expand All @@ -180,12 +187,16 @@ async def send(self, message: ChannelMessage) -> None:
async def _main_loop(self) -> None:
self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace))
await self._refresh_tape_info()
request_completed = asyncio.Event()

while not self._stop_event.is_set():
try:
with patch_stdout(raw=True):
raw = (await self._prompt.prompt_async(self._prompt_message())).strip()
raw = (
await self._prompt.prompt_async(
self._prompt_message,
refresh_interval=_PROMPT_REFRESH_INTERVAL,
)
).strip()
except KeyboardInterrupt:
self._renderer.info("Interrupted. Use ',quit' to exit.")
continue
Expand All @@ -197,32 +208,38 @@ async def _main_loop(self) -> None:
if raw in {",quit", ",exit"}:
break
if raw == ",thinking":
self._renderer.input_echo(self._prompt_label(), raw)
self._toggle_thinking()
continue

request = self._normalize_input(raw)
self._renderer.input_echo(self._prompt_label(), raw)

message = ChannelMessage(
session_id=self._message_template["session_id"],
channel=self._message_template["channel"],
chat_id=self._message_template["chat_id"],
content=request,
lifespan=self.message_lifespan(request_completed),
lifespan=self.message_lifespan(),
)
await self._on_receive(message)
await request_completed.wait()
request_completed.clear()
self._set_llm_loop_running(True)
try:
await self._on_receive(message)
except Exception:
self._set_llm_loop_running(False)
raise

self._renderer.info("Bye.")
self._stop_event.set()

@contextlib.asynccontextmanager
async def message_lifespan(self, request_completed: asyncio.Event) -> AsyncGenerator[None, None]:
async def message_lifespan(self) -> AsyncGenerator[None, None]:
self._set_llm_loop_running(True)
try:
yield
finally:
await self._refresh_tape_info()
request_completed.set()
self._set_llm_loop_running(False)

def _normalize_input(self, raw: str) -> str:
if self._mode != "shell":
Expand All @@ -232,9 +249,20 @@ def _normalize_input(self, raw: str) -> str:
return f",{raw}"

def _prompt_message(self) -> FormattedText:
prompt = self._prompt_label()
if not self._llm_loop_running:
return FormattedText([("bold", prompt)])
index = int(monotonic() / _PROMPT_REFRESH_INTERVAL) % len(_GENERATION_SPINNER)
spinner = _GENERATION_SPINNER[index]
return FormattedText([
("blue", f"\n{spinner} Generating\n"),
("bold", prompt),
])

def _prompt_label(self) -> str:
cwd = Path.cwd().name
symbol = ">" if self._mode == "agent" else ","
return FormattedText([("bold", f"{cwd} {symbol} ")])
return f"{cwd} {symbol} "

async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
Expand All @@ -245,9 +273,10 @@ async def stream_events(
print_head=lambda: self._renderer.print_head(message.kind),
expand_thinking=self._expand_thinking,
)
async for event in stream:
if printer.render(event):
yield event
with tool_call_reporter(_CliToolCallReporter(self._renderer)):
async for event in stream:
if await printer.render(event):
yield event

def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
Expand All @@ -272,6 +301,7 @@ def _tool_sort_key(tool_name: str) -> tuple[str, str]:
key_bindings=kb,
history=history,
bottom_toolbar=self._render_bottom_toolbar,
erase_when_done=True,
)

def _render_bottom_toolbar(self) -> FormattedText:
Expand All @@ -292,7 +322,27 @@ def _toggle_thinking(self) -> None:
state = "expanded" if self._expand_thinking else "collapsed"
self._renderer.info(f"Thinking output is now {state}.")

def _invalidate_prompt(self) -> None:
with contextlib.suppress(Exception):
self._prompt.app.invalidate()

def _set_llm_loop_running(self, running: bool) -> None:
if self._llm_loop_running == running:
return
self._llm_loop_running = running
self._invalidate_prompt()

@staticmethod
def _history_file(home: Path, workspace: Path) -> Path:
workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest()
return home / "history" / f"{workspace_hash}.history"

def admit_message(
self,
session_id: str,
message: Envelope,
turn: TurnSnapshot,
) -> AdmitDecision | None:
if not turn.is_running:
return None
return AdmitDecision("follow_up", reason="cli session is already generating")
Loading
Loading