From 8268777ac8ad0b5b6c42994df4714a4478081eba Mon Sep 17 00:00:00 2001 From: Frost Ming Date: Wed, 17 Jun 2026 17:26:32 +0800 Subject: [PATCH] fix(channel/ci): fix stream output mix Signed-off-by: Frost Ming --- src/bub/channels/cli/__init__.py | 95 ++++++++++++++++++-- tests/test_channels.py | 148 ++++++++++++++++++++++++++++++- 2 files changed, 233 insertions(+), 10 deletions(-) diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 57ed738e..a277ecbc 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -15,6 +15,7 @@ from prompt_toolkit.history import FileHistory from prompt_toolkit.key_binding import KeyBindings from prompt_toolkit.patch_stdout import patch_stdout +from prompt_toolkit.utils import get_cwidth from rich import get_console from rich.spinner import SPINNERS from rich.text import Text @@ -43,6 +44,9 @@ def __init__(self, *, console, print_head: Callable[[], None], expand_thinking: self._expand_thinking = expand_thinking self._reasoning_chars = 0 self._reasoning_streaming = False + self._current_text_line = "" + self._rendered_text_line: str | None = None + self._live_text_rows = 0 self.head_printed = False async def render(self, event: StreamEvent) -> bool: @@ -77,19 +81,23 @@ async def _print_content(self, content: str) -> bool: await self._ensure_head() await self._close_reasoning_stream() await self._flush_reasoning() - await self._print(content, end="", highlight=False) + await self._write_text(content) return True async def _print_end(self) -> None: if self._reasoning_chars: await self._ensure_head() await self._flush_reasoning() - if self.head_printed: + if self._current_text_line: + await self._commit_text_line() + elif self.head_printed and not self._live_text_rows: await self._print("") async def _print_stream_boundary(self) -> None: await self._close_reasoning_stream() await self._flush_reasoning() + if self._current_text_line or self._live_text_rows: + await self._commit_text_line() if self.head_printed: await self._print("") @@ -112,6 +120,65 @@ async def _flush_reasoning(self) -> None: await self._print(Tree(label, guide_style="dim", expanded=False)) self._reasoning_chars = 0 + async def _write_text(self, text: str) -> None: + parts = text.split("\n") + for index, part in enumerate(parts): + self._current_text_line += part + if index < len(parts) - 1: + await self._commit_text_line() + + if self._current_text_line: + await self._render_live_text_line() + + async def _commit_text_line(self) -> None: + if self._live_text_rows and self._rendered_text_line == self._current_text_line: + self._current_text_line = "" + self._rendered_text_line = None + self._live_text_rows = 0 + return + self._live_text_rows = await self._render_text_line(self._current_text_line) + self._current_text_line = "" + self._rendered_text_line = None + self._live_text_rows = 0 + + async def commit_live_text(self) -> None: + if self._current_text_line or self._live_text_rows: + await self._commit_text_line() + + async def _render_live_text_line(self) -> None: + self._live_text_rows = await self._render_text_line(self._current_text_line) + self._rendered_text_line = self._current_text_line + + async def _render_text_line(self, text: str) -> int: + previous_rows = self._live_text_rows + rows = self._display_rows(text) + + def render() -> None: + self._rewind_live_text(previous_rows) + self._console.print(f"{text}\n", end="", highlight=False) + + await run_in_terminal(render, render_cli_done=False) + return rows + + def _display_rows(self, text: str) -> int: + columns = max(1, int(getattr(self._console, "width", 80) or 80)) + return max(1, (get_cwidth(text) + columns - 1) // columns) + + def _rewind_live_text(self, rows: int) -> None: + if rows <= 0: + return + output = getattr(self._console, "file", None) + if output is None: + return + output.write(f"\x1b[{rows}A\r") + for row in range(rows): + output.write("\x1b[2K") + if row < rows - 1: + output.write("\x1b[1B\r") + if rows > 1: + output.write(f"\x1b[{rows - 1}A\r") + output.flush() + async def _print(self, *args: Any, **kwargs: Any) -> None: await run_in_terminal(lambda: self._console.print(*args, **kwargs), render_cli_done=False) @@ -148,6 +215,7 @@ def __init__(self, on_receive: MessageHandler, agent: Agent) -> None: self._expand_thinking = False self._llm_loop_running = False self._main_task: asyncio.Task | None = None + self._stream_printer: _StreamPrinter | None = None self._renderer = CliRenderer(get_console()) self._last_tape_info: TapeInfo | None = None self._workspace = self._agent.framework.workspace @@ -208,12 +276,12 @@ async def _main_loop(self) -> None: if raw in {",quit", ",exit"}: break if raw == ",thinking": - self._renderer.input_echo(self._prompt_label(), raw) + await self._echo_input(raw) self._toggle_thinking() continue request = self._normalize_input(raw) - self._renderer.input_echo(self._prompt_label(), raw) + await self._echo_input(raw) message = ChannelMessage( session_id=self._message_template["session_id"], @@ -264,6 +332,12 @@ def _prompt_label(self) -> str: symbol = ">" if self._mode == "agent" else "," return f"{cwd} {symbol} " + async def _echo_input(self, raw: str) -> None: + stream_printer = getattr(self, "_stream_printer", None) + if stream_printer is not None: + await stream_printer.commit_live_text() + self._renderer.input_echo(self._prompt_label(), raw) + async def stream_events( self, message: ChannelMessage, stream: AsyncIterable[StreamEvent] ) -> AsyncIterable[StreamEvent]: @@ -273,10 +347,15 @@ async def stream_events( print_head=lambda: self._renderer.print_head(message.kind), expand_thinking=self._expand_thinking, ) - with tool_call_reporter(_CliToolCallReporter(self._renderer)): - async for event in stream: - if await printer.render(event): - yield event + self._stream_printer = printer + try: + with tool_call_reporter(_CliToolCallReporter(self._renderer)): + async for event in stream: + if await printer.render(event): + yield event + finally: + if self._stream_printer is printer: + self._stream_printer = None def _build_prompt(self, workspace: Path) -> PromptSession[str]: kb = KeyBindings() diff --git a/tests/test_channels.py b/tests/test_channels.py index df055eb3..54faa432 100644 --- a/tests/test_channels.py +++ b/tests/test_channels.py @@ -2,6 +2,14 @@ import asyncio import contextlib +import os +import pty +import re +import select +import subprocess +import sys +import textwrap +import time from datetime import datetime from pathlib import Path from types import SimpleNamespace @@ -18,6 +26,8 @@ from bub.runtime import StreamEvent from bub.turn_admission import AdmitDecision, SessionTurnController, SteeringBuffer, TurnSnapshot +ANSI_RE = re.compile(r"\x1b(?:\[[0-?]*[ -/]*[@-~]|\][^\x07]*(?:\x07|\x1b\\)|[()][A-Za-z])") + def _load_channel_config( load_config, @@ -35,6 +45,32 @@ def _load_channel_config( load_config(content) +def _read_pty_until_exit(master_fd: int, process: subprocess.Popen[bytes], *, timeout: float = 3.0) -> bytes: + chunks: list[bytes] = [] + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if process.poll() is not None: + with contextlib.suppress(OSError): + chunks.append(os.read(master_fd, 65536)) + break + readable, _, _ = select.select([master_fd], [], [], 0.05) + if not readable: + continue + try: + chunk = os.read(master_fd, 65536) + except OSError: + break + if not chunk: + break + chunks.append(chunk) + return b"".join(chunks) + + +def _plain_terminal_text(raw: bytes) -> str: + text = raw.decode(errors="replace") + return ANSI_RE.sub("", text).replace("\r", "\n") + + class _FakeChannelMixin: def __init__(self, name: str, *, needs_debounce: bool = False) -> None: self.name = name @@ -803,10 +839,118 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]: yielded = [event async for event in channel.stream_events(message, source())] assert heads == ["command"] - assert printed == [("hel", "", False), ("lo", "", False), ("", None, None)] + assert printed == [("hel\n", "", False), ("hello\n", "", False)] assert [event.kind for event in yielded] == ["text", "text", "final"] +def test_cli_stream_output_does_not_overlap_active_pty_prompt() -> None: + script = textwrap.dedent( + """ + import asyncio + + from prompt_toolkit import PromptSession + from prompt_toolkit.patch_stdout import patch_stdout + from rich.console import Console + + from bub.channels.cli import _StreamPrinter + from bub.runtime import StreamEvent + + + async def main(): + console = Console(force_terminal=True, color_system=None, width=80) + printer = _StreamPrinter( + console=console, + print_head=lambda: console.print("Assistant >"), + expand_thinking=False, + ) + session = PromptSession(erase_when_done=True) + + async def stream(): + chunks = [ + "春风一夜入江城\\n", + "细雨无声湿客", + "程\\n", + "莫问归帆何处", + "去\\n", + "明朝山色满", + "前庭", + ] + for index, chunk in enumerate(chunks): + await asyncio.sleep(0.03) + await printer.render(StreamEvent("text", {"delta": chunk})) + if index == 3: + await printer.commit_live_text() + console.print("bub > steer now") + await asyncio.sleep(0.03) + await printer.render(StreamEvent("final", {})) + + task = asyncio.create_task(stream()) + with patch_stdout(raw=True): + await session.prompt_async( + lambda: [("", "\\n* Generating\\nbub > ")], + refresh_interval=0.02, + ) + await task + + + asyncio.run(main()) + """ + ) + master_fd, slave_fd = pty.openpty() + env = os.environ.copy() + env["PYTHONPATH"] = f"{Path.cwd() / 'src'}{os.pathsep}{env.get('PYTHONPATH', '')}" + process = subprocess.Popen( + [sys.executable, "-c", script], + stdin=slave_fd, + stdout=slave_fd, + stderr=slave_fd, + cwd=Path.cwd(), + env=env, + close_fds=True, + ) + os.close(slave_fd) + try: + time.sleep(0.25) + os.write(master_fd, b"next\n") + raw_output = _read_pty_until_exit(master_fd, process) + finally: + if process.poll() is None: + process.terminate() + with contextlib.suppress(subprocess.TimeoutExpired): + process.wait(timeout=1) + os.close(master_fd) + + assert process.wait(timeout=1) == 0, raw_output.decode(errors="replace") + output = _plain_terminal_text(raw_output) + + assert "春风一夜入江城" in output + assert "细雨无声湿客程" in output + assert "莫问归帆何处" in output + assert "去" in output + assert "明朝山色满前庭" in output + assert "bub > steer now" in output + assert "明朝山色满前庭bub >" not in output + assert "明朝山色满前庭* Generating" not in output + + +@pytest.mark.asyncio +async def test_cli_channel_input_echo_commits_active_stream_line() -> None: + channel = CliChannel.__new__(CliChannel) + calls: list[str] = [] + + class FakeStreamPrinter: + async def commit_live_text(self) -> None: + calls.append("commit") + + channel._stream_printer = FakeStreamPrinter() + channel._mode = "agent" + channel._renderer = SimpleNamespace(input_echo=lambda prompt, text: calls.append(f"echo:{text}")) + + await channel._echo_input("steer now") + + assert calls == ["commit", "echo:steer now"] + + @pytest.mark.asyncio async def test_cli_channel_collapsed_reasoning_does_not_start_status_spinner( monkeypatch: pytest.MonkeyPatch, @@ -838,7 +982,7 @@ async def source() -> asyncio.AsyncIterator[StreamEvent]: assert [event.kind for event in yielded] == ["reasoning", "text", "final"] assert printed - assert "hello" in [str(item) for item in printed] + assert any("hello" in str(item) for item in printed) def test_cli_channel_history_file_uses_workspace_hash(tmp_path: Path) -> None: