diff --git a/src/bub/builtin/hook_impl.py b/src/bub/builtin/hook_impl.py index fead4905..a786fca9 100644 --- a/src/bub/builtin/hook_impl.py +++ b/src/bub/builtin/hook_impl.py @@ -17,6 +17,7 @@ from bub.hookspecs import hookimpl from bub.runtime import AsyncStreamEvents from bub.tape import TapeContext, TapeStore +from bub.turn_admission import AdmitDecision, TurnSnapshot from bub.types import Envelope, MessageHandler, State AGENTS_FILE_NAME = "AGENTS.md" @@ -291,3 +292,15 @@ def provide_tape_store(self) -> TapeStore: @hookimpl def build_tape_context(self) -> TapeContext: return default_tape_context() + + @hookimpl + def admit_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: + outbound_router = self.framework._outbound_router + if outbound_router is None: + return None + return outbound_router.admit_channel_message(session_id=session_id, message=message, turn=turn) diff --git a/src/bub/channels/base.py b/src/bub/channels/base.py index 920469ec..83bbb8b0 100644 --- a/src/bub/channels/base.py +++ b/src/bub/channels/base.py @@ -5,6 +5,8 @@ from bub.channels.message import ChannelMessage from bub.runtime import StreamEvent +from bub.turn_admission import AdmitDecision, TurnSnapshot +from bub.types import Envelope class Channel(ABC): @@ -39,6 +41,15 @@ def stream_events(self, message: ChannelMessage, stream: AsyncIterable[StreamEve """Optionally wrap the output stream for this channel.""" return stream + def admit_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: + """Optionally admit or reject an incoming message before processing.""" + return None + class Interface(Channel): """User-facing inbound/outbound surface managed by the channel runtime.""" diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 9119d3be..57ed738e 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -4,16 +4,19 @@ from datetime import datetime from hashlib import md5 from pathlib import Path +from time import monotonic +from typing import Any from loguru import logger from prompt_toolkit import PromptSession +from prompt_toolkit.application import run_in_terminal from prompt_toolkit.completion import WordCompleter from prompt_toolkit.formatted_text import FormattedText from prompt_toolkit.history import FileHistory from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.patch_stdout import patch_stdout from rich import get_console -from rich.status import Status +from rich.spinner import SPINNERS from rich.text import Text from rich.tree import Tree @@ -25,8 +28,12 @@ from bub.channels.message import ChannelMessage from bub.envelope import field_of from bub.runtime import StreamEvent -from bub.tools import REGISTRY -from bub.types import MessageHandler +from bub.tools import REGISTRY, tool_call_reporter +from bub.turn_admission import AdmitDecision, TurnSnapshot +from bub.types import Envelope, MessageHandler + +_GENERATION_SPINNER: str = SPINNERS["dots"]["frames"] # type: ignore[assignment] +_PROMPT_REFRESH_INTERVAL: float = SPINNERS["dots"]["interval"] / 1000.0 # type: ignore[operator] class _StreamPrinter: @@ -36,89 +43,91 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking: self._expand_thinking = expand_thinking self._reasoning_chars = 0 self._reasoning_streaming = False - self._reasoning_status: Status | None = None self.head_printed = False - def render(self, event: StreamEvent) -> bool: + async def render(self, event: StreamEvent) -> bool: if event.kind == "reasoning": - self._record_reasoning(str(event.data.get("delta", ""))) + await self._record_reasoning(str(event.data.get("delta", ""))) return True if event.kind == "text": - return self._print_content(str(event.data.get("delta", ""))) + return await self._print_content(str(event.data.get("delta", ""))) elif event.kind == "tool_call": - self._print_stream_boundary() + await self._print_stream_boundary() elif event.kind == "final": - self._print_end() + await self._print_end() return True - def _record_reasoning(self, reasoning: str) -> None: + async def _record_reasoning(self, reasoning: str) -> None: if not self._expand_thinking: if self._reasoning_chars == 0: - self._ensure_head() - self._start_reasoning_status() + await self._ensure_head() self._reasoning_chars += len(reasoning) return - self._ensure_head() + await self._ensure_head() if not self._reasoning_streaming: - self._console.print(Text("[-] Thinking", style="dim")) + await self._print(Text("[-] Thinking", style="dim")) self._reasoning_streaming = True - self._console.print(Text(reasoning, style="dim"), end="", highlight=False) + await self._print(Text(reasoning, style="dim"), end="", highlight=False) - def _print_content(self, content: str) -> bool: + async def _print_content(self, content: str) -> bool: if not (content.strip() or self.head_printed or self._reasoning_chars or self._reasoning_streaming): return False - self._ensure_head() - self._close_reasoning_stream() - self._flush_reasoning() - self._console.print(content, end="", highlight=False) + await self._ensure_head() + await self._close_reasoning_stream() + await self._flush_reasoning() + await self._print(content, end="", highlight=False) return True - def _print_end(self) -> None: + async def _print_end(self) -> None: if self._reasoning_chars: - self._ensure_head() - self._flush_reasoning() + await self._ensure_head() + await self._flush_reasoning() if self.head_printed: - self._console.print("") + await self._print("") - def _print_stream_boundary(self) -> None: - self._close_reasoning_stream() - self._flush_reasoning() + async def _print_stream_boundary(self) -> None: + await self._close_reasoning_stream() + await self._flush_reasoning() if self.head_printed: - self._console.print("") + await self._print("") - def _ensure_head(self) -> None: + async def _ensure_head(self) -> None: if self.head_printed: return - self._print_head() + await run_in_terminal(self._print_head, render_cli_done=False) self.head_printed = True - def _close_reasoning_stream(self) -> None: + async def _close_reasoning_stream(self) -> None: if not self._reasoning_streaming: return - self._console.print("") + await self._print("") self._reasoning_streaming = False - def _flush_reasoning(self) -> None: + async def _flush_reasoning(self) -> None: if self._reasoning_chars <= 0: return - self._stop_reasoning_status() label = Text(f"[+] Thinking ({self._reasoning_chars} chars hidden)", style="dim") - self._console.print(Tree(label, guide_style="dim", expanded=False)) + await self._print(Tree(label, guide_style="dim", expanded=False)) self._reasoning_chars = 0 - def _start_reasoning_status(self) -> None: - if self._reasoning_status is not None: - return - self._reasoning_status = self._console.status(Text("Thinking", style="dim"), spinner_style="dim") - self._reasoning_status.start() + async def _print(self, *args: Any, **kwargs: Any) -> None: + await run_in_terminal(lambda: self._console.print(*args, **kwargs), render_cli_done=False) - def _stop_reasoning_status(self) -> None: - if self._reasoning_status is None: - return - self._reasoning_status.stop() - self._reasoning_status = None + +class _CliToolCallReporter: + def __init__(self, renderer: CliRenderer) -> None: + self._renderer = renderer + + def start(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + self._renderer.tool_call_start(name=name, args=args, kwargs=kwargs) + + def success(self, name: str, result: object, elapsed_ms: float) -> None: + self._renderer.tool_call_success(name=name, result=result, elapsed_ms=elapsed_ms) + + def error(self, name: str, error: BaseException, elapsed_ms: float) -> None: + self._renderer.tool_call_error(name=name, error=error, elapsed_ms=elapsed_ms) class CliChannel(Interface): @@ -137,16 +146,16 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None: } self._mode = "agent" # or "shell" self._expand_thinking = False + self._llm_loop_running = False self._main_task: asyncio.Task | None = None self._renderer = CliRenderer(get_console()) self._last_tape_info: TapeInfo | None = None self._workspace = self._agent.framework.workspace self._prompt = self._build_prompt(self._workspace) - def _install_log_sink(self) -> int: + def _suppress_logs(self) -> None: with contextlib.suppress(ValueError): logger.remove() - return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}") async def _refresh_tape_info(self) -> None: tape = self._agent.tape.session_tape(self._message_template["session_id"], self._workspace) @@ -160,7 +169,7 @@ def set_metadata(self, session_id: str | None = None, chat_id: str | None = None self._message_template["chat_id"] = chat_id async def start(self, stop_event: asyncio.Event) -> None: - self._log_handler_id = self._install_log_sink() + self._suppress_logs() self._stop_event = stop_event self._main_task = asyncio.create_task(self._main_loop()) @@ -169,8 +178,6 @@ async def stop(self) -> None: self._main_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._main_task - with contextlib.suppress(ValueError): - logger.remove(self._log_handler_id) async def send(self, message: ChannelMessage) -> None: if message.kind != "error": @@ -180,12 +187,16 @@ async def send(self, message: ChannelMessage) -> None: async def _main_loop(self) -> None: self._renderer.welcome(model=self._agent.settings.model, workspace=str(self._workspace)) await self._refresh_tape_info() - request_completed = asyncio.Event() while not self._stop_event.is_set(): try: with patch_stdout(raw=True): - raw = (await self._prompt.prompt_async(self._prompt_message())).strip() + raw = ( + await self._prompt.prompt_async( + self._prompt_message, + refresh_interval=_PROMPT_REFRESH_INTERVAL, + ) + ).strip() except KeyboardInterrupt: self._renderer.info("Interrupted. Use ',quit' to exit.") continue @@ -197,32 +208,38 @@ async def _main_loop(self) -> None: if raw in {",quit", ",exit"}: break if raw == ",thinking": + self._renderer.input_echo(self._prompt_label(), raw) self._toggle_thinking() continue request = self._normalize_input(raw) + self._renderer.input_echo(self._prompt_label(), raw) message = ChannelMessage( session_id=self._message_template["session_id"], channel=self._message_template["channel"], chat_id=self._message_template["chat_id"], content=request, - lifespan=self.message_lifespan(request_completed), + lifespan=self.message_lifespan(), ) - await self._on_receive(message) - await request_completed.wait() - request_completed.clear() + self._set_llm_loop_running(True) + try: + await self._on_receive(message) + except Exception: + self._set_llm_loop_running(False) + raise self._renderer.info("Bye.") self._stop_event.set() @contextlib.asynccontextmanager - async def message_lifespan(self, request_completed: asyncio.Event) -> AsyncGenerator[None, None]: + async def message_lifespan(self) -> AsyncGenerator[None, None]: + self._set_llm_loop_running(True) try: yield finally: await self._refresh_tape_info() - request_completed.set() + self._set_llm_loop_running(False) def _normalize_input(self, raw: str) -> str: if self._mode != "shell": @@ -232,9 +249,20 @@ def _normalize_input(self, raw: str) -> str: return f",{raw}" def _prompt_message(self) -> FormattedText: + prompt = self._prompt_label() + if not self._llm_loop_running: + return FormattedText([("bold", prompt)]) + index = int(monotonic() / _PROMPT_REFRESH_INTERVAL) % len(_GENERATION_SPINNER) + spinner = _GENERATION_SPINNER[index] + return FormattedText([ + ("blue", f"\n{spinner} Generating\n"), + ("bold", prompt), + ]) + + def _prompt_label(self) -> str: cwd = Path.cwd().name symbol = ">" if self._mode == "agent" else "," - return FormattedText([("bold", f"{cwd} {symbol} ")]) + return f"{cwd} {symbol} " async def stream_events( self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] @@ -245,9 +273,10 @@ async def stream_events( print_head=lambda: self._renderer.print_head(message.kind), expand_thinking=self._expand_thinking, ) - async for event in stream: - if printer.render(event): - yield event + with tool_call_reporter(_CliToolCallReporter(self._renderer)): + async for event in stream: + if await printer.render(event): + yield event def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() @@ -272,6 +301,7 @@ def _tool_sort_key(tool_name: str) -> tuple[str, str]: key_bindings=kb, history=history, bottom_toolbar=self._render_bottom_toolbar, + erase_when_done=True, ) def _render_bottom_toolbar(self) -> FormattedText: @@ -292,7 +322,27 @@ def _toggle_thinking(self) -> None: state = "expanded" if self._expand_thinking else "collapsed" self._renderer.info(f"Thinking output is now {state}.") + def _invalidate_prompt(self) -> None: + with contextlib.suppress(Exception): + self._prompt.app.invalidate() + + def _set_llm_loop_running(self, running: bool) -> None: + if self._llm_loop_running == running: + return + self._llm_loop_running = running + self._invalidate_prompt() + @staticmethod def _history_file(home: Path, workspace: Path) -> Path: workspace_hash = md5(str(workspace).encode("utf-8"), usedforsecurity=False).hexdigest() return home / "history" / f"{workspace_hash}.history" + + def admit_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: + if not turn.is_running: + return None + return AdmitDecision("follow_up", reason="cli session is already generating") diff --git a/src/bub/channels/cli/renderer.py b/src/bub/channels/cli/renderer.py index 7fd2e442..97eec2b9 100644 --- a/src/bub/channels/cli/renderer.py +++ b/src/bub/channels/cli/renderer.py @@ -2,7 +2,9 @@ from __future__ import annotations +import json from dataclasses import dataclass +from typing import Any from rich.console import Console from rich.panel import Panel @@ -10,6 +12,9 @@ from bub.channels.message import MessageKind +MAX_TOOL_PAYLOAD_CHARS = 4000 +MAX_TOOL_CALL_CHARS = 1200 + @dataclass class CliRenderer: @@ -47,15 +52,79 @@ def error(self, text: str) -> None: return self.console.print(f"[red bold]Error >[/]\n{text}") + def input_echo(self, prompt: str, text: str) -> None: + if not text.strip(): + return + self.console.print(f"[bold]{prompt}[/]{text}", new_line_start=True) + + def tool_call_start(self, *, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + self.console.print(Text(_format_tool_call(name, args, kwargs), style="magenta"), new_line_start=True) + + def tool_call_success(self, *, name: str, result: Any, elapsed_ms: float) -> None: + rendered = _format_tool_payload(result) + self._tool_result(f"completed in {elapsed_ms:.0f} ms", rendered, style="green") + + def tool_call_error(self, *, name: str, error: BaseException, elapsed_ms: float) -> None: + rendered = _format_tool_payload({"type": error.__class__.__name__, "message": str(error)}) + self._tool_result(f"failed in {elapsed_ms:.0f} ms", rendered, style="red") + def print_head(self, kind: MessageKind) -> None: if kind == "command": - self.console.print("[cyan bold]Command >[/]") + self.console.print("[cyan bold]Command >[/]", new_line_start=True) elif kind == "error": - self.console.print("[red bold]Error >[/]") + self.console.print("[red bold]Error >[/]", new_line_start=True) else: - self.console.print("[blue bold]Assistant >[/]") + self.console.print("[blue bold]Assistant >[/]", new_line_start=True) def log(self, message: object) -> None: text = str(message).rstrip() if text: self.console.print(text, new_line_start=True) + + def _tool_result(self, label: str, rendered: str, *, style: str) -> None: + lines = rendered.splitlines() or [""] + self.console.print(Text(f" ⎿ {label}", style=style), highlight=False) + for line in lines: + self.console.print(Text(f" {line}", style="bright_black"), highlight=False) + + +def _format_tool_call(name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> str: + params = _format_tool_params(args, kwargs) + if not params: + return f"● {name}()" + + inline = f"● {name}({', '.join(params)})" + if len(inline) <= 120 and "\n" not in inline: + return inline + + body = "\n".join(f" {param}," for param in params) + return _truncate(f"● {name}(\n{body}\n)", max_chars=MAX_TOOL_CALL_CHARS) + + +def _format_tool_params(args: tuple[Any, ...], kwargs: dict[str, Any]) -> list[str]: + params: list[str] = [] + for index, value in enumerate(args, start=1): + params.append(f"arg{index}: {_format_tool_value(value)}") + for key, value in kwargs.items(): + params.append(f"{key}: {_format_tool_value(value)}") + return params + + +def _format_tool_payload(payload: Any, *, max_chars: int = MAX_TOOL_PAYLOAD_CHARS) -> str: + return _format_tool_value(payload, max_chars=max_chars, indent=2) + + +def _format_tool_value(payload: Any, *, max_chars: int = 800, indent: int | None = None) -> str: + try: + rendered = json.dumps(payload, ensure_ascii=False, indent=indent, default=repr) + except TypeError: + rendered = repr(payload) + return _truncate(rendered, max_chars=max_chars) + + +def _truncate(text: str, *, max_chars: int) -> str: + if len(text) <= max_chars: + return text + omitted = len(text) - max_chars + suffix = f"\n... truncated {omitted} chars" + return text[: max_chars - len(suffix)].rstrip() + suffix diff --git a/src/bub/channels/manager.py b/src/bub/channels/manager.py index 313cc2ce..b3ec6726 100644 --- a/src/bub/channels/manager.py +++ b/src/bub/channels/manager.py @@ -15,7 +15,7 @@ from bub.envelope import content_of, field_of from bub.framework import BubFramework from bub.runtime import StreamEvent -from bub.turn_admission import AdmitDecision, SessionTurnController +from bub.turn_admission import AdmitDecision, SessionTurnController, TurnSnapshot from bub.types import Envelope, MessageHandler from bub.utils import wait_until_stopped @@ -341,3 +341,17 @@ async def shutdown(self) -> None: logger.info(f"channel.manager cancelled {count} in-flight tasks") for channel in self.enabled_channels(): await channel.stop() + + def admit_channel_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: + channel_name = field_of(message, "channel") + if channel_name is None: + return None + channel = self.get_channel(str(channel_name)) + if channel is None: + return None + return channel.admit_message(session_id=session_id, message=message, turn=turn) diff --git a/src/bub/tools.py b/src/bub/tools.py index 3a68358e..5735deee 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -1,12 +1,14 @@ from __future__ import annotations import asyncio +import contextlib +import contextvars import inspect import json import time from collections.abc import Callable, Sequence from dataclasses import dataclass, field, replace -from typing import Any, overload +from typing import Any, Protocol, overload from loguru import logger from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError, validate_call @@ -132,6 +134,28 @@ class ToolExecution: error: BubError | None = None +class ToolCallReporter(Protocol): + def start(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: ... + + def success(self, name: str, result: Any, elapsed_ms: float) -> None: ... + + def error(self, name: str, error: BaseException, elapsed_ms: float) -> None: ... + + +_TOOL_CALL_REPORTER: contextvars.ContextVar[ToolCallReporter | None] = contextvars.ContextVar( + "bub_tool_call_reporter", default=None +) + + +@contextlib.contextmanager +def tool_call_reporter(reporter: ToolCallReporter): + token = _TOOL_CALL_REPORTER.set(reporter) + try: + yield + finally: + _TOOL_CALL_REPORTER.reset(token) + + class ToolExecutor: """Execute already-resolved Bub tool invocations.""" @@ -220,20 +244,30 @@ async def wrapped(*args, **kwargs): call_kwargs = kwargs.copy() if tool.context: call_kwargs.pop("context", None) - _log_tool_call(tool.name, args, call_kwargs) + reporter = _TOOL_CALL_REPORTER.get() + if reporter is None: + _log_tool_call(tool.name, args, call_kwargs) + else: + reporter.start(tool.name, args, call_kwargs) start = time.monotonic() try: result = handler(*args, **kwargs) if inspect.isawaitable(result): result = await result - except Exception: + except Exception as exc: elapsed_time = (time.monotonic() - start) * 1000 - logger.exception("tool.call.error name={} elapsed_time={:.2f}ms", tool.name, elapsed_time) + if reporter is None: + logger.exception("tool.call.error name={} elapsed_time={:.2f}ms", tool.name, elapsed_time) + else: + reporter.error(tool.name, exc, elapsed_time) raise else: elapsed_time = (time.monotonic() - start) * 1000 - logger.info("tool.call.success name={} elapsed_time={:.2f}ms", tool.name, elapsed_time) + if reporter is None: + logger.info("tool.call.success name={} elapsed_time={:.2f}ms", tool.name, elapsed_time) + else: + reporter.success(tool.name, result, elapsed_time) return result return replace(tool, handler=wrapped) diff --git a/src/bub/types.py b/src/bub/types.py index 65ad88af..7fb661a9 100644 --- a/src/bub/types.py +++ b/src/bub/types.py @@ -4,10 +4,13 @@ from collections.abc import AsyncIterable, Callable, Coroutine from dataclasses import dataclass, field -from typing import Any, Protocol +from typing import TYPE_CHECKING, Any, Protocol from bub.runtime import StreamEvent +if TYPE_CHECKING: + from bub.turn_admission import AdmitDecision, TurnSnapshot + type Envelope = Any type State = dict[str, Any] type MessageHandler = Callable[[Envelope], Coroutine[Any, Any, None]] @@ -18,6 +21,12 @@ class OutboundChannelRouter(Protocol): async def dispatch_output(self, message: Envelope) -> bool: ... def wrap_stream(self, message: Envelope, stream: AsyncIterable[StreamEvent]) -> AsyncIterable[StreamEvent]: ... async def quit(self, session_id: str) -> None: ... + def admit_channel_message( + self, + session_id: str, + message: Envelope, + turn: TurnSnapshot, + ) -> AdmitDecision | None: ... @dataclass(frozen=True) diff --git a/tests/test_channels.py b/tests/test_channels.py index 079e1db6..df055eb3 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -16,7 +16,7 @@ from bub.channels.message import ChannelMessage from bub.channels.telegram import BubMessageFilter, TelegramChannel, TelegramMessageParser from bub.runtime import StreamEvent -from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer +from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer, TurnSnapshot def _load_channel_config( @@ -292,6 +292,126 @@ def test_channel_manager_selects_real_channel_types(load_config) -> None: assert [channel.name for channel in manager.enabled_channels()] == ["telegram"] +@pytest.mark.asyncio +async def test_cli_channel_accepts_input_while_previous_message_is_running() -> None: + received: list[ChannelMessage] = [] + + class FakePrompt: + def __init__(self) -> None: + self.inputs = iter(["first", "second", ",quit"]) + self.refresh_intervals: list[float | None] = [] + self.messages: list[str] = [] + self.received_callables: list[bool] = [] + + async def prompt_async(self, message, *, refresh_interval=None): + self.refresh_intervals.append(refresh_interval) + self.received_callables.append(callable(message)) + rendered = message() if callable(message) else message + self.messages.append("".join(part for _, part in rendered)) + return next(self.inputs) + + async def on_receive(message: ChannelMessage) -> None: + received.append(message) + + channel = CliChannel.__new__(CliChannel) + channel._on_receive = on_receive + channel._stop_event = asyncio.Event() + channel._message_template = { + "chat_id": "cli_chat", + "channel": "cli", + "session_id": "cli_session", + } + channel._agent = SimpleNamespace(settings=SimpleNamespace(model="test-model")) + channel._workspace = Path.cwd() + channel._mode = "agent" + channel._llm_loop_running = False + channel._prompt = FakePrompt() + echoed: list[tuple[str, str]] = [] + channel._renderer = SimpleNamespace( + welcome=lambda **kwargs: None, + info=lambda message: None, + input_echo=lambda prompt, text: echoed.append((prompt, text)), + ) + channel._refresh_tape_info = _async_return(None) + + await asyncio.wait_for(channel._main_loop(), timeout=1) + + import bub.channels.cli as cli_module + + assert [message.content for message in received] == ["first", "second"] + assert channel._prompt.refresh_intervals == [cli_module._PROMPT_REFRESH_INTERVAL] * 3 + assert channel._prompt.received_callables == [True, True, True] + assert "Generating\n" not in channel._prompt.messages[0] + assert "Generating\n" in channel._prompt.messages[1] + assert echoed == [(f"{Path.cwd().name} > ", "first"), (f"{Path.cwd().name} > ", "second")] + assert all(message.lifespan is not None for message in received) + + +def test_cli_channel_build_prompt_erases_submitted_prompt(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: + captured: dict[str, object] = {} + + class FakePromptSession: + def __init__(self, **kwargs) -> None: + captured.update(kwargs) + + monkeypatch.setattr("bub.channels.cli.PromptSession", FakePromptSession) + channel = CliChannel.__new__(CliChannel) + channel._mode = "agent" + channel._expand_thinking = False + channel._agent = SimpleNamespace(settings=SimpleNamespace(model="test-model")) + channel._last_tape_info = None + + prompt = channel._build_prompt(tmp_path) + + assert isinstance(prompt, FakePromptSession) + assert captured["erase_when_done"] is True + + +def test_cli_channel_generating_spinner_renders_above_input_not_toolbar(monkeypatch: pytest.MonkeyPatch) -> None: + channel = CliChannel.__new__(CliChannel) + channel._llm_loop_running = True + channel._mode = "agent" + channel._expand_thinking = False + channel._last_tape_info = None + channel._agent = SimpleNamespace(settings=SimpleNamespace(model="test-model")) + + prompt_text = "".join(part for _, part in channel._prompt_message()) + toolbar_text = "".join(part for _, part in channel._render_bottom_toolbar()) + + assert "\n" in prompt_text + assert "Generating\n" in prompt_text + assert prompt_text.endswith(f"{Path.cwd().name} > ") + assert "Generating" not in toolbar_text + + import bub.channels.cli as cli_module + + monkeypatch.setattr(cli_module, "monotonic", lambda: 0.0) + first_frame = "".join(part for _, part in channel._prompt_message()) + monkeypatch.setattr(cli_module, "monotonic", lambda: 0.2) + second_frame = "".join(part for _, part in channel._prompt_message()) + + assert first_frame != second_frame + + +def test_cli_channel_admit_message_queues_follow_up_when_turn_is_running() -> None: + channel = CliChannel.__new__(CliChannel) + turn = TurnSnapshot( + session_id="cli_session", + is_running=True, + running_count=1, + pending_count=0, + steering_count=0, + ) + + decision = channel.admit_message( + session_id="cli_session", + message=_message("second", channel="cli", session_id="cli_session"), + turn=turn, + ) + + assert decision == AdmitDecision("follow_up", reason="cli session is already generating") + + @pytest.mark.asyncio async def test_channel_manager_on_receive_uses_buffer_for_debounced_channel( monkeypatch: pytest.MonkeyPatch, load_config @@ -687,6 +807,40 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]: assert [event.kind for event in yielded] == ["text", "text", "final"] +@pytest.mark.asyncio +async def test_cli_channel_collapsed_reasoning_does_not_start_status_spinner( + monkeypatch: pytest.MonkeyPatch, +) -> None: + channel = CliChannel.__new__(CliChannel) + channel._renderer = SimpleNamespace(print_head=lambda kind: None) + channel._expand_thinking = False + printed: list[object] = [] + + def status(*args, **kwargs): + raise AssertionError("status spinner should not start while prompt is active") + + monkeypatch.setattr( + "bub.channels.cli.get_console", + lambda: SimpleNamespace( + print=lambda content, end=None, highlight=None: printed.append(content), + status=status, + ), + ) + + message = _message("ignored", channel="cli", kind="normal", session_id="cli:1") + + async def source() -> asyncio.AsyncIterator[StreamEvent]: + yield StreamEvent("reasoning", {"delta": "hidden"}) + yield StreamEvent("text", {"delta": "hello"}) + yield StreamEvent("final", {}) + + yielded = [event async for event in channel.stream_events(message, source())] + + assert [event.kind for event in yielded] == ["reasoning", "text", "final"] + assert printed + assert "hello" in [str(item) for item in printed] + + def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: home = tmp_path / "home" workspace = tmp_path / "workspace" @@ -706,12 +860,16 @@ def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: ], ) def test_cli_renderer_print_head_uses_message_kind(kind: str, expected: str) -> None: - printed: list[str] = [] - renderer = CliRenderer(SimpleNamespace(print=printed.append)) # type: ignore[arg-type] + printed: list[tuple[str, bool | None]] = [] + + def print_message(message: str, *, new_line_start: bool | None = None) -> None: + printed.append((message, new_line_start)) + + renderer = CliRenderer(SimpleNamespace(print=print_message)) # type: ignore[arg-type] renderer.print_head(kind) # type: ignore[arg-type] - assert printed == [expected] + assert printed == [(expected, True)] def test_bub_message_filter_accepts_private_messages() -> None: diff --git a/tests/test_cli_renderer.py b/tests/test_cli_renderer.py new file mode 100644 index 00000000..675192b3 --- /dev/null +++ b/tests/test_cli_renderer.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +from rich.console import Console + +from bub.channels.cli.renderer import CliRenderer + + +def test_tool_call_renderer_uses_compact_claude_style_output() -> None: + console = Console(record=True, force_terminal=False, width=120) + renderer = CliRenderer(console) + + renderer.tool_call_start(name="demo.echo", args=(), kwargs={"value": "hello"}) + renderer.tool_call_success(name="demo.echo", result={"ok": True}, elapsed_ms=12.3) + + output = console.export_text() + + assert '● demo.echo(value: "hello")' in output + assert " ⎿ completed in 12 ms" in output + assert '"ok": true' in output + assert "Tool call:" not in output + assert "Tool result:" not in output + + +def test_input_echo_prints_submitted_prompt_as_terminal_output() -> None: + console = Console(record=True, force_terminal=False, width=120) + renderer = CliRenderer(console) + + renderer.input_echo("bub > ", "hello") + + output = console.export_text() + + assert "bub > hello" in output diff --git a/tests/test_tools.py b/tests/test_tools.py index 6f770ffd..e1aa0c2c 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -6,7 +6,7 @@ from loguru import logger from pydantic import BaseModel -from bub.tools import REGISTRY, tool +from bub.tools import REGISTRY, tool, tool_call_reporter class EchoInput(BaseModel): @@ -74,6 +74,44 @@ def failing_tool() -> str: assert errors[0].startswith("tool.call.error name=tests.failing_tool elapsed_time=") +@pytest.mark.asyncio +async def test_tool_wrapper_uses_reporter_instead_of_logs(monkeypatch: pytest.MonkeyPatch) -> None: + tool_name = "tests.reported_tool" + REGISTRY.pop(tool_name, None) + logged: list[str] = [] + reported: list[tuple[str, str, Any]] = [] + + def record_log(message: str, *args: Any, **kwargs: Any) -> None: + logged.append(message.format(*args, **kwargs)) + + class Reporter: + def start(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> None: + reported.append(("start", name, {"args": args, "kwargs": kwargs})) + + def success(self, name: str, result: Any, elapsed_ms: float) -> None: + reported.append(("success", name, {"result": result, "elapsed_ms": elapsed_ms})) + + def error(self, name: str, error: BaseException, elapsed_ms: float) -> None: + reported.append(("error", name, {"error": error, "elapsed_ms": elapsed_ms})) + + monkeypatch.setattr(logger, "info", record_log) + monkeypatch.setattr(logger, "exception", record_log) + + @tool(name=tool_name) + def reported_tool(value: str) -> str: + return value.upper() + + with tool_call_reporter(Reporter()): + result = await reported_tool.run("hello") + + assert result == "HELLO" + assert logged == [] + assert reported[0] == ("start", tool_name, {"args": ("hello",), "kwargs": {}}) + assert reported[1][0] == "success" + assert reported[1][1] == tool_name + assert reported[1][2]["result"] == "HELLO" + + @pytest.mark.asyncio async def test_tool_direct_call_registers_wrapped_instance_in_registry() -> None: tool_name = "tests.direct_call"