From ae4acb4b9dae37a717a1ad9f35ccd5474fb56d03 Mon Sep 17 00:00:00 2001 From: Leon Zhao Date: Thu, 18 Jun 2026 21:09:36 +0800 Subject: [PATCH] feat: add ACP model config option --- .../src/bub_acp_server/plugin.py | 105 +++++++++++++++++- packages/bub-acp-server/tests/test_plugin.py | 80 +++++++++++++ 2 files changed, 181 insertions(+), 4 deletions(-) diff --git a/packages/bub-acp-server/src/bub_acp_server/plugin.py b/packages/bub-acp-server/src/bub_acp_server/plugin.py index b714e8d..8295d09 100644 --- a/packages/bub-acp-server/src/bub_acp_server/plugin.py +++ b/packages/bub-acp-server/src/bub_acp_server/plugin.py @@ -42,9 +42,12 @@ ResourceContentBlock, SessionCapabilities, SessionCloseCapabilities, + SessionConfigOptionSelect, + SessionConfigSelectOption, SessionInfo, SessionListCapabilities, SessionResumeCapabilities, + SetSessionConfigOptionResponse, SseMcpServer, TextContentBlock, ToolKind, @@ -77,6 +80,7 @@ class ACPSession: session_id: str cwd: Path additional_directories: list[str] = field(default_factory=list) + runtime: dict[str, str] = field(default_factory=dict) title: str | None = None updated_at: str | None = None @@ -97,6 +101,7 @@ def to_json(self) -> dict[str, object]: "session_id": self.session_id, "cwd": str(self.cwd), "additional_directories": list(self.additional_directories), + "runtime": dict(self.runtime), "title": self.title, "updated_at": self.updated_at, } @@ -116,10 +121,14 @@ def from_json(cls, data: Mapping[str, object]) -> ACPSession | None: title = data.get("title") updated_at = data.get("updated_at") + runtime = data.get("runtime") + if not isinstance(runtime, Mapping): + runtime = {} return cls( session_id=session_id, cwd=Path(cwd).expanduser().resolve(), additional_directories=[str(item) for item in additional_directories if isinstance(item, str)], + runtime={str(key): str(value) for key, value in runtime.items() if isinstance(key, str)}, title=title if isinstance(title, str) else None, updated_at=updated_at if isinstance(updated_at, str) else None, ) @@ -264,7 +273,10 @@ async def new_session( session.touch() self._sessions[session_id] = session self._save_sessions() - return NewSessionResponse(session_id=session_id) + return NewSessionResponse( + session_id=session_id, + config_options=await self._session_config_options(session), + ) async def load_session( self, @@ -281,7 +293,7 @@ async def load_session( additional_directories=additional_directories, ) await self._attach_session_history(session) - return LoadSessionResponse() + return LoadSessionResponse(config_options=await self._session_config_options(session)) async def resume_session( self, @@ -292,12 +304,12 @@ async def resume_session( **kwargs: Any, ) -> ResumeSessionResponse: del mcp_servers, kwargs - self._load_or_adopt_session( + session = self._load_or_adopt_session( session_id=session_id, cwd=cwd, additional_directories=additional_directories, ) - return ResumeSessionResponse() + return ResumeSessionResponse(config_options=await self._session_config_options(session)) async def list_sessions( self, @@ -321,6 +333,20 @@ async def cancel(self, session_id: str, **kwargs: Any) -> None: del kwargs await self.framework.quit_via_router(session_id) + async def set_config_option( + self, + config_id: str, + session_id: str, + value: str | bool, + **kwargs: Any, + ) -> SetSessionConfigOptionResponse: + del kwargs + session = self._sessions.get(session_id) or self._adopt_session(session_id) + session.touch() + config_options = await self._set_session_runtime_option(session, config_id, value) + self._save_sessions() + return SetSessionConfigOptionResponse(config_options=config_options) + async def prompt( self, prompt: list[ACPPromptBlock], @@ -345,6 +371,7 @@ async def prompt( media=media, context={"acp_session_id": session_id}, ) + setattr(inbound, "runtime", dict(session.runtime)) if self.settings.send_user_message_updates: await self._send_user_message_updates(prompt, session_id) @@ -465,6 +492,31 @@ async def _load_tape_entries(self, session: ACPSession) -> list[TapeEntry]: return list(cast(Iterable[TapeEntry], result)) return _load_tape_entries_from_file(bub.home.expanduser() / "tapes" / f"{tape_name}.jsonl") + async def _session_config_options(self, session: ACPSession) -> list[SessionConfigOptionSelect] | None: + runtime_options = await _framework_runtime_options(self.framework, session) + acp_options = _runtime_options_to_acp_config_options(runtime_options, session) + return acp_options or None + + async def _set_session_runtime_option( + self, + session: ACPSession, + config_id: str, + value: str | bool, + ) -> list[SessionConfigOptionSelect]: + if config_id != "model": + raise ValueError(f"unknown ACP config option: {config_id}") + if not isinstance(value, str): + raise ValueError(f"invalid value for ACP config option {config_id}: {value}") + config_options = await self._session_config_options(session) + selected_option = next((option for option in config_options or [] if option.id == "model"), None) + if selected_option is None: + raise ValueError(f"unknown ACP config option: {config_id}") + allowed_values = {option.value for option in selected_option.options} + if value not in allowed_values: + raise ValueError(f"invalid value for ACP config option {config_id}: {value}") + session.runtime["model"] = value + return await self._session_config_options(session) or [] + async def _send_user_message_updates(self, prompt: list[ACPPromptBlock], session_id: str) -> None: client = self._require_client() for block in prompt: @@ -619,6 +671,51 @@ def _framework_tape_store(framework: BubFramework) -> object | None: return store if hasattr(store, "fetch_all") else None +async def _framework_runtime_options(framework: BubFramework, session: ACPSession) -> object | None: + get_runtime_options = getattr(framework, "get_runtime_options", None) + if get_runtime_options is None: + return None + result = get_runtime_options(session_id=session.session_id, workspace=session.cwd) + if inspect.isawaitable(result): + result = await result + return result + + +def _runtime_options_to_acp_config_options(runtime_options: object | None, session: ACPSession) -> list[SessionConfigOptionSelect]: + if runtime_options is None: + return [] + choices = _list_payload(_block_value(runtime_options, "models", [])) + if not choices: + return [] + current_model = _block_value(runtime_options, "current_model", None) + current_value = session.runtime.get("model") or ( + str(current_model) if current_model is not None else str(_block_value(choices[0], "id")) + ) + return [ + SessionConfigOptionSelect( + type="select", + id="model", + name="Model", + current_value=current_value, + options=[_runtime_choice_to_acp_option(choice) for choice in choices], + category="model", + ) + ] + + +def _runtime_choice_to_acp_option(choice: object) -> SessionConfigSelectOption: + choice_id = str(_block_value(choice, "id")) + name = _block_value(choice, "name", None) + description = _block_value(choice, "description", None) + meta = _block_value(choice, "meta", None) + return SessionConfigSelectOption( + value=choice_id, + name=str(name) if name is not None else choice_id, + description=str(description) if description is not None else None, + field_meta=dict(meta) if isinstance(meta, Mapping) else None, + ) + + def _session_tape_name(session_id: str, workspace: Path) -> str: workspace_hash = hashlib.md5(str(workspace.resolve()).encode("utf-8"), usedforsecurity=False).hexdigest()[:16] session_hash = hashlib.md5(session_id.encode("utf-8"), usedforsecurity=False).hexdigest()[:16] diff --git a/packages/bub-acp-server/tests/test_plugin.py b/packages/bub-acp-server/tests/test_plugin.py index addee84..07573e5 100644 --- a/packages/bub-acp-server/tests/test_plugin.py +++ b/packages/bub-acp-server/tests/test_plugin.py @@ -1,6 +1,7 @@ from __future__ import annotations from pathlib import Path +from types import SimpleNamespace from typing import Any import pytest @@ -92,6 +93,22 @@ async def stream(): return TurnResult(session_id=inbound.session_id, prompt=inbound.content, model_output="late text") +class ConfigFramework(FakeFramework): + def __init__(self) -> None: + super().__init__() + self.runtime_queries: list[tuple[str, Path]] = [] + + async def get_runtime_options(self, *, session_id: str, workspace: Path) -> object: + self.runtime_queries.append((session_id, workspace)) + return SimpleNamespace( + models=[ + SimpleNamespace(id="openai:gpt-5", name="GPT-5", description="OpenAI model"), + SimpleNamespace(id="anthropic:claude-sonnet-4-5", name="Claude Sonnet"), + ], + current_model="openai:gpt-5", + ) + + @pytest.mark.asyncio async def test_initialize_advertises_session_capabilities() -> None: agent = BubACPAgent(FakeFramework()) @@ -180,6 +197,69 @@ async def test_sessions_survive_agent_restart(tmp_path: Path) -> None: assert sessions.sessions[0].cwd == str(tmp_path) +@pytest.mark.asyncio +async def test_session_lifecycle_returns_config_options(tmp_path: Path) -> None: + framework = ConfigFramework() + client = FakeClient() + agent = BubACPAgent(framework) + agent.on_connect(client) + + created = await agent.new_session(cwd=str(tmp_path)) + loaded = await agent.load_session(cwd=str(tmp_path), session_id=created.session_id) + resumed = await agent.resume_session(cwd=str(tmp_path), session_id=created.session_id) + + assert created.config_options is not None + assert created.config_options[0].id == "model" + assert created.config_options[0].name == "Model" + assert created.config_options[0].current_value == "openai:gpt-5" + assert created.config_options[0].options[0].value == "openai:gpt-5" + assert len(created.config_options) == 1 + assert loaded.config_options is not None + assert loaded.config_options[0].id == "model" + assert resumed.config_options is not None + assert resumed.config_options[0].id == "model" + assert framework.runtime_queries == [ + (created.session_id, tmp_path), + (created.session_id, tmp_path), + (created.session_id, tmp_path), + ] + + +@pytest.mark.asyncio +async def test_set_config_option_updates_session_runtime_and_returns_config_options(tmp_path: Path) -> None: + framework = ConfigFramework() + agent = BubACPAgent(framework) + created = await agent.new_session(cwd=str(tmp_path)) + + response = await agent.set_config_option( + config_id="model", + session_id=created.session_id, + value="anthropic:claude-sonnet-4-5", + ) + + assert agent._sessions[created.session_id].runtime == {"model": "anthropic:claude-sonnet-4-5"} + assert response.config_options[0].id == "model" + assert response.config_options[0].current_value == "anthropic:claude-sonnet-4-5" + + +@pytest.mark.asyncio +async def test_prompt_passes_acp_runtime_selection_to_bub_message(tmp_path: Path) -> None: + framework = ConfigFramework() + client = FakeClient() + agent = BubACPAgent(framework) + agent.on_connect(client) + created = await agent.new_session(cwd=str(tmp_path)) + await agent.set_config_option( + config_id="model", + session_id=created.session_id, + value="anthropic:claude-sonnet-4-5", + ) + + await agent.prompt([TextContentBlock(type="text", text="hello")], session_id=created.session_id) + + assert framework.messages[0].runtime == {"model": "anthropic:claude-sonnet-4-5"} + + @pytest.mark.asyncio async def test_session_store_expands_user_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("HOME", str(tmp_path))