Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 101 additions & 4 deletions packages/bub-acp-server/src/bub_acp_server/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@
ResourceContentBlock,
SessionCapabilities,
SessionCloseCapabilities,
SessionConfigOptionSelect,
SessionConfigSelectOption,
SessionInfo,
SessionListCapabilities,
SessionResumeCapabilities,
SetSessionConfigOptionResponse,
SseMcpServer,
TextContentBlock,
ToolKind,
Expand Down Expand Up @@ -77,6 +80,7 @@ class ACPSession:
session_id: str
cwd: Path
additional_directories: list[str] = field(default_factory=list)
runtime: dict[str, str] = field(default_factory=dict)
title: str | None = None
updated_at: str | None = None

Expand All @@ -97,6 +101,7 @@ def to_json(self) -> dict[str, object]:
"session_id": self.session_id,
"cwd": str(self.cwd),
"additional_directories": list(self.additional_directories),
"runtime": dict(self.runtime),
"title": self.title,
"updated_at": self.updated_at,
}
Expand All @@ -116,10 +121,14 @@ def from_json(cls, data: Mapping[str, object]) -> ACPSession | None:

title = data.get("title")
updated_at = data.get("updated_at")
runtime = data.get("runtime")
if not isinstance(runtime, Mapping):
runtime = {}
return cls(
session_id=session_id,
cwd=Path(cwd).expanduser().resolve(),
additional_directories=[str(item) for item in additional_directories if isinstance(item, str)],
runtime={str(key): str(value) for key, value in runtime.items() if isinstance(key, str)},
title=title if isinstance(title, str) else None,
updated_at=updated_at if isinstance(updated_at, str) else None,
)
Expand Down Expand Up @@ -264,7 +273,10 @@ async def new_session(
session.touch()
self._sessions[session_id] = session
self._save_sessions()
return NewSessionResponse(session_id=session_id)
return NewSessionResponse(
session_id=session_id,
config_options=await self._session_config_options(session),
)

async def load_session(
self,
Expand All @@ -281,7 +293,7 @@ async def load_session(
additional_directories=additional_directories,
)
await self._attach_session_history(session)
return LoadSessionResponse()
return LoadSessionResponse(config_options=await self._session_config_options(session))

async def resume_session(
self,
Expand All @@ -292,12 +304,12 @@ async def resume_session(
**kwargs: Any,
) -> ResumeSessionResponse:
del mcp_servers, kwargs
self._load_or_adopt_session(
session = self._load_or_adopt_session(
session_id=session_id,
cwd=cwd,
additional_directories=additional_directories,
)
return ResumeSessionResponse()
return ResumeSessionResponse(config_options=await self._session_config_options(session))

async def list_sessions(
self,
Expand All @@ -321,6 +333,20 @@ async def cancel(self, session_id: str, **kwargs: Any) -> None:
del kwargs
await self.framework.quit_via_router(session_id)

async def set_config_option(
self,
config_id: str,
session_id: str,
value: str | bool,
**kwargs: Any,
) -> SetSessionConfigOptionResponse:
del kwargs
session = self._sessions.get(session_id) or self._adopt_session(session_id)
session.touch()
config_options = await self._set_session_runtime_option(session, config_id, value)
self._save_sessions()
return SetSessionConfigOptionResponse(config_options=config_options)

async def prompt(
self,
prompt: list[ACPPromptBlock],
Expand All @@ -345,6 +371,7 @@ async def prompt(
media=media,
context={"acp_session_id": session_id},
)
setattr(inbound, "runtime", dict(session.runtime))
if self.settings.send_user_message_updates:
await self._send_user_message_updates(prompt, session_id)

Expand Down Expand Up @@ -465,6 +492,31 @@ async def _load_tape_entries(self, session: ACPSession) -> list[TapeEntry]:
return list(cast(Iterable[TapeEntry], result))
return _load_tape_entries_from_file(bub.home.expanduser() / "tapes" / f"{tape_name}.jsonl")

async def _session_config_options(self, session: ACPSession) -> list[SessionConfigOptionSelect] | None:
runtime_options = await _framework_runtime_options(self.framework, session)
acp_options = _runtime_options_to_acp_config_options(runtime_options, session)
return acp_options or None

async def _set_session_runtime_option(
self,
session: ACPSession,
config_id: str,
value: str | bool,
) -> list[SessionConfigOptionSelect]:
if config_id != "model":
raise ValueError(f"unknown ACP config option: {config_id}")
if not isinstance(value, str):
raise ValueError(f"invalid value for ACP config option {config_id}: {value}")
config_options = await self._session_config_options(session)
selected_option = next((option for option in config_options or [] if option.id == "model"), None)
if selected_option is None:
raise ValueError(f"unknown ACP config option: {config_id}")
allowed_values = {option.value for option in selected_option.options}
if value not in allowed_values:
raise ValueError(f"invalid value for ACP config option {config_id}: {value}")
session.runtime["model"] = value
return await self._session_config_options(session) or []

async def _send_user_message_updates(self, prompt: list[ACPPromptBlock], session_id: str) -> None:
client = self._require_client()
for block in prompt:
Expand Down Expand Up @@ -619,6 +671,51 @@ def _framework_tape_store(framework: BubFramework) -> object | None:
return store if hasattr(store, "fetch_all") else None


async def _framework_runtime_options(framework: BubFramework, session: ACPSession) -> object | None:
get_runtime_options = getattr(framework, "get_runtime_options", None)
if get_runtime_options is None:
return None
result = get_runtime_options(session_id=session.session_id, workspace=session.cwd)
if inspect.isawaitable(result):
result = await result
return result


def _runtime_options_to_acp_config_options(runtime_options: object | None, session: ACPSession) -> list[SessionConfigOptionSelect]:
if runtime_options is None:
return []
choices = _list_payload(_block_value(runtime_options, "models", []))
if not choices:
return []
current_model = _block_value(runtime_options, "current_model", None)
current_value = session.runtime.get("model") or (
str(current_model) if current_model is not None else str(_block_value(choices[0], "id"))
)
return [
SessionConfigOptionSelect(
type="select",
id="model",
name="Model",
current_value=current_value,
options=[_runtime_choice_to_acp_option(choice) for choice in choices],
category="model",
)
]


def _runtime_choice_to_acp_option(choice: object) -> SessionConfigSelectOption:
choice_id = str(_block_value(choice, "id"))
name = _block_value(choice, "name", None)
description = _block_value(choice, "description", None)
meta = _block_value(choice, "meta", None)
return SessionConfigSelectOption(
value=choice_id,
name=str(name) if name is not None else choice_id,
description=str(description) if description is not None else None,
field_meta=dict(meta) if isinstance(meta, Mapping) else None,
)


def _session_tape_name(session_id: str, workspace: Path) -> str:
workspace_hash = hashlib.md5(str(workspace.resolve()).encode("utf-8"), usedforsecurity=False).hexdigest()[:16]
session_hash = hashlib.md5(session_id.encode("utf-8"), usedforsecurity=False).hexdigest()[:16]
Expand Down
80 changes: 80 additions & 0 deletions packages/bub-acp-server/tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from pathlib import Path
from types import SimpleNamespace
from typing import Any

import pytest
Expand Down Expand Up @@ -92,6 +93,22 @@ async def stream():
return TurnResult(session_id=inbound.session_id, prompt=inbound.content, model_output="late text")


class ConfigFramework(FakeFramework):
def __init__(self) -> None:
super().__init__()
self.runtime_queries: list[tuple[str, Path]] = []

async def get_runtime_options(self, *, session_id: str, workspace: Path) -> object:
self.runtime_queries.append((session_id, workspace))
return SimpleNamespace(
models=[
SimpleNamespace(id="openai:gpt-5", name="GPT-5", description="OpenAI model"),
SimpleNamespace(id="anthropic:claude-sonnet-4-5", name="Claude Sonnet"),
],
current_model="openai:gpt-5",
)


@pytest.mark.asyncio
async def test_initialize_advertises_session_capabilities() -> None:
agent = BubACPAgent(FakeFramework())
Expand Down Expand Up @@ -180,6 +197,69 @@ async def test_sessions_survive_agent_restart(tmp_path: Path) -> None:
assert sessions.sessions[0].cwd == str(tmp_path)


@pytest.mark.asyncio
async def test_session_lifecycle_returns_config_options(tmp_path: Path) -> None:
framework = ConfigFramework()
client = FakeClient()
agent = BubACPAgent(framework)
agent.on_connect(client)

created = await agent.new_session(cwd=str(tmp_path))
loaded = await agent.load_session(cwd=str(tmp_path), session_id=created.session_id)
resumed = await agent.resume_session(cwd=str(tmp_path), session_id=created.session_id)

assert created.config_options is not None
assert created.config_options[0].id == "model"
assert created.config_options[0].name == "Model"
assert created.config_options[0].current_value == "openai:gpt-5"
assert created.config_options[0].options[0].value == "openai:gpt-5"
assert len(created.config_options) == 1
assert loaded.config_options is not None
assert loaded.config_options[0].id == "model"
assert resumed.config_options is not None
assert resumed.config_options[0].id == "model"
assert framework.runtime_queries == [
(created.session_id, tmp_path),
(created.session_id, tmp_path),
(created.session_id, tmp_path),
]


@pytest.mark.asyncio
async def test_set_config_option_updates_session_runtime_and_returns_config_options(tmp_path: Path) -> None:
framework = ConfigFramework()
agent = BubACPAgent(framework)
created = await agent.new_session(cwd=str(tmp_path))

response = await agent.set_config_option(
config_id="model",
session_id=created.session_id,
value="anthropic:claude-sonnet-4-5",
)

assert agent._sessions[created.session_id].runtime == {"model": "anthropic:claude-sonnet-4-5"}
assert response.config_options[0].id == "model"
assert response.config_options[0].current_value == "anthropic:claude-sonnet-4-5"


@pytest.mark.asyncio
async def test_prompt_passes_acp_runtime_selection_to_bub_message(tmp_path: Path) -> None:
framework = ConfigFramework()
client = FakeClient()
agent = BubACPAgent(framework)
agent.on_connect(client)
created = await agent.new_session(cwd=str(tmp_path))
await agent.set_config_option(
config_id="model",
session_id=created.session_id,
value="anthropic:claude-sonnet-4-5",
)

await agent.prompt([TextContentBlock(type="text", text="hello")], session_id=created.session_id)

assert framework.messages[0].runtime == {"model": "anthropic:claude-sonnet-4-5"}


@pytest.mark.asyncio
async def test_session_store_expands_user_home(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("HOME", str(tmp_path))
Expand Down