Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 87 additions & 8 deletions src/bub/channels/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from prompt_toolkit.history import FileHistory
from prompt_toolkit.key_binding import KeyBindings
from prompt_toolkit.patch_stdout import patch_stdout
from prompt_toolkit.utils import get_cwidth
from rich import get_console
from rich.spinner import SPINNERS
from rich.text import Text
Expand Down Expand Up @@ -43,6 +44,9 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking:
self._expand_thinking = expand_thinking
self._reasoning_chars = 0
self._reasoning_streaming = False
self._current_text_line = ""
self._rendered_text_line: str | None = None
self._live_text_rows = 0
self.head_printed = False

async def render(self, event: StreamEvent) -> bool:
Expand Down Expand Up @@ -77,19 +81,23 @@ async def _print_content(self, content: str) -> bool:
await self._ensure_head()
await self._close_reasoning_stream()
await self._flush_reasoning()
await self._print(content, end="", highlight=False)
await self._write_text(content)
return True

async def _print_end(self) -> None:
if self._reasoning_chars:
await self._ensure_head()
await self._flush_reasoning()
if self.head_printed:
if self._current_text_line:
await self._commit_text_line()
elif self.head_printed and not self._live_text_rows:
await self._print("")

async def _print_stream_boundary(self) -> None:
await self._close_reasoning_stream()
await self._flush_reasoning()
if self._current_text_line or self._live_text_rows:
await self._commit_text_line()
if self.head_printed:
await self._print("")

Expand All @@ -112,6 +120,65 @@ async def _flush_reasoning(self) -> None:
await self._print(Tree(label, guide_style="dim", expanded=False))
self._reasoning_chars = 0

async def _write_text(self, text: str) -> None:
parts = text.split("\n")
for index, part in enumerate(parts):
self._current_text_line += part
if index < len(parts) - 1:
await self._commit_text_line()

if self._current_text_line:
await self._render_live_text_line()

async def _commit_text_line(self) -> None:
if self._live_text_rows and self._rendered_text_line == self._current_text_line:
self._current_text_line = ""
self._rendered_text_line = None
self._live_text_rows = 0
return
self._live_text_rows = await self._render_text_line(self._current_text_line)
self._current_text_line = ""
self._rendered_text_line = None
self._live_text_rows = 0

async def commit_live_text(self) -> None:
if self._current_text_line or self._live_text_rows:
await self._commit_text_line()

async def _render_live_text_line(self) -> None:
self._live_text_rows = await self._render_text_line(self._current_text_line)
self._rendered_text_line = self._current_text_line

async def _render_text_line(self, text: str) -> int:
previous_rows = self._live_text_rows
rows = self._display_rows(text)

def render() -> None:
self._rewind_live_text(previous_rows)
self._console.print(f"{text}\n", end="", highlight=False)

await run_in_terminal(render, render_cli_done=False)
return rows

