diff --git a/src/bub/builtin/agent.py b/src/bub/builtin/agent.py index aebed863..8743fa30 100644 --- a/src/bub/builtin/agent.py +++ b/src/bub/builtin/agent.py @@ -2,7 +2,6 @@ from __future__ import annotations -import asyncio import inspect import re import shlex @@ -13,46 +12,31 @@ from datetime import UTC, datetime from functools import cached_property from pathlib import Path -from typing import Any, Literal, cast - -from any_llm import AnyLLM -from any_llm.types.completion import ( - ChatCompletion, - ChatCompletionChunk, - ChatCompletionMessage, - ChatCompletionMessageFunctionToolCall, - ChatCompletionMessageToolCall, - ChoiceDeltaToolCall, - Function, - ParsedChatCompletion, -) +from typing import Any + from loguru import logger -from pydantic import TypeAdapter, ValidationError -from bub.builtin.settings import ModelCandidate, load_settings -from bub.builtin.store import ForkTapeStore -from bub.builtin.tape import TapeService +from bub.builtin.model_runner import ( + ModelRunner, + is_context_length_error, +) +from bub.builtin.settings import load_settings +from bub.builtin.tape import Tape from bub.framework import BubFramework -from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState +from bub.runtime import AsyncStreamEvents, StreamEvent, StreamState from bub.skills import discover_skills, render_skills_prompt -from bub.tape import InMemoryTapeStore, Tape +from bub.tape import AsyncTapeStoreAdapter, InMemoryTapeStore, is_async_tape_store from bub.tools import ( REGISTRY, Tool, ToolContext, - ToolExecutor, ) from bub.types import State from bub.utils import workspace_from_state CONTINUE_PROMPT = "Continue the task until all targets are completed." HINT_RE = re.compile(r"\$([A-Za-z0-9_.-]+)") -_CONTEXT_LENGTH_PATTERNS = re.compile( - r"context.{0,20}(?:length|window)|maximum.{0,20}context|token.{0,10}limit|prompt.{0,10}too long|tokens? > \d+ maximum", - re.IGNORECASE, -) MAX_AUTO_HANDOFF_RETRIES = 1 -TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any]) class Agent: @@ -61,16 +45,18 @@ class Agent: def __init__(self, framework: BubFramework) -> None: self.settings = load_settings() self.framework = framework + self.model_runner = ModelRunner(self.settings) @cached_property - def tapes(self) -> TapeService: + def tape(self) -> Tape: import bub tape_store = self.framework.get_tape_store() if tape_store is None: tape_store = InMemoryTapeStore() - tape_store = ForkTapeStore(tape_store) - return TapeService(bub.home / "tapes", tape_store, self.framework.build_tape_context()) + if not is_async_tape_store(tape_store): + tape_store = AsyncTapeStoreAdapter(tape_store) + return Tape(bub.home / "tapes", tape_store, self.framework.build_tape_context()) @staticmethod def _events_from_iterable(iterable: Iterable) -> AsyncStreamEvents: @@ -109,13 +95,14 @@ async def run_stream( StreamEvent("final", {"text": "error: empty prompt", "ok": False}), ]) - tape = self.tapes.session_tape(session_id, workspace_from_state(state)) - tape.context = replace(tape.context, state=state) + tape = self.tape.session_tape( + session_id, workspace_from_state(state), context=replace(self.tape.context, state=state) + ) merge_back = not session_id.startswith("temp/") stack = AsyncExitStack() # The fork_tape context manager must not be exited until the last chunk of the stream is consumed. - await stack.enter_async_context(self.tapes.fork_tape(tape.name, merge_back=merge_back)) - await self.tapes.ensure_bootstrap_anchor(tape.name) + tape = await stack.enter_async_context(tape.fork_tape(merge_back=merge_back)) + await tape.ensure_bootstrap_anchor() if isinstance(prompt, str) and prompt.strip().startswith(","): result = await self._run_command(tape=tape, line=prompt.strip()) events = self._events_from_iterable([ @@ -139,7 +126,7 @@ async def _run_command(self, tape: Tape, *, line: str) -> str: name, arg_tokens = _parse_internal_command(line) start = time.monotonic() - context = ToolContext(tape=tape.name, run_id="run_command", state=tape.context.state) + context = ToolContext(tape=tape, run_id="run_command", state=tape.context.state) output = "" status = "ok" try: @@ -170,7 +157,7 @@ async def _run_command(self, tape: Tape, *, line: str) -> str: "output": output_text, "date": datetime.now(UTC).isoformat(), } - await self.tapes.append_event(tape.name, "command", event_payload) + await tape.append_event("command", event_payload) async def _agent_loop( self, @@ -183,8 +170,7 @@ async def _agent_loop( ) -> AsyncStreamEvents: next_prompt: str | list[dict] = prompt display_model = model or self.settings.model - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.start", { "model": display_model, @@ -220,7 +206,7 @@ async def _stream_events_with_auto_handoff( start = time.monotonic() should_continue = False logger.info("loop.step step={} tape={} model={}", step, tape.name, display_model) - await self.tapes.append_event(tape.name, "loop.step.start", {"step": step, "prompt": next_prompt}) + await tape.append_event("loop.step.start", {"step": step, "prompt": next_prompt}) try: output = await self._run_once( tape=tape, @@ -233,8 +219,7 @@ async def _stream_events_with_auto_handoff( yield event if event.kind == "error": elapsed_ms = int((time.monotonic() - start) * 1000) - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.step", { "step": step, @@ -249,20 +234,18 @@ async def _stream_events_with_auto_handoff( except Exception as exc: error_message = f"{exc!s}" elapsed_ms = int((time.monotonic() - start) * 1000) - if auto_handoff_remaining > 0 and _is_context_length_error(error_message): + if auto_handoff_remaining > 0 and is_context_length_error(error_message): auto_handoff_remaining -= 1 logger.warning( "auto_handoff: context length exceeded, performing automatic handoff. tape={} step={}", tape.name, step, ) - await self.tapes.handoff( - tape.name, + await tape.handoff( name="auto_handoff/context_overflow", state={"reason": "context_length_exceeded", "error": error_message}, ) - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.step", { "step": step, @@ -275,8 +258,7 @@ async def _stream_events_with_auto_handoff( next_prompt = prompt continue - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.step", { "step": step, @@ -292,8 +274,7 @@ async def _stream_events_with_auto_handoff( state.usage = output.usage elapsed_ms = int((time.monotonic() - start) * 1000) if not should_continue: - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.step", { "step": step, @@ -305,8 +286,7 @@ async def _stream_events_with_auto_handoff( return next_prompt = self._continue_prompt(tape) - await self.tapes.append_event( - tape.name, + await tape.append_event( "loop.step", { "step": step, @@ -367,130 +347,21 @@ async def _run_once_stream( allowed_skills: set[str] | None, tools: list[Tool], ) -> AsyncStreamEvents: - state = StreamState() - - async def iterator() -> AsyncGenerator[StreamEvent, None]: - system_prompt = self._system_prompt( - prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools - ) - prompt_message: dict[str, Any] = {"role": "user", "content": prompt} - run_id = f"run-{datetime.now(UTC).strftime('%Y%m%dT%H%M%S%fZ')}" - try: - messages = await self.tapes.read_messages(tape) - except BubError as exc: - await self.tapes.record_chat( - tape=tape.name, - run_id=run_id, - system_prompt=system_prompt, - context_error=exc, - new_messages=[], - response_text=None, - error=exc, - model=model or self.settings.model, - ) - raise - if system_prompt: - messages = [{"role": "system", "content": system_prompt}, *messages] - messages.append(prompt_message) - - from bub.builtin.tools import model_tools - - model_tools_for_call = model_tools(tools) - text_parts: list[str] = [] - tool_calls = _ToolCallAccumulator() - response: ChatCompletion | ParsedChatCompletion[Any] | None = None - async with asyncio.timeout(self.settings.model_timeout_seconds): - completion = await self._completion_response( - model=model or self.settings.model, - messages=messages, - tools=model_tools_for_call, - ) - if isinstance(completion, ChatCompletion): - response = completion - async for event in _completion_events(completion, state, text_parts, tool_calls): - yield event - - assistant_message = response.choices[0].message if response is not None else None - text = ( - assistant_message.content - if assistant_message and assistant_message.content is not None - else "".join(text_parts) - ) - native_tool_calls = tool_calls.as_native() - if native_tool_calls: - tool_map = {tool_item.name: tool_item for tool_item in model_tools_for_call} - serialized_tool_calls = [tool_call.model_dump(exclude_none=True) for tool_call in native_tool_calls] - tool_invocations = [ - _tool_invocation_from_native(tool_call, tool_map) for tool_call in native_tool_calls - ] - yield StreamEvent("tool_call", {"tool_calls": serialized_tool_calls}) - context = ToolContext(tape=tape.name, run_id=run_id, state=tape.context.state) - execution = await ToolExecutor().execute_async( - tool_invocations, - context=context, - ) - await self.tapes.record_chat( - tape=tape.name, - run_id=run_id, - system_prompt=system_prompt, - new_messages=[prompt_message], - response_text=None, - tool_calls=serialized_tool_calls, - tool_results=execution.tool_results, - response=response, - model=model or self.settings.model, - usage=state.usage, - ) - yield StreamEvent("tool_result", {"tool_results": execution.tool_results}) - yield StreamEvent( - "final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results} - ) - return - - await self.tapes.record_chat( - tape=tape.name, - run_id=run_id, - system_prompt=system_prompt, - new_messages=[prompt_message], - response_text=text, - response=response, - model=model or self.settings.model, - usage=state.usage, - ) - yield StreamEvent("final", {"ok": True, "text": text}) + from bub.builtin.tools import model_tools - return AsyncStreamEvents(iterator(), state=state) - - def _build_llm(self, candidate: ModelCandidate) -> AnyLLM: - return AnyLLM.create( - candidate.provider, - **self.settings.model_client_kwargs(candidate.provider), + system_prompt = self._system_prompt( + prompt_text, state=tape.context.state, allowed_skills=allowed_skills, tools=tools ) + resolved_model = model or self.settings.model - async def _completion_response( - self, *, model: str, messages: list[dict[str, Any]], tools: list[Tool] - ) -> ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk]: - from bub.builtin.tools import completion_tools - - tool_payloads = completion_tools(tools) or None - completion_messages: list[dict[str, Any] | ChatCompletionMessage] = list(messages) - candidates = self.settings.model_candidates(model) - for index, candidate in enumerate(candidates): - try: - llm = self._build_llm(candidate) - return await llm.acompletion( - model=candidate.model_id, - messages=completion_messages, - tools=tool_payloads, - max_tokens=self.settings.max_tokens, - stream=llm.SUPPORTS_COMPLETION_STREAMING, - ) - except Exception as exc: - if index == len(candidates) - 1: - raise - logger.warning("model candidate failed; trying fallback model={} error={}", candidate.name, exc) - - raise RuntimeError("no model candidates available") + model_tools_for_call = model_tools(tools) + return self.model_runner.run( + tape=tape, + model=resolved_model, + tools=model_tools_for_call, + system_prompt=system_prompt, + prompt=prompt, + ) def _system_prompt( self, prompt: str, state: State, allowed_skills: set[str] | None = None, tools: Iterable[Tool] | None = None @@ -514,131 +385,6 @@ def _continue_prompt(self, tape: Tape) -> str: return CONTINUE_PROMPT -@dataclass -class _StreamToolCall: - id: str | None = None - type: Literal["function"] | None = None - name: str | None = None - arguments: str = "" - - def merge(self, delta: ChoiceDeltaToolCall) -> None: - if delta.id: - self.id = delta.id - if delta.type: - self.type = delta.type - if delta.function is None: - return - if delta.function.name: - if self.name is None or self.name == delta.function.name: - self.name = delta.function.name - else: - self.name += delta.function.name - if delta.function.arguments: - self.arguments += delta.function.arguments - - def as_tool_call(self, index: int) -> ChatCompletionMessageFunctionToolCall: - return ChatCompletionMessageFunctionToolCall( - id=self.id or f"call_{index}", - type=self.type or "function", - function=Function(name=self.name or "", arguments=self.arguments or "{}"), - ) - - -class _ToolCallAccumulator: - def __init__(self) -> None: - self._message_calls: list[ChatCompletionMessageToolCall] = [] - self._stream_calls: dict[int, _StreamToolCall] = {} - - def add_message_calls(self, calls: Iterable[ChatCompletionMessageToolCall]) -> None: - self._message_calls.extend(calls) - - def merge_delta_calls(self, deltas: Iterable[ChoiceDeltaToolCall]) -> None: - for delta in deltas: - self._stream_calls.setdefault(delta.index, _StreamToolCall()).merge(delta) - - def as_native(self) -> list[ChatCompletionMessageToolCall]: - if self._message_calls: - return list(self._message_calls) - return [self._stream_calls[index].as_tool_call(index) for index in sorted(self._stream_calls)] - - -def _tool_invocation_from_native( - tool_call: ChatCompletionMessageToolCall, - tool_map: dict[str, Tool], -) -> tuple[Tool, dict[str, Any]]: - tool_name, arguments = _parse_native_function_call(tool_call) - tool_obj = tool_map.get(tool_name) - if tool_obj is None: - raise BubError(ErrorKind.TOOL, f"Unknown tool name: {tool_name}.") - return tool_obj, arguments - - -def _parse_native_function_call(tool_call: ChatCompletionMessageToolCall) -> tuple[str, dict[str, Any]]: - if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall): - raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.") - try: - arguments = TOOL_ARGUMENTS_ADAPTER.validate_json(tool_call.function.arguments or "{}") - except ValidationError as exc: - raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.") from exc - return tool_call.function.name, arguments - - -async def _completion_events( - completion: ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk], - state: StreamState, - text_parts: list[str], - tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[StreamEvent, None]: - if isinstance(completion, ChatCompletion): - if usage := TapeService._extract_usage(completion): - state.usage = usage - message = completion.choices[0].message - for event in _completion_message_events(message, text_parts, tool_calls): - yield event - return - - async for chunk in completion: - async for event in _completion_chunk_events(chunk, state, text_parts, tool_calls): - yield event - - -def _completion_message_events( - message: ChatCompletionMessage, - text_parts: list[str], - tool_calls: _ToolCallAccumulator, -) -> Iterable[StreamEvent]: - if message.reasoning: - yield StreamEvent("reasoning", {"delta": _reasoning_text(message.reasoning)}) - if message.content: - text_parts.append(message.content) - yield StreamEvent("text", {"delta": message.content}) - tool_calls.add_message_calls(cast("Iterable[ChatCompletionMessageToolCall]", message.tool_calls or [])) - - -async def _completion_chunk_events( - chunk: ChatCompletionChunk, - state: StreamState, - text_parts: list[str], - tool_calls: _ToolCallAccumulator, -) -> AsyncGenerator[StreamEvent, None]: - if usage := TapeService._extract_usage(chunk): - state.usage = usage - for choice in chunk.choices: - delta = choice.delta - if delta.reasoning: - yield StreamEvent("reasoning", {"delta": _reasoning_text(delta.reasoning)}) - if delta.content: - text_parts.append(delta.content) - yield StreamEvent("text", {"delta": delta.content}) - if delta.tool_calls: - tool_calls.merge_delta_calls(delta.tool_calls) - - -def _reasoning_text(reasoning: object) -> str: - content = getattr(reasoning, "content", reasoning) - return "" if content is None else str(content) - - @dataclass(frozen=True) class Args: positional: list[str] @@ -669,11 +415,6 @@ def _parse_args(args_tokens: list[str]) -> Args: return Args(positional=positional, kwargs=kwargs) -def _is_context_length_error(error_msg: str) -> bool: - """Check whether an error message indicates a context-length / prompt-too-long failure.""" - return bool(_CONTEXT_LENGTH_PATTERNS.search(error_msg)) - - def _extract_text_from_parts(parts: list[dict]) -> str: """Extract text content from multimodal content parts.""" return "\n".join(p.get("text", "") for p in parts if p.get("type") == "text") diff --git a/src/bub/builtin/model_runner.py b/src/bub/builtin/model_runner.py new file mode 100644 index 00000000..0684f8fc --- /dev/null +++ b/src/bub/builtin/model_runner.py @@ -0,0 +1,368 @@ +"""LLM completion and model-output helpers for the builtin agent.""" + +from __future__ import annotations + +import asyncio +import re +from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator +from dataclasses import dataclass +from datetime import UTC, datetime +from typing import Any, Literal, cast + +from any_llm import AnyLLM +from any_llm.types.completion import ( + ChatCompletion, + ChatCompletionChunk, + ChatCompletionMessage, + ChatCompletionMessageFunctionToolCall, + ChatCompletionMessageToolCall, + ChoiceDeltaToolCall, + Function, + ParsedChatCompletion, +) +from loguru import logger +from pydantic import TypeAdapter, ValidationError + +from bub.builtin.settings import AgentSettings, ModelCandidate +from bub.builtin.tape import Tape +from bub.runtime import AsyncStreamEvents, BubError, ErrorKind, StreamEvent, StreamState +from bub.tools import Tool, ToolContext, ToolExecutor + +CONTEXT_LENGTH_PATTERNS = re.compile( + r"context.{0,20}(?:length|window)|maximum.{0,20}context|token.{0,10}limit|prompt.{0,10}too long|tokens? > \d+ maximum", + re.IGNORECASE, +) +TOOL_ARGUMENTS_ADAPTER = TypeAdapter(dict[str, Any]) +CompletionResult = ChatCompletion | ParsedChatCompletion[Any] | AsyncIterator[ChatCompletionChunk] + + +class ModelRunner: + def __init__(self, settings: AgentSettings) -> None: + self.settings = settings + + def iter_llm_clients(self, model: str) -> Iterator[tuple[ModelCandidate, AnyLLM]]: + for candidate in self.settings.model_candidates(model): + yield ( + candidate, + AnyLLM.create( + candidate.provider, + **self.settings.model_client_kwargs(candidate.provider), + ), + ) + + async def completion_response( + self, *, model: str, messages: list[dict[str, Any]], tools: list[Tool] + ) -> CompletionResult: + from bub.builtin.tools import completion_tools + + tool_payloads = completion_tools(tools) or None + completion_messages: list[dict[str, Any] | ChatCompletionMessage] = list(messages) + clients = list(self.iter_llm_clients(model)) + completion_error: Exception | None = None + for index, (candidate, llm) in enumerate(clients): + try: + return await llm.acompletion( + model=candidate.model_id, + messages=completion_messages, + tools=tool_payloads, + max_tokens=self.settings.max_tokens, + stream=llm.SUPPORTS_COMPLETION_STREAMING, + ) + except Exception as exc: + if completion_error is None: + completion_error = exc + if index == len(clients) - 1: + raise completion_error from None + logger.warning("model candidate failed; trying fallback model={} error={}", candidate.name, exc) + + raise RuntimeError("no model candidates available") + + def run( + self, + *, + tape: Tape, + model: str, + tools: list[Tool], + system_prompt: str | None, + prompt: str | list[dict], + ) -> AsyncStreamEvents: + state = StreamState() + + async def iterator() -> AsyncGenerator[StreamEvent, None]: + run_id = self.generate_run_id() + messages, new_messages = await self.build_messages( + tape=tape, + run_id=run_id, + system_prompt=system_prompt, + prompt=prompt, + model=model, + ) + output = ModelOutputAccumulator() + async with asyncio.timeout(self.settings.model_timeout_seconds): + completion = await self.completion_response(model=model, messages=messages, tools=tools) + async for event in self._completion_events(completion, state, output): + yield event + + tool_calls = output.tool_calls + if tool_calls: + tool_map = {tool_item.name: tool_item for tool_item in tools} + serialized_tool_calls = [tool_call.model_dump(exclude_none=True) for tool_call in tool_calls] + tool_invocations = [tool_invocation_from_native(tool_call, tool_map) for tool_call in tool_calls] + yield StreamEvent("tool_call", {"tool_calls": serialized_tool_calls}) + context = ToolContext(tape=tape, run_id=run_id, state=tape.context.state) + execution = await ToolExecutor().execute_async( + tool_invocations, + context=context, + ) + await self.record_chat( + tape=tape, + run_id=run_id, + system_prompt=system_prompt, + new_messages=new_messages, + response_text=None, + tool_calls=serialized_tool_calls, + tool_results=execution.tool_results, + response=output.response, + model=model, + usage=state.usage, + ) + yield StreamEvent("tool_result", {"tool_results": execution.tool_results}) + yield StreamEvent( + "final", {"ok": True, "tool_calls": serialized_tool_calls, "tool_results": execution.tool_results} + ) + return + + text = output.text + await self.record_chat( + tape=tape, + run_id=run_id, + system_prompt=system_prompt, + new_messages=new_messages, + response_text=text, + response=output.response, + model=model, + usage=state.usage, + ) + yield StreamEvent("final", {"ok": True, "text": text}) + + return AsyncStreamEvents(iterator(), state=state) + + @staticmethod + def generate_run_id() -> str: + return f"run-{datetime.now(UTC).strftime('%Y%m%dT%H%M%S%fZ')}" + + async def build_messages( + self, + *, + tape: Tape, + run_id: str, + system_prompt: str | None, + prompt: str | list[dict], + model: str, + ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + prompt_message: dict[str, Any] = {"role": "user", "content": prompt} + try: + messages = await tape.read_messages() + except BubError as exc: + await self.record_context_error( + tape=tape, + run_id=run_id, + system_prompt=system_prompt, + error=exc, + model=model, + ) + raise + if system_prompt: + messages = [{"role": "system", "content": system_prompt}, *messages] + messages.append(prompt_message) + return messages, [prompt_message] + + async def record_context_error( + self, + *, + tape: Tape, + run_id: str, + system_prompt: str | None, + error: BubError, + model: str, + ) -> None: + await self.record_chat( + tape=tape, + run_id=run_id, + system_prompt=system_prompt, + context_error=error, + new_messages=[], + response_text=None, + error=error, + model=model, + ) + + async def record_chat( + self, + *, + tape: Tape, + run_id: str, + system_prompt: str | None, + new_messages: list[dict[str, Any]], + response_text: str | None, + context_error: BubError | None = None, + tool_calls: list[dict[str, Any]] | None = None, + tool_results: list[Any] | None = None, + error: BubError | None = None, + response: Any | None = None, + provider: str | None = None, + model: str | None = None, + usage: dict[str, Any] | None = None, + ) -> None: + await tape.record_chat( + run_id=run_id, + system_prompt=system_prompt, + new_messages=new_messages, + response_text=response_text, + context_error=context_error, + tool_calls=tool_calls, + tool_results=tool_results, + error=error, + response=response, + provider=provider, + model=model, + usage=usage, + ) + + async def _completion_events( + self, + completion: CompletionResult, + state: StreamState, + output: ModelOutputAccumulator, + ) -> AsyncGenerator[StreamEvent, None]: + if isinstance(completion, ChatCompletion): + if usage := Tape._extract_usage(completion): + state.usage = usage + output.response = completion + message = completion.choices[0].message + for event in self._completion_message_events(message, output): + yield event + return + + async for chunk in completion: + async for event in self._completion_chunk_events(chunk, state, output): + yield event + + def _completion_message_events( + self, + message: ChatCompletionMessage, + output: ModelOutputAccumulator, + ) -> Iterable[StreamEvent]: + if message.reasoning: + yield StreamEvent("reasoning", {"delta": self.reasoning_text(message.reasoning)}) + if message.content: + output.add_text(message.content) + yield StreamEvent("text", {"delta": message.content}) + output.add_message_tool_calls(cast("Iterable[ChatCompletionMessageToolCall]", message.tool_calls or [])) + + async def _completion_chunk_events( + self, + chunk: ChatCompletionChunk, + state: StreamState, + output: ModelOutputAccumulator, + ) -> AsyncGenerator[StreamEvent, None]: + if usage := Tape._extract_usage(chunk): + state.usage = usage + for choice in chunk.choices: + delta = choice.delta + if delta.reasoning: + yield StreamEvent("reasoning", {"delta": self.reasoning_text(delta.reasoning)}) + if delta.content: + output.add_text(delta.content) + yield StreamEvent("text", {"delta": delta.content}) + if delta.tool_calls: + output.merge_delta_tool_calls(delta.tool_calls) + + @staticmethod + def reasoning_text(reasoning: object) -> str: + content = getattr(reasoning, "content", reasoning) + return "" if content is None else str(content) + + +@dataclass +class StreamToolCall: + id: str | None = None + type: Literal["function"] | None = None + name: str | None = None + arguments: str = "" + + def merge(self, delta: ChoiceDeltaToolCall) -> None: + if delta.id: + self.id = delta.id + if delta.type: + self.type = delta.type + if delta.function is None: + return + if delta.function.name: + if self.name is None or self.name == delta.function.name: + self.name = delta.function.name + else: + self.name += delta.function.name + if delta.function.arguments: + self.arguments += delta.function.arguments + + def as_tool_call(self, index: int) -> ChatCompletionMessageFunctionToolCall: + return ChatCompletionMessageFunctionToolCall( + id=self.id or f"call_{index}", + type=self.type or "function", + function=Function(name=self.name or "", arguments=self.arguments or "{}"), + ) + + +class ModelOutputAccumulator: + def __init__(self) -> None: + self.response: ChatCompletion | ParsedChatCompletion[Any] | None = None + self._text_parts: list[str] = [] + self._message_calls: list[ChatCompletionMessageToolCall] = [] + self._stream_calls: dict[int, StreamToolCall] = {} + + def add_text(self, text: str) -> None: + self._text_parts.append(text) + + def add_message_tool_calls(self, calls: Iterable[ChatCompletionMessageToolCall]) -> None: + self._message_calls.extend(calls) + + def merge_delta_tool_calls(self, deltas: Iterable[ChoiceDeltaToolCall]) -> None: + for delta in deltas: + self._stream_calls.setdefault(delta.index, StreamToolCall()).merge(delta) + + @property + def text(self) -> str: + return "".join(self._text_parts) + + @property + def tool_calls(self) -> list[ChatCompletionMessageToolCall]: + if self._message_calls: + return list(self._message_calls) + return [self._stream_calls[index].as_tool_call(index) for index in sorted(self._stream_calls)] + + +def tool_invocation_from_native( + tool_call: ChatCompletionMessageToolCall, + tool_map: dict[str, Tool], +) -> tuple[Tool, dict[str, Any]]: + tool_name, arguments = parse_native_function_call(tool_call) + tool_obj = tool_map.get(tool_name) + if tool_obj is None: + raise BubError(ErrorKind.TOOL, f"Unknown tool name: {tool_name}.") + return tool_obj, arguments + + +def parse_native_function_call(tool_call: ChatCompletionMessageToolCall) -> tuple[str, dict[str, Any]]: + if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall): + raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.") + try: + arguments = TOOL_ARGUMENTS_ADAPTER.validate_json(tool_call.function.arguments or "{}") + except ValidationError as exc: + raise BubError(ErrorKind.INVALID_INPUT, "Expected a function tool call with JSON object arguments.") from exc + return tool_call.function.name, arguments + + +def is_context_length_error(error_msg: str) -> bool: + """Check whether an error message indicates a context-length / prompt-too-long failure.""" + return bool(CONTEXT_LENGTH_PATTERNS.search(error_msg)) diff --git a/src/bub/builtin/store.py b/src/bub/builtin/store.py index 457bc606..c5f33adf 100644 --- a/src/bub/builtin/store.py +++ b/src/bub/builtin/store.py @@ -1,12 +1,10 @@ from __future__ import annotations -import contextlib -import contextvars import itertools import json import re import threading -from collections.abc import AsyncGenerator, Iterable +from collections.abc import Iterable from dataclasses import asdict, replace from datetime import UTC, datetime from pathlib import Path @@ -16,19 +14,13 @@ from bub.tape import ( AsyncTapeStore, - AsyncTapeStoreAdapter, InMemoryQueryMixin, InMemoryTapeStore, TapeEntry, TapeQuery, - TapeStore, - is_async_tape_store, ) from bub.utils import get_entry_text -current_store: contextvars.ContextVar[TapeStore] = contextvars.ContextVar("current_store") -current_fork_tape: contextvars.ContextVar[str | None] = contextvars.ContextVar("current_fork_tape", default=None) -current_tape_was_reset: contextvars.ContextVar[bool] = contextvars.ContextVar("current_tape_was_reset", default=False) WORD_PATTERN = re.compile(r"[a-z0-9_/-]+") MIN_FUZZY_QUERY_LENGTH = 3 MIN_FUZZY_SCORE = 80 @@ -36,52 +28,39 @@ class ForkTapeStore: - def __init__(self, parent: AsyncTapeStore | TapeStore) -> None: - if is_async_tape_store(parent): - self._parent = parent - else: - self._parent = AsyncTapeStoreAdapter(parent) - - @property - def _current(self) -> TapeStore: - return current_store.get(_empty_store) - - @property - def _fork_tape(self) -> str | None: - return current_fork_tape.get() - - @property - def _current_was_reset(self) -> bool: - return current_tape_was_reset.get() + def __init__(self, parent: AsyncTapeStore, tape: str) -> None: + self._parent = parent + self._store = InMemoryTapeStore() + self._tape = tape + self._tape_was_reset = False async def list_tapes(self) -> list[str]: return await self._parent.list_tapes() async def reset(self, tape: str) -> None: - self._current.reset(tape) - if self._current is _empty_store or self._fork_tape != tape: + if tape != self._tape: await self._parent.reset(tape) return - current_tape_was_reset.set(True) + self._store.reset(tape) + self._tape_was_reset = True async def fetch_all(self, query: TapeQuery[AsyncTapeStore]) -> Iterable[TapeEntry]: parent_entries: Iterable[TapeEntry] = [] - if not (query.tape == self._fork_tape and self._current_was_reset): + if not (query.tape == self._tape and self._tape_was_reset): try: parent_entries = await self._parent.fetch_all(query) except Exception: parent_entries = [] this_entries: list[TapeEntry] = [] - if isinstance(self._current, InMemoryQueryMixin): - for entry in self._current.read(query.tape) or []: - if query._kinds and entry.kind not in query._kinds: + for entry in self._store.read(query.tape) or []: + if query._kinds and entry.kind not in query._kinds: + continue + if entry.kind == "anchor": # noqa: SIM102 + if query._after_last or (query._after_anchor and entry.payload.get("name") == query._after_anchor): + this_entries.clear() + parent_entries = [] continue - if entry.kind == "anchor": # noqa: SIM102 - if query._after_last or (query._after_anchor and entry.payload.get("name") == query._after_anchor): - this_entries.clear() - parent_entries = [] - continue - this_entries.append(entry) + this_entries.append(entry) return itertools.chain(parent_entries, this_entries) @staticmethod @@ -103,55 +82,18 @@ def _redact_payload(payload: dict) -> None: async def append(self, tape: str, entry: TapeEntry) -> None: self._redact_payload(entry.payload) - self._current.append(tape, entry) - - @contextlib.asynccontextmanager - async def fork(self, tape: str, merge_back: bool = True) -> AsyncGenerator[None, None]: - store = InMemoryTapeStore() - # Save/restore instead of ContextVar.reset(token) to avoid - # "Token was created in a different Context" when cleanup - # runs in a different asyncio Task (e.g. cancellation, TaskGroup). - prev_store = current_store.get(_empty_store) - prev_fork_tape = current_fork_tape.get() - prev_was_reset = current_tape_was_reset.get() - current_store.set(store) - current_fork_tape.set(tape) - current_tape_was_reset.set(False) - try: - yield - finally: - was_reset = current_tape_was_reset.get() - current_store.set(prev_store) - current_fork_tape.set(prev_fork_tape) - current_tape_was_reset.set(prev_was_reset) - if merge_back: - if was_reset: - await self._parent.reset(tape) - entries = store.read(tape) - if entries: - count = len(entries) - for entry in entries: - await self._parent.append(tape, entry) - logger.info(f'Merged {count} entries into tape "{tape}"') - - -class EmptyTapeStore: - """Sync TapeStore sentinel that always returns empty results.""" - - def list_tapes(self) -> list[str]: - return [] + self._store.append(tape, entry) - def reset(self, tape: str) -> None: - pass - - def fetch_all(self, query: TapeQuery) -> Iterable[TapeEntry]: - return [] - - def append(self, tape: str, entry: TapeEntry) -> None: - pass - - -_empty_store = EmptyTapeStore() + async def merge_back(self) -> None: + if self._tape_was_reset: + await self._parent.reset(self._tape) + entries = self._store.read(self._tape) + if not entries: + return + count = len(entries) + for entry in entries: + await self._parent.append(self._tape, entry) + logger.info(f'Merged {count} entries into tape "{self._tape}"') class FileTapeStore(InMemoryQueryMixin): diff --git a/src/bub/builtin/tape.py b/src/bub/builtin/tape.py index ca992521..22281c9b 100644 --- a/src/bub/builtin/tape.py +++ b/src/bub/builtin/tape.py @@ -1,9 +1,11 @@ +from __future__ import annotations + import contextlib import hashlib import inspect import json from collections.abc import AsyncGenerator -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field, replace from datetime import UTC, datetime from pathlib import Path from typing import Any @@ -12,7 +14,13 @@ from bub.builtin.store import ForkTapeStore from bub.runtime import BubError -from bub.tape import AsyncTapeStore, Tape, TapeContext, TapeEntry, TapeQuery, build_messages +from bub.tape import ( + AsyncTapeStore, + TapeContext, + TapeEntry, + TapeQuery, + build_messages, +) @dataclass(frozen=True) @@ -35,17 +43,32 @@ class AnchorSummary: state: dict[str, object] -class TapeService: - def __init__(self, archive_path: Path, store: ForkTapeStore, context: TapeContext | None = None) -> None: - self._archive_path = archive_path - self._store = store - self._context = context or TapeContext() +@dataclass(frozen=True) +class Tape: + """Tape abstraction for recording agent interactions.""" + + archive_path: Path + store: AsyncTapeStore + context: TapeContext + _name: str | None = field(default=None, repr=False) + + @property + def name(self) -> str: + if self._name is None: + raise ValueError("tape is not scoped") + return self._name + + def with_context(self, context: TapeContext) -> Tape: + return replace(self, context=context) + + def scoped(self, name: str, context: TapeContext | None = None) -> Tape: + return replace(self, context=context or self.context, _name=name) - def query(self, tape_name: str) -> TapeQuery[AsyncTapeStore]: - return TapeQuery(tape=tape_name, store=self._store) + def query(self) -> TapeQuery[AsyncTapeStore]: + return TapeQuery(tape=self.name, store=self.store) - async def info(self, tape_name: str) -> TapeInfo: - entries = list(await self._store.fetch_all(self.query(tape_name))) + async def info(self) -> TapeInfo: + entries = list(await self.store.fetch_all(self.query())) anchors = [(i, entry) for i, entry in enumerate(entries) if entry.kind == "anchor"] if anchors: last_anchor = anchors[-1][1].payload.get("name") @@ -62,7 +85,7 @@ async def info(self, tape_name: str) -> TapeInfo: last_token_usage = token_usage break return TapeInfo( - name=tape_name, + name=self.name, entries=len(entries), anchors=len(anchors), last_anchor=str(last_anchor) if last_anchor else None, @@ -70,13 +93,13 @@ async def info(self, tape_name: str) -> TapeInfo: last_token_usage=last_token_usage, ) - async def ensure_bootstrap_anchor(self, tape_name: str) -> None: - anchors = list(await self._store.fetch_all(self.query(tape_name).kinds("anchor"))) + async def ensure_bootstrap_anchor(self) -> None: + anchors = list(await self.store.fetch_all(self.query().kinds("anchor"))) if not anchors: - await self.handoff(tape_name, name="session/start", state={"owner": "human"}) + await self.handoff(name="session/start", state={"owner": "human"}) - async def anchors(self, tape_name: str, limit: int = 20) -> list[AnchorSummary]: - entries = list(await self._store.fetch_all(self.query(tape_name).kinds("anchor"))) + async def anchors(self, limit: int = 20) -> list[AnchorSummary]: + entries = list(await self.store.fetch_all(self.query().kinds("anchor"))) results: list[AnchorSummary] = [] for entry in entries[-limit:]: name = str(entry.payload.get("name", "-")) @@ -86,37 +109,36 @@ async def anchors(self, tape_name: str, limit: int = 20) -> list[AnchorSummary]: return results async def search(self, query: TapeQuery[AsyncTapeStore]) -> list[TapeEntry]: - return list(await self._store.fetch_all(query)) + return list(await self.store.fetch_all(query)) - async def append_event(self, tape_name: str, name: str, payload: dict[str, Any], **meta: Any) -> None: - await self._store.append(tape_name, TapeEntry.event(name, payload, **meta)) + async def append_event(self, name: str, payload: dict[str, Any], **meta: Any) -> None: + await self.store.append(self.name, TapeEntry.event(name, payload, **meta)) - async def read_messages(self, tape: Tape) -> list[dict[str, Any]]: - query = tape.context.build_query(self.query(tape.name)) - entries = await self._store.fetch_all(query) - messages = build_messages(entries, tape.context) + async def read_messages(self) -> list[dict[str, Any]]: + query = self.context.build_query(self.query()) + entries = await self.store.fetch_all(query) + messages = build_messages(entries, self.context) if inspect.isawaitable(messages): messages = await messages return messages async def handoff( self, - tape_name: str, *, name: str, state: dict[str, Any] | None = None, **meta: Any, ) -> list[TapeEntry]: + tape_name = self.name entry = TapeEntry.anchor(name, state=state, **meta) event = TapeEntry.event("handoff", {"name": name, "state": state or {}}, **meta) - await self._store.append(tape_name, entry) - await self._store.append(tape_name, event) + await self.store.append(tape_name, entry) + await self.store.append(tape_name, event) return [entry, event] async def record_chat( # noqa: C901 self, *, - tape: str, run_id: str, system_prompt: str | None, new_messages: list[dict[str, Any]], @@ -130,21 +152,24 @@ async def record_chat( # noqa: C901 model: str | None = None, usage: dict[str, Any] | None = None, ) -> None: + tape_name = self.name meta = {"run_id": run_id} if system_prompt: - await self._store.append(tape, TapeEntry.system(system_prompt, **meta)) + await self.store.append(tape_name, TapeEntry.system(system_prompt, **meta)) if context_error is not None: - await self._store.append(tape, TapeEntry.error(context_error, **meta)) + await self.store.append(tape_name, TapeEntry.error(context_error, **meta)) for message in new_messages: - await self._store.append(tape, TapeEntry.message(message, **meta)) + await self.store.append(tape_name, TapeEntry.message(message, **meta)) if tool_calls: - await self._store.append(tape, TapeEntry.tool_call(tool_calls, **meta)) + await self.store.append(tape_name, TapeEntry.tool_call(tool_calls, **meta)) if tool_results is not None: - await self._store.append(tape, TapeEntry.tool_result(tool_results, **meta)) + await self.store.append(tape_name, TapeEntry.tool_result(tool_results, **meta)) if error is not None and error is not context_error: - await self._store.append(tape, TapeEntry.error(error, **meta)) + await self.store.append(tape_name, TapeEntry.error(error, **meta)) if response_text is not None: - await self._store.append(tape, TapeEntry.message({"role": "assistant", "content": response_text}, **meta)) + await self.store.append( + tape_name, TapeEntry.message({"role": "assistant", "content": response_text}, **meta) + ) data: dict[str, Any] = {"status": "error" if error is not None else "ok"} resolved_usage = usage or self._extract_usage(response) @@ -154,7 +179,7 @@ async def record_chat( # noqa: C901 data["provider"] = provider if model: data["model"] = model - await self._store.append(tape, TapeEntry.event("run", data, **meta)) + await self.store.append(tape_name, TapeEntry.event("run", data, **meta)) @staticmethod def _extract_usage(response: object) -> dict[str, Any] | None: @@ -168,34 +193,40 @@ def _extract_usage(response: object) -> dict[str, Any] | None: return payload if isinstance(payload, dict) else None return None - async def _archive(self, tape_name: str) -> Path: + async def _archive(self) -> Path: + tape_name = self.name stamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ") - self._archive_path.mkdir(parents=True, exist_ok=True) - archive_path = self._archive_path / f"{tape_name}.jsonl.{stamp}.bak" + self.archive_path.mkdir(parents=True, exist_ok=True) + archive_path = self.archive_path / f"{tape_name}.jsonl.{stamp}.bak" with archive_path.open("w", encoding="utf-8") as f: - for entry in await self._store.fetch_all(self.query(tape_name)): + for entry in await self.store.fetch_all(self.query()): f.write(json.dumps(asdict(entry), ensure_ascii=False) + "\n") return archive_path - async def reset(self, tape_name: str, *, archive: bool = False) -> str: + async def reset(self, *, archive: bool = False) -> str: archive_path: Path | None = None if archive: - archive_path = await self._archive(tape_name) - await self._store.reset(tape_name) + archive_path = await self._archive() + await self.store.reset(self.name) state = {"owner": "human"} if archive_path is not None: state["archived"] = str(archive_path) - await self.handoff(tape_name, name="session/start", state=state) + await self.handoff(name="session/start", state=state) return f"Archived: {archive_path}" if archive_path else "ok" - def session_tape(self, session_id: str, workspace: Path) -> Tape: + def session_tape(self, session_id: str, workspace: Path, context: TapeContext | None = None) -> Tape: workspace_hash = hashlib.md5(str(workspace.resolve()).encode("utf-8"), usedforsecurity=False).hexdigest()[:16] tape_name = ( workspace_hash + "__" + hashlib.md5(session_id.encode("utf-8"), usedforsecurity=False).hexdigest()[:16] ) - return Tape(name=tape_name, context=self._context) + return self.scoped(tape_name, context=context) @contextlib.asynccontextmanager - async def fork_tape(self, tape_name: str, merge_back: bool = True) -> AsyncGenerator[None, None]: - async with self._store.fork(tape_name, merge_back=merge_back): - yield + async def fork_tape(self, merge_back: bool = True) -> AsyncGenerator[Tape, None]: + fork_store = ForkTapeStore(self.store, self.name) + forked = replace(self, store=fork_store) + try: + yield forked + finally: + if merge_back: + await fork_store.merge_back() diff --git a/src/bub/builtin/tools.py b/src/bub/builtin/tools.py index 51484d52..700c7aa1 100644 --- a/src/bub/builtin/tools.py +++ b/src/bub/builtin/tools.py @@ -272,8 +272,7 @@ def skill_describe(name: str, *, context: ToolContext) -> str: @tool(context=True, name="tape.info") async def tape_info(context: ToolContext) -> str: """Get information about the current tape, such as number of entries and anchors.""" - agent = _get_agent(context) - info = await agent.tapes.info(context.tape or "") + info = await context.tape.info() return ( f"name: {info.name}\n" f"entries: {info.entries}\n" @@ -287,12 +286,11 @@ async def tape_info(context: ToolContext) -> str: @tool(context=True, name="tape.search", model=SearchInput) async def tape_search(param: SearchInput, *, context: ToolContext) -> str: """Search for entries in the current tape that match the query. Returns a list of matching entries.""" - agent = _get_agent(context) - query = agent.tapes.query(context.tape or "").query(param.query).kinds(*param.kinds).limit(param.limit) + query = context.tape.query().query(param.query).kinds(*param.kinds).limit(param.limit) if param.start or param.end: query = query.between_dates(param.start or "", param.end or "") - entries = await agent.tapes.search(query) + entries = await context.tape.search(query) lines: list[str] = [] for entry in entries: entry_str = json.dumps({"date": entry.date, "content": entry.payload}) @@ -307,24 +305,21 @@ async def tape_search(param: SearchInput, *, context: ToolContext) -> str: @tool(context=True, name="tape.reset") async def tape_reset(archive: bool = False, *, context: ToolContext) -> str: """Reset the current tape, optionally archiving it.""" - agent = _get_agent(context) - result = await agent.tapes.reset(context.tape or "", archive=archive) + result = await context.tape.reset(archive=archive) return result @tool(context=True, name="tape.handoff") async def tape_handoff(name: str = "handoff", summary: str = "", *, context: ToolContext) -> str: """Add a handoff anchor to the current tape.""" - agent = _get_agent(context) - await agent.tapes.handoff(context.tape or "", name=name, state={"summary": summary}) + await context.tape.handoff(name=name, state={"summary": summary}) return f"anchor added: {name}" @tool(context=True, name="tape.anchors") async def tape_anchors(*, context: ToolContext) -> str: """List anchors in the current tape.""" - agent = _get_agent(context) - anchors = await agent.tapes.anchors(context.tape or "") + anchors = await context.tape.anchors() if not anchors: return "(no anchors)" return "\n".join(f"- {anchor.name}" for anchor in anchors) diff --git a/src/bub/channels/cli/__init__.py b/src/bub/channels/cli/__init__.py index 1b87324b..9119d3be 100644 --- a/src/bub/channels/cli/__init__.py +++ b/src/bub/channels/cli/__init__.py @@ -149,8 +149,8 @@ def _install_log_sink(self) -> int: return logger.add(self._renderer.log, colorize=False, format="{level:<8} | {message}") async def _refresh_tape_info(self) -> None: - tape = self._agent.tapes.session_tape(self._message_template["session_id"], self._workspace) - info = await self._agent.tapes.info(tape.name) + tape = self._agent.tape.session_tape(self._message_template["session_id"], self._workspace) + info = await tape.info() self._last_tape_info = info def set_metadata(self, session_id: str | None = None, chat_id: str | None = None) -> None: diff --git a/src/bub/tape.py b/src/bub/tape.py index 29886348..de973f3c 100644 --- a/src/bub/tape.py +++ b/src/bub/tape.py @@ -170,14 +170,6 @@ def build_query(self, query: TapeQuery) -> TapeQuery: return query.after_anchor(self.anchor) -@dataclass -class Tape: - """A scoped conversation tape used by the agent runtime.""" - - name: str - context: TapeContext - - def build_messages(entries: Iterable[TapeEntry], context: TapeContext) -> SelectedMessages: if context.select is not None: return context.select(entries, context) diff --git a/src/bub/tools.py b/src/bub/tools.py index 3c6a4525..3a68358e 100644 --- a/src/bub/tools.py +++ b/src/bub/tools.py @@ -9,8 +9,9 @@ from typing import Any, overload from loguru import logger -from pydantic import BaseModel, TypeAdapter, ValidationError, validate_call +from pydantic import BaseModel, ConfigDict, TypeAdapter, ValidationError, validate_call +from bub.builtin.tape import Tape from bub.runtime import BubError, ErrorKind @@ -18,7 +19,7 @@ class ToolContext: """Runtime context passed to tools that opt into context.""" - tape: str | None = None + tape: Tape run_id: str | None = None state: dict[str, Any] = field(default_factory=dict) @@ -61,6 +62,23 @@ def _schema_from_signature(signature: inspect.Signature, *, ignore_params: set[s return schema +def _signature_without_context(signature: inspect.Signature) -> inspect.Signature: + parameters = [param for param in signature.parameters.values() if param.name != "context"] + return signature.replace(parameters=parameters) + + +def _validate_without_context(func: Callable[..., Any], signature: inspect.Signature) -> Callable[..., Any]: + def validate_target(*args: Any, **kwargs: Any) -> tuple[tuple[Any, ...], dict[str, Any]]: + return args, kwargs + + validate_target.__name__ = _callable_name(func) + validate_target.__qualname__ = getattr(func, "__qualname__", validate_target.__name__) + validate_target.__annotations__ = dict(getattr(func, "__annotations__", {})) + validate_target.__annotations__.pop("context", None) + validate_target.__signature__ = _signature_without_context(signature) # type: ignore[attr-defined] + return validate_call(validate_target) + + @dataclass(frozen=True) class Tool: """A callable unit the model can invoke.""" @@ -89,7 +107,16 @@ def from_callable( tool_name = name or _to_snake_case(_callable_name(func)) tool_description = description if description is not None else (inspect.getdoc(func) or "") parameters = _schema_from_signature(signature, ignore_params={"context"} if context else None) - validated = validate_call(func) + if context: + validate_args = _validate_without_context(func, signature) + + def validated(*args: Any, **kwargs: Any) -> Any: + tool_context = kwargs.pop("context") + validated_args, validated_kwargs = validate_args(*args, **kwargs) + return func(*validated_args, context=tool_context, **validated_kwargs) + + else: + validated = validate_call(config=ConfigDict(arbitrary_types_allowed=True))(func) return cls( name=tool_name, description=tool_description, diff --git a/tests/test_builtin_agent.py b/tests/test_builtin_agent.py index ba470306..bbb4cb2d 100644 --- a/tests/test_builtin_agent.py +++ b/tests/test_builtin_agent.py @@ -9,9 +9,10 @@ from any_llm.types.completion import ChatCompletionChunk from bub.builtin.agent import Agent +from bub.builtin.model_runner import ModelRunner from bub.builtin.settings import AgentSettings from bub.runtime import BubError -from bub.tape import Tape, TapeContext +from bub.tape import TapeContext from bub.tools import REGISTRY, tool # --------------------------------------------------------------------------- @@ -19,6 +20,16 @@ # --------------------------------------------------------------------------- +class _FakeModelRunner(ModelRunner): + def __init__(self, settings: AgentSettings) -> None: + super().__init__(settings) + self.completion_kwargs: dict[str, Any] | None = None + + async def completion_response(self, **kwargs: Any) -> AsyncIterator[ChatCompletionChunk]: + self.completion_kwargs = kwargs + return _chat_stream("done") + + def _make_agent() -> Agent: """Build an Agent with a mocked framework, bypassing real LLM/tape init.""" framework = MagicMock() @@ -30,14 +41,13 @@ def _make_agent() -> Agent: agent.settings = AgentSettings.model_construct(model="test:model", api_key="k", api_base="b", client_args={}) agent.framework = framework + agent.model_runner = _FakeModelRunner(agent.settings) + return agent - async def fake_completion_response(**kwargs: Any) -> AsyncIterator[ChatCompletionChunk]: - agent.completion_kwargs = kwargs - return _chat_stream("done") - agent.completion_kwargs = None - agent._completion_response = fake_completion_response # type: ignore[method-assign] - return agent +def _model_runner(agent: Agent) -> _FakeModelRunner: + assert isinstance(agent.model_runner, _FakeModelRunner) + return agent.model_runner def _chat_chunk(content: str) -> ChatCompletionChunk: @@ -76,35 +86,33 @@ async def fork_tape(self, tape_name: str, merge_back: bool = True) -> AsyncGener self.exit_count += 1 -class _FakeTapeService: - """Minimal TapeService stand-in for testing Agent.run().""" +class _FakeTape: + """Scoped tape stand-in for testing Agent.run().""" def __init__(self, fork_capture: _ForkCapture) -> None: self._fork = fork_capture + self.name = "test-tape" + self.context = TapeContext(state={}) self.messages: list[dict[str, Any]] = [] self.events: list[tuple[str, str, dict[str, Any]]] = [] - def session_tape(self, session_id: str, workspace: Any) -> Tape: - return Tape(name="test-tape", context=TapeContext(state={})) - - async def ensure_bootstrap_anchor(self, tape_name: str) -> None: + async def ensure_bootstrap_anchor(self) -> None: pass @contextlib.asynccontextmanager - async def fork_tape(self, tape_name: str, merge_back: bool = True) -> AsyncGenerator[None, None]: - async with self._fork.fork_tape(tape_name, merge_back=merge_back): - yield + async def fork_tape(self, merge_back: bool = True) -> AsyncGenerator[_FakeTape, None]: + async with self._fork.fork_tape(self.name, merge_back=merge_back): + yield self - async def read_messages(self, tape: Tape) -> list[dict[str, Any]]: + async def read_messages(self) -> list[dict[str, Any]]: return list(self.messages) - async def append_event(self, tape_name: str, name: str, payload: dict[str, Any], **meta: Any) -> None: - self.events.append((tape_name, name, payload)) + async def append_event(self, name: str, payload: dict[str, Any], **meta: Any) -> None: + self.events.append((self.name, name, payload)) async def record_chat( self, *, - tape: str, run_id: str, system_prompt: str | None, new_messages: list[dict[str, Any]], @@ -119,19 +127,33 @@ async def record_chat( usage: dict[str, Any] | None = None, ) -> None: if system_prompt: - self.events.append((tape, "system", {"content": system_prompt})) + self.events.append((self.name, "system", {"content": system_prompt})) if context_error is not None: - self.events.append((tape, "error", context_error.as_dict())) + self.events.append((self.name, "error", context_error.as_dict())) self.messages.extend(new_messages) if tool_calls: - self.events.append((tape, "tool_call", {"calls": tool_calls})) + self.events.append((self.name, "tool_call", {"calls": tool_calls})) if tool_results is not None: - self.events.append((tape, "tool_result", {"results": tool_results})) + self.events.append((self.name, "tool_result", {"results": tool_results})) if error is not None and error is not context_error: - self.events.append((tape, "error", error.as_dict())) + self.events.append((self.name, "error", error.as_dict())) if response_text is not None: self.messages.append({"role": "assistant", "content": response_text}) - self.events.append((tape, "run", {"run_id": run_id, "model": model, "error": error is not None})) + self.events.append((self.name, "run", {"run_id": run_id, "model": model, "error": error is not None})) + + +class _FakeTapeFactory: + """Minimal tape factory stand-in for testing Agent.run().""" + + def __init__(self, fork_capture: _ForkCapture) -> None: + self.tape = _FakeTape(fork_capture) + self.context = self.tape.context + + def session_tape(self, session_id: str, workspace: Any, context: TapeContext | None = None) -> _FakeTape: + if context is not None: + self.tape.context = context + self.context = context + return self.tape @pytest.mark.asyncio @@ -139,7 +161,7 @@ async def test_agent_run_regular_session_merges_back() -> None: """A regular (non-temp) session should merge tape entries back.""" agent = _make_agent() fork_capture = _ForkCapture() - agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] + agent.tape = _FakeTapeFactory(fork_capture) # type: ignore[assignment] result = await agent.run_stream(session_id="user/session1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 @@ -157,7 +179,7 @@ async def test_agent_run_temp_session_does_not_merge_back() -> None: """A temp/ session should NOT merge tape entries back.""" agent = _make_agent() fork_capture = _ForkCapture() - agent.tapes = _FakeTapeService(fork_capture) # type: ignore[assignment] + agent.tape = _FakeTapeFactory(fork_capture) # type: ignore[assignment] result = await agent.run_stream(session_id="temp/abc123", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 @@ -175,8 +197,8 @@ async def test_agent_run_passes_model_to_llm() -> None: """The model parameter should be forwarded to any-llm.""" agent = _make_agent() fork_capture = _ForkCapture() - fake_tapes = _FakeTapeService(fork_capture) - agent.tapes = fake_tapes # type: ignore[assignment] + fake_tapes = _FakeTapeFactory(fork_capture) + agent.tape = fake_tapes # type: ignore[assignment] result = await agent.run_stream( session_id="user/s1", @@ -186,13 +208,15 @@ async def test_agent_run_passes_model_to_llm() -> None: ) [event async for event in result] - assert agent.completion_kwargs["model"] == "openai:gpt-4o" + completion_kwargs = _model_runner(agent).completion_kwargs + assert completion_kwargs is not None + assert completion_kwargs["model"] == "openai:gpt-4o" @pytest.mark.asyncio async def test_agent_run_empty_prompt_returns_error() -> None: agent = _make_agent() - agent.tapes = MagicMock() # type: ignore[assignment] + agent.tape = MagicMock() result = await agent.run_stream(session_id="user/s1", prompt="", state={}) events = [event async for event in result] @@ -208,13 +232,15 @@ async def test_agent_run_model_defaults_to_none() -> None: """When model is not specified, settings.model is used for any-llm.""" agent = _make_agent() fork_capture = _ForkCapture() - fake_tapes = _FakeTapeService(fork_capture) - agent.tapes = fake_tapes # type: ignore[assignment] + fake_tapes = _FakeTapeFactory(fork_capture) + agent.tape = fake_tapes # type: ignore[assignment] result = await agent.run_stream(session_id="user/s1", prompt="hello", state={"_runtime_workspace": "/tmp"}) # noqa: S108 [event async for event in result] - assert agent.completion_kwargs["model"] == "test:model" + completion_kwargs = _model_runner(agent).completion_kwargs + assert completion_kwargs is not None + assert completion_kwargs["model"] == "test:model" @pytest.mark.asyncio @@ -234,8 +260,8 @@ def denied_agent_tool() -> str: agent = _make_agent() fork_capture = _ForkCapture() - fake_tapes = _FakeTapeService(fork_capture) - agent.tapes = fake_tapes # type: ignore[assignment] + fake_tapes = _FakeTapeFactory(fork_capture) + agent.tape = fake_tapes # type: ignore[assignment] result = await agent.run_stream( session_id="user/s1", @@ -245,9 +271,10 @@ def denied_agent_tool() -> str: ) [event async for event in result] - assert agent.completion_kwargs is not None - assert [tool.name for tool in agent.completion_kwargs["tools"]] == ["tests_allowed_agent_tool"] - system_prompt = agent.completion_kwargs["messages"][0]["content"] + completion_kwargs = _model_runner(agent).completion_kwargs + assert completion_kwargs is not None + assert [tool.name for tool in completion_kwargs["tools"]] == ["tests_allowed_agent_tool"] + system_prompt = completion_kwargs["messages"][0]["content"] assert "- tests_allowed_agent_tool(): Allowed tool" in system_prompt assert "tests_denied_agent_tool" not in system_prompt @@ -256,8 +283,8 @@ def denied_agent_tool() -> str: async def test_agent_run_rejects_unknown_allowed_tools() -> None: agent = _make_agent() fork_capture = _ForkCapture() - fake_tapes = _FakeTapeService(fork_capture) - agent.tapes = fake_tapes # type: ignore[assignment] + fake_tapes = _FakeTapeFactory(fork_capture) + agent.tape = fake_tapes # type: ignore[assignment] stream = await agent.run_stream( session_id="user/s1", diff --git a/tests/test_builtin_tape.py b/tests/test_builtin_tape.py new file mode 100644 index 00000000..236cf9a1 --- /dev/null +++ b/tests/test_builtin_tape.py @@ -0,0 +1,34 @@ +from __future__ import annotations + +from pathlib import Path + +import pytest + +from bub.builtin.store import ForkTapeStore +from bub.builtin.tape import Tape +from bub.tape import AsyncTapeStoreAdapter, InMemoryTapeStore, TapeContext + + +@pytest.mark.asyncio +async def test_tape_fork_binds_temporary_fork_store_to_scoped_tape(tmp_path: Path) -> None: + parent = InMemoryTapeStore() + root = Tape(tmp_path, AsyncTapeStoreAdapter(parent), TapeContext()).scoped("test-tape") + + async with root.fork_tape(merge_back=True) as forked: + first_store = forked.store + + assert isinstance(first_store, ForkTapeStore) + assert first_store is not root.store + + await forked.append_event("step", {"value": 1}) + assert parent.read("test-tape") is None + + assert [entry.payload["name"] for entry in parent.read("test-tape") or []] == ["step"] + + async with root.fork_tape(merge_back=False) as forked: + second_store = forked.store + await forked.append_event("step", {"value": 2}) + + assert isinstance(second_store, ForkTapeStore) + assert second_store is not first_store + assert [entry.payload["data"]["value"] for entry in parent.read("test-tape") or []] == [1] diff --git a/tests/test_builtin_tools.py b/tests/test_builtin_tools.py index d5bc7e56..88ca07ac 100644 --- a/tests/test_builtin_tools.py +++ b/tests/test_builtin_tools.py @@ -10,6 +10,7 @@ import bub.builtin.tools as builtin_tools from bub.builtin.shell_manager import ShellManager +from bub.builtin.tape import Tape from bub.builtin.tools import ( bash, bash_output, @@ -21,11 +22,13 @@ resolve_tool_names, ) from bub.runtime import ErrorKind +from bub.tape import AsyncTapeStoreAdapter, InMemoryTapeStore, TapeContext from bub.tools import REGISTRY, Tool, ToolContext, ToolExecutor, tool def _tool_context(tmp_path, **state) -> ToolContext: - return ToolContext(tape="test-tape", run_id="test-run", state={"_runtime_workspace": str(tmp_path), **state}) + tape = Tape(tmp_path, AsyncTapeStoreAdapter(InMemoryTapeStore()), TapeContext()).scoped("test-tape") + return ToolContext(tape=tape, run_id="test-run", state={"_runtime_workspace": str(tmp_path), **state}) def _python_shell(code: str) -> str: diff --git a/tests/test_file_tape_store_entry_ids.py b/tests/test_file_tape_store_entry_ids.py index da8ec666..d891de97 100644 --- a/tests/test_file_tape_store_entry_ids.py +++ b/tests/test_file_tape_store_entry_ids.py @@ -3,19 +3,20 @@ import pytest from bub.builtin.store import FileTapeStore, ForkTapeStore -from bub.tape import TapeEntry +from bub.tape import AsyncTapeStoreAdapter, TapeEntry @pytest.mark.asyncio async def test_file_tape_store_assigns_monotonic_ids_when_merging_forked_entries(tmp_path) -> None: parent = FileTapeStore(directory=tmp_path) - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "tape") - async with store.fork("tape", merge_back=True): - await store.append("tape", TapeEntry.event(name="first", data={"n": 1})) + await store.append("tape", TapeEntry.event(name="first", data={"n": 1})) + await store.merge_back() - async with store.fork("tape", merge_back=True): - await store.append("tape", TapeEntry.event(name="second", data={"n": 2})) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "tape") + await store.append("tape", TapeEntry.event(name="second", data={"n": 2})) + await store.merge_back() entries = parent.read("tape") or [] assert [entry.id for entry in entries] == [1, 2] diff --git a/tests/test_fork_store_merge_back.py b/tests/test_fork_store_merge_back.py index 4652ca53..c0f3d157 100644 --- a/tests/test_fork_store_merge_back.py +++ b/tests/test_fork_store_merge_back.py @@ -3,18 +3,18 @@ import pytest from bub.builtin.store import ForkTapeStore -from bub.tape import InMemoryTapeStore, TapeEntry, TapeQuery +from bub.tape import AsyncTapeStoreAdapter, InMemoryTapeStore, TapeEntry, TapeQuery @pytest.mark.asyncio async def test_fork_merge_back_true_merges_entries() -> None: """With merge_back=True (default), forked entries are merged into the parent.""" parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") - async with store.fork("test-tape", merge_back=True): - await store.append("test-tape", TapeEntry.event(name="step", data={"x": 1})) - await store.append("test-tape", TapeEntry.event(name="step", data={"x": 2})) + await store.append("test-tape", TapeEntry.event(name="step", data={"x": 1})) + await store.append("test-tape", TapeEntry.event(name="step", data={"x": 2})) + await store.merge_back() entries = parent.read("test-tape") assert entries is not None @@ -25,10 +25,9 @@ async def test_fork_merge_back_true_merges_entries() -> None: async def test_fork_merge_back_false_discards_entries() -> None: """With merge_back=False, forked entries are NOT merged into the parent.""" parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") - async with store.fork("test-tape", merge_back=False): - await store.append("test-tape", TapeEntry.event(name="step", data={"x": 1})) + await store.append("test-tape", TapeEntry.event(name="step", data={"x": 1})) entries = parent.read("test-tape") # No entries should have been merged @@ -36,28 +35,24 @@ async def test_fork_merge_back_false_discards_entries() -> None: @pytest.mark.asyncio -async def test_fork_default_merge_back_is_true() -> None: - """The default value of merge_back should be True.""" +async def test_merge_back_can_be_called_without_entries() -> None: parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") - async with store.fork("test-tape"): - await store.append("test-tape", TapeEntry.event(name="step", data={"v": 1})) + await store.merge_back() entries = parent.read("test-tape") - assert entries is not None - assert len(entries) == 1 + assert entries is None or len(entries) == 0 @pytest.mark.asyncio async def test_fork_reset_with_merge_back_false_preserves_parent_entries() -> None: parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) - async with store.fork("test-tape", merge_back=False): - await store.reset("test-tape") - await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) entries = parent.read("test-tape") assert entries is not None @@ -67,12 +62,12 @@ async def test_fork_reset_with_merge_back_false_preserves_parent_entries() -> No @pytest.mark.asyncio async def test_fork_reset_with_merge_back_true_replaces_parent_entries() -> None: parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) - async with store.fork("test-tape", merge_back=True): - await store.reset("test-tape") - await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + await store.merge_back() entries = parent.read("test-tape") assert entries is not None @@ -82,23 +77,22 @@ async def test_fork_reset_with_merge_back_true_replaces_parent_entries() -> None @pytest.mark.asyncio async def test_fork_reset_hides_parent_entries_during_fetch() -> None: parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "test-tape") parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) - async with store.fork("test-tape", merge_back=False): - await store.reset("test-tape") - await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) + await store.reset("test-tape") + await store.append("test-tape", TapeEntry.event(name="inside", data={"x": 2})) - query = TapeQuery(tape="test-tape", store=store) - entries = list(await store.fetch_all(query)) + query = TapeQuery(tape="test-tape", store=store) + entries = list(await store.fetch_all(query)) assert [entry.payload["name"] for entry in entries] == ["inside"] @pytest.mark.asyncio -async def test_reset_outside_fork_resets_parent_immediately() -> None: +async def test_reset_for_unbound_tape_resets_parent_immediately() -> None: parent = InMemoryTapeStore() - store = ForkTapeStore(parent) + store = ForkTapeStore(AsyncTapeStoreAdapter(parent), "other-tape") parent.append("test-tape", TapeEntry.event(name="before", data={"x": 1})) await store.reset("test-tape") diff --git a/tests/test_tape_search_output.py b/tests/test_tape_search_output.py index 5159234a..acc55319 100644 --- a/tests/test_tape_search_output.py +++ b/tests/test_tape_search_output.py @@ -19,7 +19,10 @@ class _FakeTapes: def __init__(self, entries: list[_FakeEntry]) -> None: self._entries = entries - def query(self, _tape: str) -> _FakeQuery: + def scoped(self, _tape: str) -> _FakeTapes: + return self + + def query(self) -> _FakeQuery: return _FakeQuery() async def search(self, _query: object) -> list[_FakeEntry]: @@ -53,7 +56,7 @@ async def test_tape_search_reports_shown_matches_and_filtered_count(monkeypatch) ] monkeypatch.setattr(builtin_tools, "_get_agent", lambda _context: _FakeAgent(entries)) - output = await tape_search.run(query="x", context=ToolContext(tape="tape", run_id="run", state={})) + output = await tape_search.run(query="x", context=ToolContext(tape=_FakeTapes(entries), run_id="run", state={})) assert output.splitlines()[0] == "[tape.search]: 1 matches (1 filtered)" @@ -66,6 +69,6 @@ async def test_tape_search_reports_zero_filtered_explicitly(monkeypatch) -> None ] monkeypatch.setattr(builtin_tools, "_get_agent", lambda _context: _FakeAgent(entries)) - output = await tape_search.run(query="x", context=ToolContext(tape="tape", run_id="run", state={})) + output = await tape_search.run(query="x", context=ToolContext(tape=_FakeTapes(entries), run_id="run", state={})) assert output.splitlines()[0] == "[tape.search]: 2 matches (0 filtered)"