def _display_rows(self, text: str) -> int:
columns = max(1, int(getattr(self._console, "width", 80) or 80))
return max(1, (get_cwidth(text) + columns - 1) // columns)

def _rewind_live_text(self, rows: int) -> None:
if rows <= 0:
return
output = getattr(self._console, "file", None)
if output is None:
return
output.write(f"\x1b[{rows}A\r")
for row in range(rows):
output.write("\x1b[2K")
if row < rows - 1:
output.write("\x1b[1B\r")
if rows > 1:
output.write(f"\x1b[{rows - 1}A\r")
output.flush()

async def _print(self, *args: Any, **kwargs: Any) -> None:
await run_in_terminal(lambda: self._console.print(*args, **kwargs), render_cli_done=False)

Expand Down Expand Up @@ -148,6 +215,7 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None:
self._expand_thinking = False
self._llm_loop_running = False
self._main_task: asyncio.Task | None = None
self._stream_printer: _StreamPrinter | None = None
self._renderer = CliRenderer(get_console())
self._last_tape_info: TapeInfo | None = None
self._workspace = self._agent.framework.workspace
Expand Down Expand Up @@ -208,12 +276,12 @@ async def _main_loop(self) -> None:
if raw in {",quit", ",exit"}:
break
if raw == ",thinking":
self._renderer.input_echo(self._prompt_label(), raw)
await self._echo_input(raw)
self._toggle_thinking()
continue

request = self._normalize_input(raw)
self._renderer.input_echo(self._prompt_label(), raw)
await self._echo_input(raw)

message = ChannelMessage(
session_id=self._message_template["session_id"],
Expand Down Expand Up @@ -264,6 +332,12 @@ def _prompt_label(self) -> str:
symbol = ">" if self._mode == "agent" else ","
return f"{cwd} {symbol} "

async def _echo_input(self, raw: str) -> None:
stream_printer = getattr(self, "_stream_printer", None)
if stream_printer is not None:
await stream_printer.commit_live_text()
self._renderer.input_echo(self._prompt_label(), raw)

async def stream_events(
self, message: ChannelMessage, stream: AsyncIterable[StreamEvent]
) -> AsyncIterable[StreamEvent]:
Expand All @@ -273,10 +347,15 @@ async def stream_events(
print_head=lambda: self._renderer.print_head(message.kind),
expand_thinking=self._expand_thinking,
)
with tool_call_reporter(_CliToolCallReporter(self._renderer)):
async for event in stream:
if await printer.render(event):
yield event
self._stream_printer = printer
try:
with tool_call_reporter(_CliToolCallReporter(self._renderer)):
async for event in stream:
if await printer.render(event):
yield event
finally:
if self._stream_printer is printer:
self._stream_printer = None

def _build_prompt(self, workspace: Path) -> PromptSession[str]:
kb = KeyBindings()
Expand Down
148 changes: 146 additions & 2 deletions tests/test_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@

import asyncio
import contextlib
import os
import pty
import re
import select
import subprocess
import sys
import textwrap
import time
from datetime import datetime
from pathlib import Path
from types import SimpleNamespace
Expand All @@ -18,6 +26,8 @@
from bub.runtime import StreamEvent
from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer, TurnSnapshot

ANSI_RE = re.compile(r"\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[()][A-Za-z])")


def _load_channel_config(
load_config,
Expand All @@ -35,6 +45,32 @@ def _load_channel_config(
load_config(content)


def _read_pty_until_exit(master_fd: int, process: subprocess.Popen[bytes], *, timeout: float = 3.0) -> bytes:
chunks: list[bytes] = []
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if process.poll() is not None:
with contextlib.suppress(OSError):
chunks.append(os.read(master_fd, 65536))
break
readable, _, _ = select.select([master_fd], [], [], 0.05)
if not readable:
continue
try:
chunk = os.read(master_fd, 65536)
except OSError:
break
if not chunk:
break
chunks.append(chunk)
return b"".join(chunks)


def _plain_terminal_text(raw: bytes) -> str:
text = raw.decode(errors="replace")
return ANSI_RE.sub("", text).replace("\r", "\n")


class _FakeChannelMixin:
def __init__(self, name: str, *, needs_debounce: bool = False) -> None:
self.name = name
Expand Down Expand Up @@ -803,10 +839,118 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:
yielded = [event async for event in channel.stream_events(message, source())]

assert heads == ["command"]
assert printed == [("hel", "", False), ("lo", "", False), ("", None, None)]
assert printed == [("hel\n", "", False), ("hello\n", "", False)]
assert [event.kind for event in yielded] == ["text", "text", "final"]


def test_cli_stream_output_does_not_overlap_active_pty_prompt() -> None:
script = textwrap.dedent(
"""
import asyncio

from prompt_toolkit import PromptSession
from prompt_toolkit.patch_stdout import patch_stdout
from rich.console import Console

from bub.channels.cli import _StreamPrinter
from bub.runtime import StreamEvent


async def main():
console = Console(force_terminal=True, color_system=None, width=80)
printer = _StreamPrinter(
console=console,
print_head=lambda: console.print("Assistant >"),
expand_thinking=False,
)
session = PromptSession(erase_when_done=True)

async def stream():
chunks = [
"春风一夜入江城\\n",
"细雨无声湿客",
"程\\n",
"莫问归帆何处",
"去\\n",
"明朝山色满",
"前庭",
]
for index, chunk in enumerate(chunks):
await asyncio.sleep(0.03)
await printer.render(StreamEvent("text", {"delta": chunk}))
if index == 3:
await printer.commit_live_text()
console.print("bub > steer now")
await asyncio.sleep(0.03)
await printer.render(StreamEvent("final", {}))

task = asyncio.create_task(stream())
with patch_stdout(raw=True):
await session.prompt_async(
lambda: [("", "\\n* Generating\\nbub > ")],
refresh_interval=0.02,
)
await task


asyncio.run(main())
"""
)
master_fd, slave_fd = pty.openpty()
env = os.environ.copy()
env["PYTHONPATH"] = f"{Path.cwd() / 'src'}{os.pathsep}{env.get('PYTHONPATH', '')}"
process = subprocess.Popen(
[sys.executable, "-c", script],
stdin=slave_fd,
stdout=slave_fd,
stderr=slave_fd,
cwd=Path.cwd(),
env=env,
close_fds=True,
)
os.close(slave_fd)
try:
time.sleep(0.25)
os.write(master_fd, b"next\n")
raw_output = _read_pty_until_exit(master_fd, process)
finally:
if process.poll() is None:
process.terminate()
with contextlib.suppress(subprocess.TimeoutExpired):
process.wait(timeout=1)
os.close(master_fd)

assert process.wait(timeout=1) == 0, raw_output.decode(errors="replace")
output = _plain_terminal_text(raw_output)

assert "春风一夜入江城" in output
assert "细雨无声湿客程" in output
assert "莫问归帆何处" in output
assert "去" in output
assert "明朝山色满前庭" in output
assert "bub > steer now" in output
assert "明朝山色满前庭bub >" not in output
assert "明朝山色满前庭* Generating" not in output


@pytest.mark.asyncio
async def test_cli_channel_input_echo_commits_active_stream_line() -> None:
channel = CliChannel.__new__(CliChannel)
calls: list[str] = []

class FakeStreamPrinter:
async def commit_live_text(self) -> None:
calls.append("commit")

channel._stream_printer = FakeStreamPrinter()
channel._mode = "agent"
channel._renderer = SimpleNamespace(input_echo=lambda prompt, text: calls.append(f"echo:{text}"))

await channel._echo_input("steer now")

assert calls == ["commit", "echo:steer now"]


@pytest.mark.asyncio
async def test_cli_channel_collapsed_reasoning_does_not_start_status_spinner(
monkeypatch: pytest.MonkeyPatch,
Expand Down Expand Up @@ -838,7 +982,7 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]:

assert [event.kind for event in yielded] == ["reasoning", "text", "final"]
assert printed
assert "hello" in [str(item) for item in printed]
assert any("hello" in str(item) for item in printed)


def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None:
Expand Down
Loading