diff --git a/.envrc b/.envrc new file mode 100644 index 00000000..3550a30f --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitignore b/.gitignore index c883e6e6..85b14354 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # Node / Electron node_modules/ +.pnpm/ dist/ out/ build/ @@ -9,6 +10,7 @@ yarn-debug.log* yarn-error.log* # Go +.go/ bin/ *.test *.out @@ -30,6 +32,7 @@ session-events.jsonl.* .ao/ # Environment +.direnv/ .env .env.* !.env.example diff --git a/backend/.golangci.yml b/backend/.golangci.yml index 438dd020..2a471bf5 100644 --- a/backend/.golangci.yml +++ b/backend/.golangci.yml @@ -85,6 +85,7 @@ linters: excludes: - G104 # unchecked errors — errcheck owns this - G304 # file inclusion via variable — paths are config/run-file/worktree-derived, not user input + - G703 # path traversal via taint — same rationale as G304: paths are daemon-owned exclusions: generated: lax # skip sqlc/codegen ("Code generated ... DO NOT EDIT") @@ -107,7 +108,6 @@ linters: formatters: enable: - - gofmt - goimports settings: goimports: diff --git a/backend/internal/adapters/agent/agent.go b/backend/internal/adapters/agent/agent.go new file mode 100644 index 00000000..a0ef0b48 --- /dev/null +++ b/backend/internal/adapters/agent/agent.go @@ -0,0 +1,134 @@ +package agent + +import ( + "context" +) + +// Agent defines the behavior every CLI coding agent adapter must provide. +type Agent interface { + // GetConfigSpec describes the agent-specific config keys Better-AO can + // expose to users in ~/.better-ao/config.yaml. + GetConfigSpec(ctx context.Context) (ConfigSpec, error) + + // GetLaunchCommand builds the command Better-AO should run to start this agent. + GetLaunchCommand(ctx context.Context, cfg LaunchConfig) (cmd []string, err error) + + // GetPromptDeliveryStrategy tells Better-AO whether the prompt is included in + // the launch command or must be sent after the agent process starts. + GetPromptDeliveryStrategy(ctx context.Context, cfg LaunchConfig) (PromptDeliveryStrategy, error) + + // GetAgentHooks installs or merges Better-AO hooks into the agent's + // native workspace-local hook config. It must preserve user-defined hooks. + GetAgentHooks(ctx context.Context, cfg WorkspaceHookConfig) error + + // GetRestoreCommand builds a command that continues an existing native agent + // session. ok=false means no existing native session can be continued. + GetRestoreCommand(ctx context.Context, cfg RestoreConfig) (cmd []string, ok bool, err error) + + // SessionInfo reads agent-owned session metadata such as native session id, + // display title, or summary. ok=false means no info is available. + SessionInfo(ctx context.Context, session SessionRef) (info SessionInfo, ok bool, err error) +} + +// MetadataKeyAgentSessionID is the SessionRef.Metadata key under which every +// adapter persists the native agent session id captured at launch and reads it +// back during restore. The Better-AO portshim sets it so the underlying +// adapter's GetRestoreCommand sees a unified location regardless of harness. +const MetadataKeyAgentSessionID = "agentSessionId" + +// Config contains values loaded from the selected agent's config section. +// Agent adapters own validation for their custom keys. +type Config map[string]any + +// ConfigSpec describes the agent-specific config keys AO can expose to users. +type ConfigSpec struct { + Fields []ConfigField +} + +// ConfigField describes one user-facing agent config key. +type ConfigField struct { + Key string + Type ConfigFieldType + Description string + Required bool + Default any + Enum []string +} + +// ConfigFieldType is the primitive value kind Better-AO expects for a field. +type ConfigFieldType string + +// Known ConfigFieldType values. +const ( + ConfigFieldString ConfigFieldType = "string" + ConfigFieldBool ConfigFieldType = "bool" + ConfigFieldNumber ConfigFieldType = "number" + ConfigFieldStringList ConfigFieldType = "string_list" + ConfigFieldEnum ConfigFieldType = "enum" +) + +// LaunchConfig carries inputs needed to build a new agent launch command. +type LaunchConfig struct { + Config Config + IssueID string + Permissions PermissionMode + Prompt string + SessionID string + SystemPrompt string + SystemPromptFile string + WorkspacePath string +} + +// WorkspaceHookConfig carries inputs needed to install workspace-local agent hooks. +type WorkspaceHookConfig struct { + Config Config + DataDir string + SessionID string + WorkspacePath string +} + +// RestoreConfig carries inputs needed to continue an existing native agent session. +type RestoreConfig struct { + Config Config + Permissions PermissionMode + Session SessionRef +} + +// SessionRef identifies a Better-AO session whose agent-owned metadata may be read. +type SessionRef struct { + ID string + Metadata map[string]string + WorkspacePath string +} + +// SessionInfo contains agent-owned session metadata. +type SessionInfo struct { + AgentSessionID string + Metadata map[string]string + Title string + Summary string +} + +// PermissionMode controls how much review an agent requires before acting. +type PermissionMode string + +// Known PermissionMode values. +// +// PermissionModeDefault is special: adapters emit no flag for it so the agent +// resolves its starting mode from the user's own config (e.g. Claude's TUI +// reading ~/.claude/settings.json defaultMode). +const ( + PermissionModeDefault PermissionMode = "default" + PermissionModeAcceptEdits PermissionMode = "accept-edits" + PermissionModeAuto PermissionMode = "auto" + PermissionModeBypassPermissions PermissionMode = "bypass-permissions" +) + +// PromptDeliveryStrategy describes how Better-AO should deliver the initial prompt. +type PromptDeliveryStrategy string + +// Known PromptDeliveryStrategy values. +const ( + PromptDeliveryInCommand PromptDeliveryStrategy = "in_command" + PromptDeliveryAfterStart PromptDeliveryStrategy = "after_start" +) diff --git a/backend/internal/adapters/agent/claudecode/.claude/settings.local.json b/backend/internal/adapters/agent/claudecode/.claude/settings.local.json new file mode 100644 index 00000000..da538700 --- /dev/null +++ b/backend/internal/adapters/agent/claudecode/.claude/settings.local.json @@ -0,0 +1,38 @@ +{ + "hooks": { + "SessionStart": [ + { + "matcher": "startup", + "hooks": [ + { + "type": "command", + "command": "better-ao hooks claude-code session-start", + "timeout": 30 + } + ] + } + ], + "UserPromptSubmit": [ + { + "hooks": [ + { + "type": "command", + "command": "better-ao hooks claude-code user-prompt-submit", + "timeout": 30 + } + ] + } + ], + "Stop": [ + { + "hooks": [ + { + "type": "command", + "command": "better-ao hooks claude-code stop", + "timeout": 30 + } + ] + } + ] + } +} diff --git a/backend/internal/adapters/agent/claudecode/claudecode.go b/backend/internal/adapters/agent/claudecode/claudecode.go new file mode 100644 index 00000000..fb66fb7f --- /dev/null +++ b/backend/internal/adapters/agent/claudecode/claudecode.go @@ -0,0 +1,444 @@ +// Package claudecode implements the Claude Code agent adapter. +// +// It builds the argv to launch `claude` as an interactive session inside a +// session's worktree, installs worktree-local hooks that report normalized +// session metadata (native id, title, summary) back into Better-AO's store, +// and supports resume: GetLaunchCommand pins a stable `--session-id` so +// GetRestoreCommand can rebuild `claude --resume `. SessionInfo reads the +// hook-captured metadata from the store — it does not parse transcripts. +// GetConfigSpec remains a no-op (no agent-specific config keys yet). +// +// Claude Code starts an interactive session by default (no -p/--print), which +// is exactly what better-ao wants: a live agent the user can attach to in the +// browser terminal or via `zellij attach`. The initial task prompt is passed +// as the positional argument; the orchestrator system prompt (if any) is +// appended to Claude's default system prompt so its built-in coding +// instructions are preserved. +package claudecode + +import ( + "context" + "encoding/json" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/google/uuid" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +const ( + // adapterID is the registry id and the value users pass to + // `better-ao spawn --agent`. + adapterID = "claude-code" + + // Normalized session-metadata keys the Claude Code hooks persist into the + // Better-AO session store and SessionInfo reads back. Shared vocabulary + // with the Codex adapter so the dashboard treats every agent uniformly. + claudeTitleMetadataKey = "title" + claudeSummaryMetadataKey = "summary" +) + +// claudeSessionNamespace seeds the UUIDv5 derivation that maps a better-ao +// session id onto a stable Claude Code `--session-id`. A fixed namespace makes +// the mapping deterministic, so GetLaunchCommand (which pins --session-id at +// launch) and GetRestoreCommand (which recomputes it as a fallback for +// pre-hook sessions) agree without persisting anything. +var claudeSessionNamespace = uuid.MustParse("a1f0c3d2-7b54-4e96-8a2b-0d9e1f2a3b4c") + +// Plugin is the Claude Code adapter. The zero value is not usable; call New. +type Plugin struct { + binaryMu sync.Mutex + resolvedBinary string +} + +// New constructs a Claude Code adapter instance. +func New() *Plugin { + return &Plugin{} +} + +var _ adapters.Adapter = (*Plugin)(nil) +var _ agent.Agent = (*Plugin)(nil) + +// Manifest reports the adapter's self-describing record. +func (p *Plugin) Manifest() adapters.Manifest { + return adapters.Manifest{ + ID: adapterID, + Name: "Claude Code", + Description: "Run Claude Code worker sessions.", + Version: "0.0.1", + Capabilities: []adapters.Capability{ + adapters.CapabilityAgent, + }, + } +} + +// GetConfigSpec returns the agent-specific config keys this adapter exposes. +// Claude Code has none today. +func (p *Plugin) GetConfigSpec(ctx context.Context) (agent.ConfigSpec, error) { + if err := ctx.Err(); err != nil { + return agent.ConfigSpec{}, err + } + return agent.ConfigSpec{}, nil +} + +// GetLaunchCommand builds the argv to start an interactive Claude Code +// session. Shape: +// +// claude [--session-id ] \ +// [--permission-mode ] \ +// [--append-system-prompt ] \ +// [-- ] +// +// --session-id pins Claude's native session UUID to a value derived from the +// better-ao session id, so the session is resumable later (see +// GetRestoreCommand) and its transcript is locatable (see SessionInfo) without +// a separate capture step. +// +// is acceptEdits, auto, or bypassPermissions. better-ao's "default" +// mode emits no --permission-mode flag, so Claude's TUI resolves the starting +// mode from ~/.claude/settings.json exactly as a normal launch. +// +// The prompt is passed after `--` so a prompt beginning with "-" is not +// mistaken for a flag. +func (p *Plugin) GetLaunchCommand(ctx context.Context, cfg agent.LaunchConfig) (cmd []string, err error) { + binary, err := p.claudeBinary(ctx) + if err != nil { + return nil, err + } + + cmd = []string{binary} + if cfg.SessionID != "" { + cmd = append(cmd, "--session-id", claudeSessionUUID(cfg.SessionID)) + } + appendPermissionFlags(&cmd, cfg.Permissions) + + systemPrompt, err := resolveSystemPrompt(cfg) + if err != nil { + return nil, err + } + if systemPrompt != "" { + // Append rather than replace: Claude Code's default system prompt + // carries its tool-use and coding instructions, which we want to + // keep. The orchestrator prompt layers on top. + cmd = append(cmd, "--append-system-prompt", systemPrompt) + } + + if cfg.Prompt != "" { + cmd = append(cmd, "--", cfg.Prompt) + } + + return cmd, nil +} + +// GetPromptDeliveryStrategy reports how Better-AO should deliver the initial +// prompt. Claude Code accepts it in the launch command. +func (p *Plugin) GetPromptDeliveryStrategy(ctx context.Context, cfg agent.LaunchConfig) (agent.PromptDeliveryStrategy, error) { + if err := ctx.Err(); err != nil { + return "", err + } + return agent.PromptDeliveryInCommand, nil +} + +// PreLaunch is an optional capability the spawn engine invokes (via type +// assertion) immediately before creating the session. Claude Code shows a +// blocking "do you trust this folder?" dialog the first time it runs in any +// directory. Every better-ao worktree is a fresh path, so without this the +// agent would hang at that prompt with no one to answer it. +// +// A better-ao worktree is derived from the repo the user is already running +// better-ao in, so it is inherently trusted. PreLaunch records that trust in +// ~/.claude.json before launch, additively and atomically, so it cannot +// clobber a concurrently-running Claude instance's config. +func (p *Plugin) PreLaunch(ctx context.Context, cfg agent.LaunchConfig) error { + if err := ctx.Err(); err != nil { + return err + } + if cfg.WorkspacePath == "" { + return nil + } + cfgPath, err := claudeConfigPath() + if err != nil { + return err + } + return ensureWorkspaceTrusted(cfgPath, cfg.WorkspacePath) +} + +// GetRestoreCommand rebuilds the argv that continues an existing Claude Code +// session: `claude [--permission-mode ] --resume `. It +// prefers the hook-captured native session id from +// cfg.Session.Metadata["agentSessionId"]; for sessions created before hooks +// captured it, it falls back to the deterministic UUID better-ao pins via +// --session-id at launch. ok is false only when neither is available, so the +// caller fresh-spawns. The command re-applies the permission mode (resume +// otherwise reverts to the configured default) but not the prompt/system +// prompt, which the session already carries. +func (p *Plugin) GetRestoreCommand(ctx context.Context, cfg agent.RestoreConfig) (cmd []string, ok bool, err error) { + if err := ctx.Err(); err != nil { + return nil, false, err + } + + sessionID := strings.TrimSpace(cfg.Session.Metadata[agent.MetadataKeyAgentSessionID]) + if sessionID == "" && cfg.Session.ID != "" { + // Explicit fallback for pre-hook sessions: the id better-ao + // deterministically pinned via --session-id at launch. + sessionID = claudeSessionUUID(cfg.Session.ID) + } + if sessionID == "" { + return nil, false, nil + } + + binary, err := p.claudeBinary(ctx) + if err != nil { + return nil, false, err + } + cmd = make([]string, 0, 5) + cmd = append(cmd, binary) + appendPermissionFlags(&cmd, cfg.Permissions) + cmd = append(cmd, "--resume", sessionID) + return cmd, true, nil +} + +// SessionInfo surfaces the normalized session metadata that the Claude Code +// hooks persisted into Better-AO's store: the native session id, the title (the +// first user prompt), and the summary (the final assistant message). It reads +// only from session.Metadata — never from transcript files — and returns +// ok=false when none of those fields are present. Metadata is intentionally nil: +// there is no Claude-specific field callers need beyond the normalized ones. +func (p *Plugin) SessionInfo(ctx context.Context, session agent.SessionRef) (agent.SessionInfo, bool, error) { + if err := ctx.Err(); err != nil { + return agent.SessionInfo{}, false, err + } + info := agent.SessionInfo{ + AgentSessionID: session.Metadata[agent.MetadataKeyAgentSessionID], + Title: session.Metadata[claudeTitleMetadataKey], + Summary: session.Metadata[claudeSummaryMetadataKey], + } + if info.AgentSessionID == "" && info.Title == "" && info.Summary == "" { + return agent.SessionInfo{}, false, nil + } + return info, true, nil +} + +// claudeSessionUUID maps a better-ao session id onto a stable Claude Code +// session UUID via UUIDv5 over a fixed namespace, so the same better-ao session +// always resolves to the same Claude session. +func claudeSessionUUID(betterAoSessionID string) string { + return uuid.NewSHA1(claudeSessionNamespace, []byte(betterAoSessionID)).String() +} + +// resolveSystemPrompt returns the system prompt text to append, preferring +// SystemPromptFile (read from disk) over an inline SystemPrompt. +func resolveSystemPrompt(cfg agent.LaunchConfig) (string, error) { + if cfg.SystemPromptFile != "" { + data, err := os.ReadFile(cfg.SystemPromptFile) + if err != nil { + return "", fmt.Errorf("claude-code: read system prompt file: %w", err) + } + return strings.TrimRight(string(data), "\n"), nil + } + return cfg.SystemPrompt, nil +} + +// appendPermissionFlags maps better-ao's permission modes onto Claude Code's +// --permission-mode values: +// - default → no flag. Claude's TUI resolves the starting mode +// from ~/.claude/settings.json (defaultMode), exactly as a normal launch. +// - accept-edits → --permission-mode acceptEdits (auto-accept edits + +// safe filesystem bash; still prompts for network/system bash, MCP, web) +// - auto → --permission-mode auto (classifier-gated +// auto-approval; auto-runs what a safety model deems safe) +// - bypass-permissions → --permission-mode bypassPermissions (skip all +// checks; equivalent to --dangerously-skip-permissions) +// +// Empty/unrecognized normalizes to default, so no flag is emitted. +func appendPermissionFlags(cmd *[]string, permissions agent.PermissionMode) { + switch normalizePermissionMode(permissions) { + case agent.PermissionModeDefault: + // No flag: defer to the user's settings.json defaultMode. + case agent.PermissionModeAcceptEdits: + *cmd = append(*cmd, "--permission-mode", "acceptEdits") + case agent.PermissionModeAuto: + *cmd = append(*cmd, "--permission-mode", "auto") + case agent.PermissionModeBypassPermissions: + *cmd = append(*cmd, "--permission-mode", "bypassPermissions") + } +} + +func normalizePermissionMode(mode agent.PermissionMode) agent.PermissionMode { + switch mode { + case agent.PermissionModeDefault, + agent.PermissionModeAcceptEdits, + agent.PermissionModeAuto, + agent.PermissionModeBypassPermissions: + return mode + default: + // Empty or unrecognized: defer to settings.json (no flag). + return agent.PermissionModeDefault + } +} + +// ResolveClaudeBinary finds the `claude` binary, searching PATH then a few +// well-known install locations (the native installer's ~/.local/bin, npm +// global, Homebrew). Returns "claude" as a last resort so callers get a +// clear "command not found" rather than an empty argv. +func ResolveClaudeBinary(ctx context.Context) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + if runtime.GOOS == "windows" { + for _, name := range []string{"claude.cmd", "claude.exe", "claude"} { + if path, err := exec.LookPath(name); err == nil && path != "" { + return path, nil + } + } + candidates := []string{} + if appData := os.Getenv("APPDATA"); appData != "" { + candidates = append(candidates, + filepath.Join(appData, "npm", "claude.cmd"), + filepath.Join(appData, "npm", "claude.exe"), + ) + } + for _, candidate := range candidates { + if fileExists(candidate) { + return candidate, nil + } + } + return "claude", nil + } + + if path, err := exec.LookPath("claude"); err == nil && path != "" { + return path, nil + } + + candidates := []string{ + "/usr/local/bin/claude", + "/opt/homebrew/bin/claude", + } + if home, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, + filepath.Join(home, ".local", "bin", "claude"), + filepath.Join(home, ".npm", "bin", "claude"), + filepath.Join(home, ".claude", "local", "claude"), + ) + } + for _, candidate := range candidates { + if fileExists(candidate) { + return candidate, nil + } + if err := ctx.Err(); err != nil { + return "", err + } + } + + return "claude", nil +} + +func (p *Plugin) claudeBinary(ctx context.Context) (string, error) { + p.binaryMu.Lock() + defer p.binaryMu.Unlock() + + if p.resolvedBinary != "" { + return p.resolvedBinary, nil + } + + binary, err := ResolveClaudeBinary(ctx) + if err != nil { + return "", err + } + p.resolvedBinary = binary + return binary, nil +} + +// claudeConfigPath returns the path to Claude Code's global config file, +// ~/.claude.json. +func claudeConfigPath() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("claude-code: resolve home directory: %w", err) + } + return filepath.Join(home, ".claude.json"), nil +} + +// ensureWorkspaceTrusted records workspacePath as trusted in Claude Code's +// config so the interactive trust dialog does not block a spawned session. +// +// It is additive and concurrency-safe: it reads the existing config, sets +// only projects[workspacePath].hasTrustDialogAccepted = true (preserving the +// rest of the entry and every other project), and writes back via a +// temp-file + atomic rename. If the path is already trusted, it makes no +// write at all. A missing config file is treated as an empty one. +func ensureWorkspaceTrusted(configPath, workspacePath string) error { + root := map[string]any{} + data, err := os.ReadFile(configPath) + switch { + case err == nil: + if len(data) > 0 { + if err := json.Unmarshal(data, &root); err != nil { + return fmt.Errorf("claude-code: parse %s: %w", configPath, err) + } + } + case os.IsNotExist(err): + // Treat as empty config; we'll create it. + default: + return fmt.Errorf("claude-code: read %s: %w", configPath, err) + } + + projects, _ := root["projects"].(map[string]any) + if projects == nil { + projects = map[string]any{} + root["projects"] = projects + } + + entry, _ := projects[workspacePath].(map[string]any) + if entry == nil { + entry = map[string]any{} + projects[workspacePath] = entry + } + + if trusted, ok := entry["hasTrustDialogAccepted"].(bool); ok && trusted { + // Already trusted — no write needed, so no race window at all. + return nil + } + entry["hasTrustDialogAccepted"] = true + + out, err := json.MarshalIndent(root, "", " ") + if err != nil { + return fmt.Errorf("claude-code: encode %s: %w", configPath, err) + } + + // Atomic write: temp file in the same directory, then rename. Matches + // how Claude Code itself updates this file, so concurrent updates are + // last-writer-wins rather than corrupting. + dir := filepath.Dir(configPath) + tmp, err := os.CreateTemp(dir, ".claude.json.tmp-*") + if err != nil { + return fmt.Errorf("claude-code: create temp config: %w", err) + } + tmpName := tmp.Name() + defer func() { _ = os.Remove(tmpName) }() // no-op once renamed + + if _, err := tmp.Write(out); err != nil { + _ = tmp.Close() + return fmt.Errorf("claude-code: write temp config: %w", err) + } + if err := tmp.Close(); err != nil { + return fmt.Errorf("claude-code: close temp config: %w", err) + } + if err := os.Rename(tmpName, configPath); err != nil { + return fmt.Errorf("claude-code: replace config: %w", err) + } + return nil +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + return err == nil && !info.IsDir() +} diff --git a/backend/internal/adapters/agent/claudecode/claudecode_test.go b/backend/internal/adapters/agent/claudecode/claudecode_test.go new file mode 100644 index 00000000..3914ff64 --- /dev/null +++ b/backend/internal/adapters/agent/claudecode/claudecode_test.go @@ -0,0 +1,473 @@ +package claudecode + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "reflect" + "testing" + + "github.com/google/uuid" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +func TestGetLaunchCommandBypassWithPrompt(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + + cmd, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + Permissions: agent.PermissionModeBypassPermissions, + Prompt: "-add a health check", + }) + if err != nil { + t.Fatal(err) + } + + want := []string{ + "claude", + "--permission-mode", "bypassPermissions", + "--", "-add a health check", + } + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("unexpected command\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetLaunchCommandMapsPermissionModes(t *testing.T) { + tests := []struct { + name string + permission agent.PermissionMode + want []string + notExpected string + }{ + {"default omits flag (defers to settings.json)", agent.PermissionModeDefault, nil, "--permission-mode"}, + {"accept-edits", agent.PermissionModeAcceptEdits, []string{"--permission-mode", "acceptEdits"}, ""}, + {"auto", agent.PermissionModeAuto, []string{"--permission-mode", "auto"}, ""}, + {"bypass-permissions", agent.PermissionModeBypassPermissions, []string{"--permission-mode", "bypassPermissions"}, ""}, + {"empty omits permission flags", "", nil, "--permission-mode"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + cmd, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + Permissions: tt.permission, + }) + if err != nil { + t.Fatal(err) + } + if len(tt.want) > 0 && !containsSubsequence(cmd, tt.want) { + t.Fatalf("command %#v does not contain %#v", cmd, tt.want) + } + if tt.notExpected != "" && contains(cmd, tt.notExpected) { + t.Fatalf("command %#v unexpectedly contains %q", cmd, tt.notExpected) + } + }) + } +} + +func TestGetLaunchCommandAppendsSystemPromptFromFile(t *testing.T) { + dir := t.TempDir() + promptFile := filepath.Join(dir, "system.md") + if err := os.WriteFile(promptFile, []byte("You are an orchestrator.\n"), 0o644); err != nil { + t.Fatal(err) + } + + p := &Plugin{resolvedBinary: "claude"} + cmd, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + SystemPromptFile: promptFile, + Prompt: "do the thing", + }) + if err != nil { + t.Fatal(err) + } + + want := []string{ + "claude", + "--append-system-prompt", "You are an orchestrator.", + "--", "do the thing", + } + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("unexpected command\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetLaunchCommandInlineSystemPrompt(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + cmd, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + SystemPrompt: "inline instructions", + }) + if err != nil { + t.Fatal(err) + } + if !containsSubsequence(cmd, []string{"--append-system-prompt", "inline instructions"}) { + t.Fatalf("command %#v does not append inline system prompt", cmd) + } +} + +func TestGetLaunchCommandMissingSystemPromptFileErrors(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + _, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + SystemPromptFile: filepath.Join(t.TempDir(), "does-not-exist.md"), + }) + if err == nil { + t.Fatal("expected error for missing system prompt file") + } +} + +func TestGetLaunchCommandInjectsSessionID(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + cmd, err := p.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + SessionID: "e0tt49", + Prompt: "do the thing", + }) + if err != nil { + t.Fatal(err) + } + wantUUID := claudeSessionUUID("e0tt49") + if !containsSubsequence(cmd, []string{"--session-id", wantUUID}) { + t.Fatalf("command %#v missing --session-id %q", cmd, wantUUID) + } + + // No SessionID → no --session-id flag. + cmd, err = p.GetLaunchCommand(context.Background(), agent.LaunchConfig{Prompt: "x"}) + if err != nil { + t.Fatal(err) + } + if contains(cmd, "--session-id") { + t.Fatalf("command %#v unexpectedly contains --session-id", cmd) + } +} + +func TestClaudeSessionUUIDDeterministicAndUnique(t *testing.T) { + a1 := claudeSessionUUID("alpha") + a2 := claudeSessionUUID("alpha") + b := claudeSessionUUID("beta") + if a1 != a2 { + t.Fatalf("derivation not deterministic: %q != %q", a1, a2) + } + if a1 == b { + t.Fatalf("distinct ids collided: both %q", a1) + } + if _, err := uuid.Parse(a1); err != nil { + t.Fatalf("derived value is not a valid UUID: %q (%v)", a1, err) + } +} + +func TestGetAgentHooksInstallsClaudeHooks(t *testing.T) { + p := &Plugin{resolvedBinary: "claude"} + workspace := t.TempDir() + settingsDir := filepath.Join(workspace, ".claude") + if err := os.MkdirAll(settingsDir, 0o755); err != nil { + t.Fatal(err) + } + settingsPath := filepath.Join(settingsDir, "settings.local.json") + // Pre-seed a user's own Stop hook + an unrelated setting; both must survive. + existing := `{"hooks":{"Stop":[{"hooks":[{"type":"command","command":"my own stop hook","timeout":5}]}]},"permissions":{"defaultMode":"plan"}}` + if err := os.WriteFile(settingsPath, []byte(existing), 0o644); err != nil { + t.Fatal(err) + } + + cfg := agent.WorkspaceHookConfig{DataDir: t.TempDir(), SessionID: "sess-1", WorkspacePath: workspace} + if err := p.GetAgentHooks(context.Background(), cfg); err != nil { + t.Fatal(err) + } + // A second install must not duplicate Better-AO hook commands. + if err := p.GetAgentHooks(context.Background(), cfg); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(settingsPath) + if err != nil { + t.Fatal(err) + } + var config struct { + Hooks map[string][]claudeMatcherGroup `json:"hooks"` + Permissions json.RawMessage `json:"permissions"` + } + if err := json.Unmarshal(data, &config); err != nil { + t.Fatal(err) + } + if config.Hooks == nil { + t.Fatalf("hooks object missing: %s", data) + } + + // Every command in the embedded template is installed exactly once. + templateHooks, err := claudeEmbeddedHookGroups() + if err != nil { + t.Fatal(err) + } + for event, templateGroups := range templateHooks { + for _, group := range templateGroups { + for _, hook := range group.Hooks { + if got := countClaudeHookCommand(config.Hooks[event], hook.Command); got != 1 { + t.Fatalf("%s command %q count = %d, want 1", event, hook.Command, got) + } + } + } + } + // Existing user hook preserved. + if countClaudeHookCommand(config.Hooks["Stop"], "my own stop hook") != 1 { + t.Fatalf("existing Stop hook not preserved: %#v", config.Hooks["Stop"]) + } + // Unrelated settings preserved. + if len(config.Permissions) == 0 { + t.Fatalf("unrelated settings clobbered: %s", data) + } + // SessionStart carries the required matcher; UserPromptSubmit omits it. + if m := matcherForCommand(config.Hooks["SessionStart"], "better-ao hooks claude-code session-start"); m == nil || *m != "startup" { + t.Fatalf("SessionStart matcher = %v, want startup", m) + } + if m := matcherForCommand(config.Hooks["UserPromptSubmit"], "better-ao hooks claude-code user-prompt-submit"); m != nil { + t.Fatalf("UserPromptSubmit matcher = %v, want none", m) + } +} + +func TestSessionInfoReadsHookMetadata(t *testing.T) { + info, ok, err := (&Plugin{resolvedBinary: "claude"}).SessionInfo(context.Background(), agent.SessionRef{ + WorkspacePath: "/some/path", + Metadata: map[string]string{ + agent.MetadataKeyAgentSessionID: "claude-native-1", + claudeTitleMetadataKey: "Fix login redirect", + claudeSummaryMetadataKey: "Updated the auth callback and tests.", + "ignored": "not returned", + }, + }) + if err != nil || !ok { + t.Fatalf("SessionInfo = (ok=%v, err=%v), want ok", ok, err) + } + if info.AgentSessionID != "claude-native-1" { + t.Fatalf("AgentSessionID = %q", info.AgentSessionID) + } + if info.Title != "Fix login redirect" { + t.Fatalf("Title = %q", info.Title) + } + if info.Summary != "Updated the auth callback and tests." { + t.Fatalf("Summary = %q", info.Summary) + } + if info.Metadata != nil { + t.Fatalf("Metadata = %#v, want nil for Claude", info.Metadata) + } +} + +func TestSessionInfoFalseWhenNoHookMetadata(t *testing.T) { + info, ok, err := (&Plugin{resolvedBinary: "claude"}).SessionInfo(context.Background(), agent.SessionRef{ + WorkspacePath: "/some/path", + Metadata: map[string]string{}, + }) + if err != nil { + t.Fatalf("err = %v", err) + } + if ok { + t.Fatalf("ok = true, want false") + } + if !reflect.DeepEqual(info, agent.SessionInfo{}) { + t.Fatalf("info = %#v, want zero", info) + } +} + +// countClaudeHookCommand counts how many hook entries under one event register +// the given command — used to prove no duplicate Better-AO hooks. +func countClaudeHookCommand(groups []claudeMatcherGroup, command string) int { + count := 0 + for _, group := range groups { + for _, hook := range group.Hooks { + if hook.Command == command { + count++ + } + } + } + return count +} + +// matcherForCommand returns the matcher on the group that registers the given +// command (nil if the group has no matcher). +func matcherForCommand(groups []claudeMatcherGroup, command string) *string { + for _, group := range groups { + for _, hook := range group.Hooks { + if hook.Command == command { + return group.Matcher + } + } + } + return nil +} + +func TestGetRestoreCommandReadsAgentSessionID(t *testing.T) { + cmd, ok, err := (&Plugin{resolvedBinary: "claude"}).GetRestoreCommand(context.Background(), agent.RestoreConfig{ + Permissions: agent.PermissionModeBypassPermissions, + Session: agent.SessionRef{ + ID: "sess-r", + Metadata: map[string]string{agent.MetadataKeyAgentSessionID: "claude-native-1"}, + }, + }) + if err != nil || !ok { + t.Fatalf("restore = (ok=%v, err=%v), want ok", ok, err) + } + // The hook-captured native id wins over the derived fallback. + want := []string{"claude", "--permission-mode", "bypassPermissions", "--resume", "claude-native-1"} + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("restore cmd\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetRestoreCommandFallsBackToDerivedUUID(t *testing.T) { + // No agentSessionId captured (pre-hook session) → derive deterministically + // from the better-ao session id, the explicit fallback. + cmd, ok, err := (&Plugin{resolvedBinary: "claude"}).GetRestoreCommand(context.Background(), agent.RestoreConfig{ + Permissions: agent.PermissionModeBypassPermissions, + Session: agent.SessionRef{ID: "sess-r"}, + }) + if err != nil || !ok { + t.Fatalf("restore = (ok=%v, err=%v), want ok", ok, err) + } + want := []string{"claude", "--permission-mode", "bypassPermissions", "--resume", claudeSessionUUID("sess-r")} + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("restore cmd\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetRestoreCommandFalseWithoutSessionID(t *testing.T) { + cases := []struct { + name string + ref agent.SessionRef + }{ + {"empty ref", agent.SessionRef{}}, + {"blank agent session, no id", agent.SessionRef{Metadata: map[string]string{agent.MetadataKeyAgentSessionID: " "}}}, + {"workspace path only", agent.SessionRef{WorkspacePath: "/some/path"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cmd, ok, err := (&Plugin{resolvedBinary: "claude"}).GetRestoreCommand(context.Background(), + agent.RestoreConfig{Permissions: agent.PermissionModeBypassPermissions, Session: tc.ref}) + if err != nil || ok || cmd != nil { + t.Fatalf("restore = (%#v, %v, %v), want (nil,false,nil)", cmd, ok, err) + } + }) + } +} + +func TestManifestID(t *testing.T) { + if got := New().Manifest().ID; got != "claude-code" { + t.Fatalf("manifest id = %q, want claude-code", got) + } +} + +func TestEnsureWorkspaceTrustedCreatesEntry(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, ".claude.json") + // Seed an existing config with another project + a top-level key, to + // prove we preserve unrelated state. + seed := `{"userID":"abc","projects":{"/existing/proj":{"hasTrustDialogAccepted":true,"lastCost":1.5}}}` + if err := os.WriteFile(cfgPath, []byte(seed), 0o600); err != nil { + t.Fatal(err) + } + + work := "/Users/me/.better-ao/worktrees/01ABC" + if err := ensureWorkspaceTrusted(cfgPath, work); err != nil { + t.Fatalf("ensureWorkspaceTrusted: %v", err) + } + + root := readJSON(t, cfgPath) + projects := root["projects"].(map[string]any) + + // New entry trusted. + newEntry := projects[work].(map[string]any) + if newEntry["hasTrustDialogAccepted"] != true { + t.Fatalf("new entry not trusted: %#v", newEntry) + } + // Existing project preserved (including its other fields). + existing := projects["/existing/proj"].(map[string]any) + if existing["hasTrustDialogAccepted"] != true || existing["lastCost"].(float64) != 1.5 { + t.Fatalf("existing project clobbered: %#v", existing) + } + // Top-level key preserved. + if root["userID"] != "abc" { + t.Fatalf("top-level key clobbered: %#v", root["userID"]) + } +} + +func TestEnsureWorkspaceTrustedIsIdempotentAndNoWriteWhenAlreadyTrusted(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, ".claude.json") + work := "/w" + if err := os.WriteFile(cfgPath, []byte(`{"projects":{"/w":{"hasTrustDialogAccepted":true}}}`), 0o600); err != nil { + t.Fatal(err) + } + info1, err := os.Stat(cfgPath) + if err != nil { + t.Fatal(err) + } + + if err := ensureWorkspaceTrusted(cfgPath, work); err != nil { + t.Fatalf("ensureWorkspaceTrusted: %v", err) + } + + // Already trusted → no rewrite → mtime unchanged. + info2, err := os.Stat(cfgPath) + if err != nil { + t.Fatal(err) + } + if !info1.ModTime().Equal(info2.ModTime()) { + t.Fatal("expected no rewrite when already trusted") + } +} + +func TestEnsureWorkspaceTrustedCreatesMissingConfig(t *testing.T) { + dir := t.TempDir() + cfgPath := filepath.Join(dir, ".claude.json") // does not exist yet + work := "/fresh/worktree" + + if err := ensureWorkspaceTrusted(cfgPath, work); err != nil { + t.Fatalf("ensureWorkspaceTrusted: %v", err) + } + + root := readJSON(t, cfgPath) + projects := root["projects"].(map[string]any) + entry := projects[work].(map[string]any) + if entry["hasTrustDialogAccepted"] != true { + t.Fatalf("entry not trusted in freshly-created config: %#v", entry) + } +} + +func readJSON(t *testing.T, path string) map[string]any { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + t.Fatalf("parse %s: %v", path, err) + } + return m +} + +func contains(values []string, needle string) bool { + for _, v := range values { + if v == needle { + return true + } + } + return false +} + +func containsSubsequence(values, needle []string) bool { + if len(needle) == 0 { + return true + } + for start := 0; start+len(needle) <= len(values); start++ { + ok := true + for i, w := range needle { + if values[start+i] != w { + ok = false + break + } + } + if ok { + return true + } + } + return false +} diff --git a/backend/internal/adapters/agent/claudecode/hooks.go b/backend/internal/adapters/agent/claudecode/hooks.go new file mode 100644 index 00000000..16be91fe --- /dev/null +++ b/backend/internal/adapters/agent/claudecode/hooks.go @@ -0,0 +1,187 @@ +package claudecode + +import ( + "context" + "embed" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +const ( + claudeSettingsDirName = ".claude" + claudeSettingsFileName = "settings.local.json" + claudeHooksTemplate = ".claude/settings.local.json" +) + +//go:embed .claude/settings.local.json +var claudeHookTemplateFS embed.FS + +type claudeHookFile struct { + Hooks map[string][]claudeMatcherGroup `json:"hooks"` +} + +type claudeMatcherGroup struct { + // Matcher is a pointer so it round-trips exactly: SessionStart requires a + // real matcher ("startup"); UserPromptSubmit/Stop omit it (Claude ignores + // matcher for those events). omitempty drops a nil matcher on write. + Matcher *string `json:"matcher,omitempty"` + Hooks []claudeHookEntry `json:"hooks"` +} + +type claudeHookEntry struct { + Type string `json:"type"` + Command string `json:"command"` + Timeout int `json:"timeout,omitempty"` +} + +// GetAgentHooks installs Better-AO's Claude Code hooks into the worktree-local +// .claude/settings.local.json file (the per-session local settings, not the +// shared .claude/settings.json). The hooks (SessionStart, UserPromptSubmit, +// Stop) report normalized session metadata back into Better-AO's store. Existing +// hooks and unrelated settings are preserved, and duplicate Better-AO commands +// are not appended, so the install is idempotent. +func (p *Plugin) GetAgentHooks(ctx context.Context, cfg agent.WorkspaceHookConfig) error { + if err := ctx.Err(); err != nil { + return err + } + if strings.TrimSpace(cfg.WorkspacePath) == "" { + return errors.New("claude-code.GetAgentHooks: WorkspacePath is required") + } + + settingsPath := filepath.Join(cfg.WorkspacePath, claudeSettingsDirName, claudeSettingsFileName) + // Preserve every top-level setting (permissions, model, …) and every hook + // event we don't touch by keeping them as raw JSON. + topLevel := map[string]json.RawMessage{} + rawHooks := map[string]json.RawMessage{} + + if existingData, err := os.ReadFile(settingsPath); err == nil { + if strings.TrimSpace(string(existingData)) != "" { + if err := json.Unmarshal(existingData, &topLevel); err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: parse %s: %w", settingsPath, err) + } + if hooksRaw, ok := topLevel["hooks"]; ok { + if err := json.Unmarshal(hooksRaw, &rawHooks); err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: parse hooks in %s: %w", settingsPath, err) + } + } + } + } else if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("claude-code.GetAgentHooks: read %s: %w", settingsPath, err) + } + + templateHooks, err := claudeEmbeddedHookGroups() + if err != nil { + return err + } + for event, templateGroups := range templateHooks { + var existingGroups []claudeMatcherGroup + if err := parseClaudeHookType(rawHooks, event, &existingGroups); err != nil { + return err + } + for _, group := range templateGroups { + for _, hook := range group.Hooks { + if !claudeHookCommandExists(existingGroups, hook.Command) { + existingGroups = addClaudeHook(existingGroups, hook, group.Matcher) + } + } + } + if err := marshalClaudeHookType(rawHooks, event, existingGroups); err != nil { + return err + } + } + + hooksJSON, err := json.Marshal(rawHooks) + if err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: encode hooks: %w", err) + } + topLevel["hooks"] = hooksJSON + + if err := os.MkdirAll(filepath.Dir(settingsPath), 0o750); err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: create settings dir: %w", err) + } + data, err := json.MarshalIndent(topLevel, "", " ") + if err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: encode %s: %w", settingsPath, err) + } + data = append(data, '\n') + if err := os.WriteFile(settingsPath, data, 0o600); err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: write %s: %w", settingsPath, err) + } + return nil +} + +func claudeEmbeddedHookGroups() (map[string][]claudeMatcherGroup, error) { + data, err := claudeHookTemplateFS.ReadFile(claudeHooksTemplate) + if err != nil { + return nil, fmt.Errorf("claude-code.GetAgentHooks: read embedded %s: %w", claudeHooksTemplate, err) + } + var file claudeHookFile + if err := json.Unmarshal(data, &file); err != nil { + return nil, fmt.Errorf("claude-code.GetAgentHooks: parse embedded %s: %w", claudeHooksTemplate, err) + } + if file.Hooks == nil { + return map[string][]claudeMatcherGroup{}, nil + } + return file.Hooks, nil +} + +func parseClaudeHookType(rawHooks map[string]json.RawMessage, event string, target *[]claudeMatcherGroup) error { + data, ok := rawHooks[event] + if !ok { + return nil + } + if err := json.Unmarshal(data, target); err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: parse %s hooks: %w", event, err) + } + return nil +} + +func marshalClaudeHookType(rawHooks map[string]json.RawMessage, event string, groups []claudeMatcherGroup) error { + if len(groups) == 0 { + delete(rawHooks, event) + return nil + } + data, err := json.Marshal(groups) + if err != nil { + return fmt.Errorf("claude-code.GetAgentHooks: encode %s hooks: %w", event, err) + } + rawHooks[event] = data + return nil +} + +func claudeHookCommandExists(groups []claudeMatcherGroup, command string) bool { + for _, group := range groups { + for _, hook := range group.Hooks { + if hook.Command == command { + return true + } + } + } + return false +} + +// addClaudeHook appends hook to an existing group with the same matcher (so a +// SessionStart hook lands under its "startup" matcher), creating that group if +// none matches. +func addClaudeHook(groups []claudeMatcherGroup, hook claudeHookEntry, matcher *string) []claudeMatcherGroup { + for i, group := range groups { + if matchersEqual(group.Matcher, matcher) { + groups[i].Hooks = append(groups[i].Hooks, hook) + return groups + } + } + return append(groups, claudeMatcherGroup{Matcher: matcher, Hooks: []claudeHookEntry{hook}}) +} + +func matchersEqual(a, b *string) bool { + if a == nil || b == nil { + return a == nil && b == nil + } + return *a == *b +} diff --git a/backend/internal/adapters/agent/codex/.codex/hooks.json b/backend/internal/adapters/agent/codex/.codex/hooks.json new file mode 100644 index 00000000..aaf1660b --- /dev/null +++ b/backend/internal/adapters/agent/codex/.codex/hooks.json @@ -0,0 +1,40 @@ +{ + "hooks": { + "SessionStart": [ + { + "matcher": null, + "hooks": [ + { + "type": "command", + "command": "better-ao hooks codex session-start", + "timeout": 30 + } + ] + } + ], + "UserPromptSubmit": [ + { + "matcher": null, + "hooks": [ + { + "type": "command", + "command": "better-ao hooks codex user-prompt-submit", + "timeout": 30 + } + ] + } + ], + "Stop": [ + { + "matcher": null, + "hooks": [ + { + "type": "command", + "command": "better-ao hooks codex stop", + "timeout": 30 + } + ] + } + ] + } +} diff --git a/backend/internal/adapters/agent/codex/codex.go b/backend/internal/adapters/agent/codex/codex.go new file mode 100644 index 00000000..bc04fd23 --- /dev/null +++ b/backend/internal/adapters/agent/codex/codex.go @@ -0,0 +1,257 @@ +// Package codex implements the Codex agent adapter: launching new sessions, +// resuming hook-tracked sessions, installing workspace-local hooks, and reading +// hook-derived session info. +// +// Better-AO-managed sessions derive native session identity and display +// metadata from Codex hooks instead of transcript/cache scans. +package codex + +import ( + "context" + "os" + "os/exec" + "path/filepath" + "runtime" + "strings" + "sync" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +const ( + codexAgentSessionIDMetadataKey = agent.MetadataKeyAgentSessionID + codexTitleMetadataKey = "title" + codexSummaryMetadataKey = "summary" +) + +// Plugin is the Codex adapter. The zero value is not usable; call New. +type Plugin struct { + binaryMu sync.Mutex + resolvedBinary string +} + +// New constructs a Codex adapter instance. +func New() *Plugin { + return &Plugin{} +} + +var _ adapters.Adapter = (*Plugin)(nil) +var _ agent.Agent = (*Plugin)(nil) + +// Manifest reports the adapter's self-describing record. +func (p *Plugin) Manifest() adapters.Manifest { + return adapters.Manifest{ + ID: "codex", + Name: "Codex", + Description: "Run Codex worker sessions.", + Version: "0.0.1", + Capabilities: []adapters.Capability{ + adapters.CapabilityAgent, + }, + } +} + +// GetConfigSpec returns the agent-specific config keys this adapter exposes. +// Codex has none today. +func (p *Plugin) GetConfigSpec(ctx context.Context) (agent.ConfigSpec, error) { + if err := ctx.Err(); err != nil { + return agent.ConfigSpec{}, err + } + return agent.ConfigSpec{}, nil +} + +// GetLaunchCommand builds the argv to start a fresh Codex session. +func (p *Plugin) GetLaunchCommand(ctx context.Context, cfg agent.LaunchConfig) (cmd []string, err error) { + binary, err := p.codexBinary(ctx) + if err != nil { + return nil, err + } + + cmd = []string{binary} + appendNoUpdateCheckFlag(&cmd) + appendApprovalFlags(&cmd, cfg.Permissions) + + if cfg.SystemPromptFile != "" { + cmd = append(cmd, "-c", "model_instructions_file="+cfg.SystemPromptFile) + } else if cfg.SystemPrompt != "" { + cmd = append(cmd, "-c", "developer_instructions="+cfg.SystemPrompt) + } + + if cfg.Prompt != "" { + cmd = append(cmd, "--", cfg.Prompt) + } + + return cmd, nil +} + +// GetPromptDeliveryStrategy reports how Better-AO should deliver the initial +// prompt. Codex accepts it in the launch command. +func (p *Plugin) GetPromptDeliveryStrategy(ctx context.Context, cfg agent.LaunchConfig) (agent.PromptDeliveryStrategy, error) { + if err := ctx.Err(); err != nil { + return "", err + } + return agent.PromptDeliveryInCommand, nil +} + +// GetRestoreCommand rebuilds the argv that continues an existing Codex +// session: `codex resume `. ok is false when the hook-derived +// native session id has not landed yet, so callers can fall back to fresh +// launch behavior. +func (p *Plugin) GetRestoreCommand(ctx context.Context, cfg agent.RestoreConfig) (cmd []string, ok bool, err error) { + if err := ctx.Err(); err != nil { + return nil, false, err + } + agentSessionID := strings.TrimSpace(cfg.Session.Metadata[codexAgentSessionIDMetadataKey]) + if agentSessionID == "" { + return nil, false, nil + } + + binary, err := p.codexBinary(ctx) + if err != nil { + return nil, false, err + } + + cmd = make([]string, 0, 5) + cmd = append(cmd, binary, "resume") + appendNoUpdateCheckFlag(&cmd) + appendApprovalFlags(&cmd, cfg.Permissions) + cmd = append(cmd, agentSessionID) + return cmd, true, nil +} + +// SessionInfo surfaces Codex hook-derived metadata. Metadata is intentionally +// nil for Codex: callers get the normalized fields directly. +func (p *Plugin) SessionInfo(ctx context.Context, session agent.SessionRef) (agent.SessionInfo, bool, error) { + if err := ctx.Err(); err != nil { + return agent.SessionInfo{}, false, err + } + info := agent.SessionInfo{ + AgentSessionID: session.Metadata[codexAgentSessionIDMetadataKey], + Title: session.Metadata[codexTitleMetadataKey], + Summary: session.Metadata[codexSummaryMetadataKey], + } + if info.AgentSessionID == "" && info.Title == "" && info.Summary == "" { + return agent.SessionInfo{}, false, nil + } + return info, true, nil +} + +// ResolveCodexBinary returns the path to the codex binary on this machine, +// searching PATH then a handful of well-known install locations +// (Homebrew, Cargo, npm global). Returns "codex" as a last-ditch fallback +// so callers see a clear "command not found" rather than an empty argv. +func ResolveCodexBinary(ctx context.Context) (string, error) { + if err := ctx.Err(); err != nil { + return "", err + } + + if runtime.GOOS == "windows" { + for _, name := range []string{"codex.cmd", "codex.exe", "codex"} { + path, err := exec.LookPath(name) + if err == nil && path != "" { + return path, nil + } + if err := ctx.Err(); err != nil { + return "", err + } + } + + candidates := []string{} + if appData := os.Getenv("APPDATA"); appData != "" { + candidates = append(candidates, + filepath.Join(appData, "npm", "codex.cmd"), + filepath.Join(appData, "npm", "codex.exe"), + ) + } + if home, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, filepath.Join(home, ".cargo", "bin", "codex.exe")) + } + for _, candidate := range candidates { + if fileExists(candidate) { + return candidate, nil + } + if err := ctx.Err(); err != nil { + return "", err + } + } + + return "codex", nil + } + + if path, err := exec.LookPath("codex"); err == nil && path != "" { + return path, nil + } + + candidates := []string{ + "/usr/local/bin/codex", + "/opt/homebrew/bin/codex", + } + if home, err := os.UserHomeDir(); err == nil { + candidates = append(candidates, + filepath.Join(home, ".cargo", "bin", "codex"), + filepath.Join(home, ".npm", "bin", "codex"), + ) + } + + for _, candidate := range candidates { + if fileExists(candidate) { + return candidate, nil + } + if err := ctx.Err(); err != nil { + return "", err + } + } + + return "codex", nil +} + +func (p *Plugin) codexBinary(ctx context.Context) (string, error) { + p.binaryMu.Lock() + defer p.binaryMu.Unlock() + + if p.resolvedBinary != "" { + return p.resolvedBinary, nil + } + + binary, err := ResolveCodexBinary(ctx) + if err != nil { + return "", err + } + p.resolvedBinary = binary + return binary, nil +} + +func appendNoUpdateCheckFlag(cmd *[]string) { + *cmd = append(*cmd, "-c", "check_for_update_on_startup=false") +} + +func appendApprovalFlags(cmd *[]string, permissions agent.PermissionMode) { + switch normalizePermissionMode(permissions) { + case agent.PermissionModeDefault: + // No flag: defer to the user's Codex config/default behavior. + case agent.PermissionModeAcceptEdits: + *cmd = append(*cmd, "--ask-for-approval", "on-request") + case agent.PermissionModeAuto: + *cmd = append(*cmd, "--ask-for-approval", "on-request", "-c", `approvals_reviewer="auto_review"`) + case agent.PermissionModeBypassPermissions: + *cmd = append(*cmd, "--dangerously-bypass-approvals-and-sandbox") + } +} + +func normalizePermissionMode(mode agent.PermissionMode) agent.PermissionMode { + switch mode { + case agent.PermissionModeDefault, + agent.PermissionModeAcceptEdits, + agent.PermissionModeAuto, + agent.PermissionModeBypassPermissions: + return mode + default: + return agent.PermissionModeDefault + } +} + +func fileExists(path string) bool { + info, err := os.Stat(path) + return err == nil && !info.IsDir() +} diff --git a/backend/internal/adapters/agent/codex/codex_test.go b/backend/internal/adapters/agent/codex/codex_test.go new file mode 100644 index 00000000..9dd1b3ac --- /dev/null +++ b/backend/internal/adapters/agent/codex/codex_test.go @@ -0,0 +1,335 @@ +package codex + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "reflect" + "strings" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +func TestGetLaunchCommandBuildsCrossPlatformArgv(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + cmd, err := plugin.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + Permissions: agent.PermissionModeBypassPermissions, + Prompt: "-fix this", + SystemPromptFile: filepath.Join("tmp", "prompt with spaces.md"), + SystemPrompt: "ignored", + }) + if err != nil { + t.Fatal(err) + } + + want := []string{ + "codex", + "-c", "check_for_update_on_startup=false", + "--dangerously-bypass-approvals-and-sandbox", + "-c", "model_instructions_file=" + filepath.Join("tmp", "prompt with spaces.md"), + "--", "-fix this", + } + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("unexpected command\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetLaunchCommandMapsApprovalModes(t *testing.T) { + tests := []struct { + name string + permission agent.PermissionMode + want []string + notExpected string + }{ + { + name: "default", + permission: agent.PermissionModeDefault, + notExpected: "--ask-for-approval", + }, + { + name: "accept-edits", + permission: agent.PermissionModeAcceptEdits, + want: []string{"--ask-for-approval", "on-request"}, + }, + { + name: "auto", + permission: agent.PermissionModeAuto, + want: []string{"--ask-for-approval", "on-request", "-c", `approvals_reviewer="auto_review"`}, + }, + { + name: "bypass-permissions", + permission: agent.PermissionModeBypassPermissions, + want: []string{"--dangerously-bypass-approvals-and-sandbox"}, + }, + { + name: "empty", + permission: "", + notExpected: "--ask-for-approval", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + cmd, err := plugin.GetLaunchCommand(context.Background(), agent.LaunchConfig{ + Permissions: tt.permission, + }) + if err != nil { + t.Fatal(err) + } + if len(tt.want) > 0 && !containsSubsequence(cmd, tt.want) { + t.Fatalf("command %#v does not contain %#v", cmd, tt.want) + } + if tt.notExpected != "" && contains(cmd, tt.notExpected) { + t.Fatalf("command %#v contains %q", cmd, tt.notExpected) + } + }) + } +} + +func TestGetPromptDeliveryStrategyIsInCommand(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + got, err := plugin.GetPromptDeliveryStrategy(context.Background(), agent.LaunchConfig{}) + if err != nil { + t.Fatal(err) + } + if got != agent.PromptDeliveryInCommand { + t.Fatalf("unexpected strategy: %q", got) + } +} + +func TestGetConfigSpecHasNoCustomFieldsYet(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + spec, err := plugin.GetConfigSpec(context.Background()) + if err != nil { + t.Fatal(err) + } + if len(spec.Fields) != 0 { + t.Fatalf("unexpected config fields: %#v", spec.Fields) + } +} + +func TestGetAgentHooksInstallsCodexHooks(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + workspace := t.TempDir() + hooksDir := filepath.Join(workspace, ".codex") + if err := os.MkdirAll(hooksDir, 0o755); err != nil { + t.Fatal(err) + } + hooksPath := filepath.Join(hooksDir, "hooks.json") + existing := `{"hooks":{"Stop":[{"matcher":null,"hooks":[{"type":"command","command":"custom stop hook","timeout":3}]}]}}` + if err := os.WriteFile(hooksPath, []byte(existing), 0o644); err != nil { + t.Fatal(err) + } + + cfg := agent.WorkspaceHookConfig{ + DataDir: t.TempDir(), + SessionID: "sess-1", + WorkspacePath: workspace, + } + if err := plugin.GetAgentHooks(context.Background(), cfg); err != nil { + t.Fatal(err) + } + // A second install must not duplicate Better-AO hook commands. + if err := plugin.GetAgentHooks(context.Background(), cfg); err != nil { + t.Fatal(err) + } + + data, err := os.ReadFile(hooksPath) + if err != nil { + t.Fatal(err) + } + var config codexHookFile + if err := json.Unmarshal(data, &config); err != nil { + t.Fatal(err) + } + if config.Hooks == nil { + t.Fatalf("hooks config missing hooks object: %#v", config) + } + templateHooks, err := codexEmbeddedHookGroups() + if err != nil { + t.Fatal(err) + } + for event, templateGroups := range templateHooks { + entries := config.Hooks[event] + for _, templateGroup := range templateGroups { + for _, hook := range templateGroup.Hooks { + count := countCodexHookCommand(entries, hook.Command) + if count != 1 { + t.Fatalf("%s command count = %d, want 1 in %#v", event, count, entries) + } + } + } + } + stopEntries := config.Hooks["Stop"] + if countCodexHookCommand(stopEntries, "custom stop hook") != 1 { + t.Fatalf("existing Stop hook was not preserved: %#v", stopEntries) + } + + configData, err := os.ReadFile(filepath.Join(workspace, ".codex", "config.toml")) + if err != nil { + t.Fatal(err) + } + if !strings.Contains(string(configData), codexHooksFeatureLine) { + t.Fatalf("config.toml missing hooks feature flag: %s", configData) + } +} + +func TestGetRestoreCommandReadsAgentSessionID(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + cmd, ok, err := plugin.GetRestoreCommand(context.Background(), agent.RestoreConfig{ + Permissions: agent.PermissionModeAuto, + Session: agent.SessionRef{ + Metadata: map[string]string{codexAgentSessionIDMetadataKey: "thread-123"}, + }, + }) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if !ok { + t.Fatal("ok = false, want true") + } + want := []string{ + "codex", + "resume", + "-c", "check_for_update_on_startup=false", + "--ask-for-approval", "on-request", + "-c", `approvals_reviewer="auto_review"`, + "thread-123", + } + if !reflect.DeepEqual(cmd, want) { + t.Fatalf("restore cmd\nwant: %#v\n got: %#v", want, cmd) + } +} + +func TestGetRestoreCommandFalseWithoutAgentSessionID(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + cases := []struct { + name string + ref agent.SessionRef + }{ + {"empty session ref", agent.SessionRef{}}, + {"empty metadata", agent.SessionRef{Metadata: map[string]string{}}}, + {"blank agent session metadata", agent.SessionRef{Metadata: map[string]string{codexAgentSessionIDMetadataKey: " "}}}, + {"workspace path only", agent.SessionRef{WorkspacePath: "/some/path"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + cmd, ok, err := plugin.GetRestoreCommand(context.Background(), agent.RestoreConfig{ + Permissions: agent.PermissionModeAuto, + Session: tc.ref, + }) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if ok { + t.Fatalf("ok = true, want false") + } + if cmd != nil { + t.Fatalf("cmd = %#v, want nil", cmd) + } + }) + } +} + +func TestSessionInfoReadsHookMetadata(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + info, ok, err := plugin.SessionInfo(context.Background(), agent.SessionRef{ + WorkspacePath: "/some/path", + Metadata: map[string]string{ + codexAgentSessionIDMetadataKey: "thread-123", + codexTitleMetadataKey: "Fix login redirect", + codexSummaryMetadataKey: "Updated the auth callback and tests.", + "ignored": "not returned", + }, + }) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if !ok { + t.Fatalf("ok = false, want true") + } + if info.AgentSessionID != "thread-123" { + t.Fatalf("AgentSessionID = %q, want native id", info.AgentSessionID) + } + if info.Title != "Fix login redirect" { + t.Fatalf("Title = %q, want hook title", info.Title) + } + if info.Summary != "Updated the auth callback and tests." { + t.Fatalf("Summary = %q, want hook summary", info.Summary) + } + if info.Metadata != nil { + t.Fatalf("Metadata = %#v, want nil for Codex", info.Metadata) + } +} + +func TestSessionInfoFalseWhenNoHookMetadata(t *testing.T) { + plugin := &Plugin{resolvedBinary: "codex"} + + info, ok, err := plugin.SessionInfo(context.Background(), agent.SessionRef{ + WorkspacePath: "/some/path", + Metadata: map[string]string{}, + }) + if err != nil { + t.Fatalf("err = %v, want nil", err) + } + if ok { + t.Fatalf("ok = true, want false") + } + if !reflect.DeepEqual(info, agent.SessionInfo{}) { + t.Fatalf("info = %#v, want zero value", info) + } +} + +func contains(values []string, needle string) bool { + for _, value := range values { + if value == needle { + return true + } + } + return false +} + +func containsSubsequence(values []string, needle []string) bool { + if len(needle) == 0 { + return true + } + + for start := range values { + if start+len(needle) > len(values) { + return false + } + ok := true + for offset, want := range needle { + if values[start+offset] != want { + ok = false + break + } + } + if ok { + return true + } + } + + return false +} + +func countCodexHookCommand(entries []codexMatcherGroup, command string) int { + count := 0 + for _, entry := range entries { + for _, hook := range entry.Hooks { + if hook.Command == command { + count++ + } + } + } + return count +} diff --git a/backend/internal/adapters/agent/codex/hooks.go b/backend/internal/adapters/agent/codex/hooks.go new file mode 100644 index 00000000..d7b0ee89 --- /dev/null +++ b/backend/internal/adapters/agent/codex/hooks.go @@ -0,0 +1,236 @@ +package codex + +import ( + "context" + "embed" + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" +) + +const ( + codexHooksDirName = ".codex" + codexHooksFileName = "hooks.json" + codexHooksTemplate = ".codex/hooks.json" + + codexConfigFileName = "config.toml" + codexHooksFeatureLine = "hooks = true" + codexLegacyHookFeatureLine = "codex_hooks = true" +) + +//go:embed .codex/hooks.json +var codexHookTemplateFS embed.FS + +type codexHookFile struct { + Hooks map[string][]codexMatcherGroup `json:"hooks"` +} + +type codexMatcherGroup struct { + Matcher *string `json:"matcher"` + Hooks []codexHookEntry `json:"hooks"` +} + +type codexHookEntry struct { + Type string `json:"type"` + Command string `json:"command"` + Timeout int `json:"timeout,omitempty"` +} + +// GetAgentHooks installs Better-AO's Codex hooks into the worktree-local +// .codex/hooks.json file. Existing hook entries are preserved and duplicate +// Better-AO commands are not appended. +func (p *Plugin) GetAgentHooks(ctx context.Context, cfg agent.WorkspaceHookConfig) error { + if err := ctx.Err(); err != nil { + return err + } + if strings.TrimSpace(cfg.WorkspacePath) == "" { + return errors.New("codex.GetAgentHooks: WorkspacePath is required") + } + + hooksPath := filepath.Join(cfg.WorkspacePath, codexHooksDirName, codexHooksFileName) + topLevel := map[string]json.RawMessage{} + rawHooks := map[string]json.RawMessage{} + + if existingData, err := os.ReadFile(hooksPath); err == nil { + if strings.TrimSpace(string(existingData)) != "" { + if err := json.Unmarshal(existingData, &topLevel); err != nil { + return fmt.Errorf("codex.GetAgentHooks: parse %s: %w", hooksPath, err) + } + if hooksRaw, ok := topLevel["hooks"]; ok { + if err := json.Unmarshal(hooksRaw, &rawHooks); err != nil { + return fmt.Errorf("codex.GetAgentHooks: parse hooks in %s: %w", hooksPath, err) + } + } + } + } else if !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("codex.GetAgentHooks: read %s: %w", hooksPath, err) + } + + templateHooks, err := codexEmbeddedHookGroups() + if err != nil { + return err + } + for event, templateGroups := range templateHooks { + var existingGroups []codexMatcherGroup + if err := parseCodexHookType(rawHooks, event, &existingGroups); err != nil { + return err + } + for _, group := range templateGroups { + for _, hook := range group.Hooks { + if !codexHookCommandExists(existingGroups, hook.Command) { + existingGroups = addCodexHook(existingGroups, hook) + } + } + } + if err := marshalCodexHookType(rawHooks, event, existingGroups); err != nil { + return err + } + } + + hooksJSON, err := json.Marshal(rawHooks) + if err != nil { + return fmt.Errorf("codex.GetAgentHooks: encode hooks: %w", err) + } + topLevel["hooks"] = hooksJSON + + if err := os.MkdirAll(filepath.Dir(hooksPath), 0o750); err != nil { + return fmt.Errorf("codex.GetAgentHooks: create hook dir: %w", err) + } + data, err := json.MarshalIndent(topLevel, "", " ") + if err != nil { + return fmt.Errorf("codex.GetAgentHooks: encode %s: %w", hooksPath, err) + } + data = append(data, '\n') + if err := os.WriteFile(hooksPath, data, 0o600); err != nil { + return fmt.Errorf("codex.GetAgentHooks: write %s: %w", hooksPath, err) + } + + if err := ensureCodexHooksFeatureEnabled(cfg.WorkspacePath); err != nil { + return fmt.Errorf("codex.GetAgentHooks: enable hooks feature: %w", err) + } + return nil +} + +func codexEmbeddedHookGroups() (map[string][]codexMatcherGroup, error) { + data, err := codexHookTemplateFS.ReadFile(codexHooksTemplate) + if err != nil { + return nil, fmt.Errorf("codex.GetAgentHooks: read embedded %s: %w", codexHooksTemplate, err) + } + var file codexHookFile + if err := json.Unmarshal(data, &file); err != nil { + return nil, fmt.Errorf("codex.GetAgentHooks: parse embedded %s: %w", codexHooksTemplate, err) + } + if file.Hooks == nil { + return map[string][]codexMatcherGroup{}, nil + } + return file.Hooks, nil +} + +func parseCodexHookType(rawHooks map[string]json.RawMessage, event string, target *[]codexMatcherGroup) error { + data, ok := rawHooks[event] + if !ok { + return nil + } + if err := json.Unmarshal(data, target); err != nil { + return fmt.Errorf("codex.GetAgentHooks: parse %s hooks: %w", event, err) + } + return nil +} + +func marshalCodexHookType(rawHooks map[string]json.RawMessage, event string, groups []codexMatcherGroup) error { + if len(groups) == 0 { + delete(rawHooks, event) + return nil + } + data, err := json.Marshal(groups) + if err != nil { + return fmt.Errorf("codex.GetAgentHooks: encode %s hooks: %w", event, err) + } + rawHooks[event] = data + return nil +} + +func codexHookCommandExists(groups []codexMatcherGroup, command string) bool { + for _, group := range groups { + for _, hook := range group.Hooks { + if hook.Command == command { + return true + } + } + } + return false +} + +func addCodexHook(groups []codexMatcherGroup, hook codexHookEntry) []codexMatcherGroup { + for i, group := range groups { + if group.Matcher == nil { + groups[i].Hooks = append(groups[i].Hooks, hook) + return groups + } + } + return append(groups, codexMatcherGroup{ + Matcher: nil, + Hooks: []codexHookEntry{hook}, + }) +} + +func ensureCodexHooksFeatureEnabled(workspacePath string) error { + configPath := filepath.Join(workspacePath, codexHooksDirName, codexConfigFileName) + data, err := os.ReadFile(configPath) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("read config.toml: %w", err) + } + + content := string(data) + hasNew := containsCodexFeatureLine(content, codexHooksFeatureLine) + hasLegacy := containsCodexFeatureLine(content, codexLegacyHookFeatureLine) + switch { + case hasNew && hasLegacy: + content = stripCodexLegacyHookFeatureLine(content) + case hasNew: + return nil + case hasLegacy: + content = strings.Replace(content, codexLegacyHookFeatureLine, codexHooksFeatureLine, 1) + case strings.Contains(content, "[features]"): + content = strings.Replace(content, "[features]", "[features]\n"+codexHooksFeatureLine, 1) + default: + if content != "" && !strings.HasSuffix(content, "\n") { + content += "\n" + } + content += "\n[features]\n" + codexHooksFeatureLine + "\n" + } + + if err := os.MkdirAll(filepath.Dir(configPath), 0o750); err != nil { + return fmt.Errorf("create .codex directory: %w", err) + } + if err := os.WriteFile(configPath, []byte(content), 0o600); err != nil { + return fmt.Errorf("write config.toml: %w", err) + } + return nil +} + +func containsCodexFeatureLine(content, line string) bool { + for raw := range strings.SplitSeq(content, "\n") { + if strings.TrimSpace(raw) == line { + return true + } + } + return false +} + +func stripCodexLegacyHookFeatureLine(content string) string { + idx := strings.Index(content, codexLegacyHookFeatureLine) + if idx < 0 { + return content + } + end := idx + len(codexLegacyHookFeatureLine) + if end < len(content) && content[end] == '\n' { + end++ + } + return content[:idx] + content[end:] +} diff --git a/backend/internal/adapters/agent/portshim/shim.go b/backend/internal/adapters/agent/portshim/shim.go new file mode 100644 index 00000000..541fbcc7 --- /dev/null +++ b/backend/internal/adapters/agent/portshim/shim.go @@ -0,0 +1,114 @@ +// Package portshim bridges the richer adapters/agent.Agent interface onto the +// narrower ports.Agent the Session Manager consumes. The richer interface +// returns argv slices and takes a context; ports.Agent returns a single shell +// string and is context-free. The shim joins argv with POSIX shell quoting so +// the zellij runtime, which evaluates LaunchCommand under `sh -lc`, sees the +// agent's argv intact. +package portshim + +import ( + "context" + "strings" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" +) + +// Shim wraps an adapters/agent.Agent and satisfies ports.Agent. The shim is +// context-free at its API surface; it threads context.Background() into the +// richer interface. That matches the existing ports.Agent shape — extending it +// is a separate change. +type Shim struct { + agent agent.Agent +} + +// New constructs a Shim. agent is required; nil is not supported. +func New(a agent.Agent) *Shim { return &Shim{agent: a} } + +var _ ports.Agent = (*Shim)(nil) + +// GetLaunchCommand asks the wrapped agent for its launch argv and renders it as +// a single POSIX-shell-safe string. An adapter error or empty argv yields "". +func (s *Shim) GetLaunchCommand(cfg ports.AgentConfig) string { + argv, err := s.agent.GetLaunchCommand(context.Background(), launchConfigFor(cfg)) + if err != nil { + return "" + } + return joinShellArgv(argv) +} + +// GetEnvironment returns nil: the richer agent interface doesn't carry the env +// keys ports.AgentConfig exposes, and the SM layers AO_SESSION_ID, +// AO_PROJECT_ID, AO_ISSUE_ID on top of whatever the agent contributes. A nil +// map is fine here — session.spawnEnv treats nil as empty. +func (s *Shim) GetEnvironment(ports.AgentConfig) map[string]string { + return nil +} + +// GetRestoreCommand resumes a native agent session given its agentSessionID and +// returns the resume command as a POSIX-shell-safe string. An adapter error or +// ok=false yields "" — the SM falls back to a fresh Spawn. +func (s *Shim) GetRestoreCommand(agentSessionID string) string { + cfg := agent.RestoreConfig{ + Session: agent.SessionRef{ + ID: agentSessionID, + Metadata: map[string]string{ + agent.MetadataKeyAgentSessionID: agentSessionID, + }, + }, + } + argv, ok, err := s.agent.GetRestoreCommand(context.Background(), cfg) + if err != nil || !ok { + return "" + } + return joinShellArgv(argv) +} + +func launchConfigFor(cfg ports.AgentConfig) agent.LaunchConfig { + return agent.LaunchConfig{ + SessionID: string(cfg.SessionID), + WorkspacePath: cfg.WorkspacePath, + Prompt: cfg.Prompt, + } +} + +// joinShellArgv renders argv as a single string the POSIX shell will re-parse +// into the same tokens. Each arg is quoted in single quotes unless it consists +// only of characters guaranteed safe to leave bare. +func joinShellArgv(argv []string) string { + if len(argv) == 0 { + return "" + } + parts := make([]string, len(argv)) + for i, a := range argv { + parts[i] = shellQuote(a) + } + return strings.Join(parts, " ") +} + +func shellQuote(s string) string { + if s == "" { + return "''" + } + if isShellSafe(s) { + return s + } + return "'" + strings.ReplaceAll(s, "'", `'\''`) + "'" +} + +// isShellSafe matches the conservative bash-completion convention: letters, +// digits, and a handful of punctuation that never trigger expansion or word +// splitting. Anything else is quoted. +func isShellSafe(s string) bool { + for _, r := range s { + switch { + case r >= 'a' && r <= 'z', + r >= 'A' && r <= 'Z', + r >= '0' && r <= '9', + r == '-', r == '_', r == '/', r == '.', r == ',', r == ':', r == '+', r == '@', r == '=': + continue + } + return false + } + return true +} diff --git a/backend/internal/adapters/agent/portshim/shim_test.go b/backend/internal/adapters/agent/portshim/shim_test.go new file mode 100644 index 00000000..46f8c77c --- /dev/null +++ b/backend/internal/adapters/agent/portshim/shim_test.go @@ -0,0 +1,164 @@ +package portshim_test + +import ( + "context" + "errors" + "strings" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent/portshim" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + sessionmanager "github.com/aoagents/agent-orchestrator/backend/internal/session_manager" +) + +type fakeAgent struct { + launchCmd []string + launchErr error + restoreCmd []string + restoreOK bool + restoreErr error + gotLaunchCfg agent.LaunchConfig + gotRestoreCfg agent.RestoreConfig +} + +func (f *fakeAgent) GetConfigSpec(context.Context) (agent.ConfigSpec, error) { + return agent.ConfigSpec{}, nil +} +func (f *fakeAgent) GetLaunchCommand(_ context.Context, cfg agent.LaunchConfig) ([]string, error) { + f.gotLaunchCfg = cfg + return f.launchCmd, f.launchErr +} +func (f *fakeAgent) GetPromptDeliveryStrategy(context.Context, agent.LaunchConfig) (agent.PromptDeliveryStrategy, error) { + return agent.PromptDeliveryInCommand, nil +} +func (f *fakeAgent) GetAgentHooks(context.Context, agent.WorkspaceHookConfig) error { return nil } +func (f *fakeAgent) GetRestoreCommand(_ context.Context, cfg agent.RestoreConfig) ([]string, bool, error) { + f.gotRestoreCfg = cfg + return f.restoreCmd, f.restoreOK, f.restoreErr +} +func (f *fakeAgent) SessionInfo(context.Context, agent.SessionRef) (agent.SessionInfo, bool, error) { + return agent.SessionInfo{}, false, nil +} + +func TestSatisfiesPortsAgent(t *testing.T) { + var _ ports.Agent = (*portshim.Shim)(nil) +} + +func TestGetLaunchCommand_JoinsArgvShellSafely(t *testing.T) { + tests := []struct { + name string + argv []string + want string + }{ + {"simple", []string{"claude"}, "claude"}, + {"flags and prompt", []string{"claude", "--", "do it"}, "claude -- 'do it'"}, + {"path with spaces", []string{"/Applications/My App/claude", "--flag"}, "'/Applications/My App/claude' --flag"}, + {"prompt with single quote", []string{"claude", "--", "it's fine"}, `claude -- 'it'\''s fine'`}, + {"empty argv", []string{}, ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s := portshim.New(&fakeAgent{launchCmd: tc.argv}) + got := s.GetLaunchCommand(ports.AgentConfig{}) + if got != tc.want { + t.Fatalf("got %q want %q", got, tc.want) + } + }) + } +} + +func TestGetLaunchCommand_PropagatesAgentConfig(t *testing.T) { + fake := &fakeAgent{launchCmd: []string{"claude"}} + s := portshim.New(fake) + cfg := ports.AgentConfig{SessionID: "p-1", WorkspacePath: "/ws/p-1", Prompt: "hello"} + _ = s.GetLaunchCommand(cfg) + if fake.gotLaunchCfg.SessionID != "p-1" { + t.Errorf("SessionID not propagated: %+v", fake.gotLaunchCfg) + } + if fake.gotLaunchCfg.WorkspacePath != "/ws/p-1" { + t.Errorf("WorkspacePath not propagated: %+v", fake.gotLaunchCfg) + } + if fake.gotLaunchCfg.Prompt != "hello" { + t.Errorf("Prompt not propagated: %+v", fake.gotLaunchCfg) + } +} + +func TestGetLaunchCommand_AgentErrorReturnsEmpty(t *testing.T) { + fake := &fakeAgent{launchErr: errors.New("boom")} + s := portshim.New(fake) + got := s.GetLaunchCommand(ports.AgentConfig{SessionID: "p-1"}) + if got != "" { + t.Fatalf("expected empty on error, got %q", got) + } +} + +func TestGetEnvironment_ReturnsAgentEnvKeysOnly(t *testing.T) { + // The richer Agent interface doesn't carry the env keys the SM port supplies, + // so the shim has nothing agent-specific to surface. SM layers AO_* on top. + s := portshim.New(&fakeAgent{}) + got := s.GetEnvironment(ports.AgentConfig{SessionID: "p-1"}) + if len(got) != 0 { + t.Fatalf("expected empty env from shim, got %v", got) + } + for _, k := range []string{sessionmanager.EnvSessionID, sessionmanager.EnvProjectID, sessionmanager.EnvIssueID} { + if _, ok := got[k]; ok { + t.Errorf("shim must not pre-populate AO env key %s; SM owns it", k) + } + } +} + +func TestGetRestoreCommand_JoinsWhenOK(t *testing.T) { + fake := &fakeAgent{restoreCmd: []string{"claude", "--resume", "abc 123"}, restoreOK: true} + s := portshim.New(fake) + got := s.GetRestoreCommand("abc 123") + want := `claude --resume 'abc 123'` + if got != want { + t.Fatalf("got %q want %q", got, want) + } + if fake.gotRestoreCfg.Session.ID != "abc 123" { + t.Errorf("session id not propagated: %+v", fake.gotRestoreCfg) + } +} + +func TestGetRestoreCommand_NotOKReturnsEmpty(t *testing.T) { + fake := &fakeAgent{restoreOK: false} + s := portshim.New(fake) + if got := s.GetRestoreCommand("anything"); got != "" { + t.Fatalf("expected empty when not restorable, got %q", got) + } +} + +func TestGetRestoreCommand_ErrorReturnsEmpty(t *testing.T) { + fake := &fakeAgent{restoreErr: errors.New("boom")} + s := portshim.New(fake) + if got := s.GetRestoreCommand("x"); got != "" { + t.Fatalf("expected empty on restore error, got %q", got) + } +} + +func TestGetRestoreCommand_PassesAgentSessionIDAsMetadata(t *testing.T) { + // Claude-code (and Codex) read the native session id off cfg.Session.Metadata + // ["agentSessionId"] to rebuild the --resume command. Pass it via both Session.ID + // (the legacy fallback) and Session.Metadata so the richer adapter can find it. + fake := &fakeAgent{restoreCmd: []string{"claude", "--resume", "x"}, restoreOK: true} + s := portshim.New(fake) + _ = s.GetRestoreCommand("native-uuid") + gotID := fake.gotRestoreCfg.Session.ID + if gotID != "native-uuid" { + t.Errorf("Session.ID want native-uuid, got %q", gotID) + } + if m := fake.gotRestoreCfg.Session.Metadata[agent.MetadataKeyAgentSessionID]; m != "native-uuid" { + t.Errorf("Session.Metadata[%s] want native-uuid, got %q", agent.MetadataKeyAgentSessionID, m) + } +} + +func TestShellQuotingDoesNotDoubleQuoteSafeStrings(t *testing.T) { + // Safe identifiers (letters, digits, dash, dot, slash, underscore) should + // pass through unquoted; quoting them would inflate every command. + s := portshim.New(&fakeAgent{launchCmd: []string{"/usr/local/bin/claude", "--session-id", "abc-123_xyz.uuid"}}) + got := s.GetLaunchCommand(ports.AgentConfig{}) + if strings.Contains(got, "'") { + t.Fatalf("got unexpected quotes: %q", got) + } +} diff --git a/backend/internal/adapters/messenger/inbox/inbox.go b/backend/internal/adapters/messenger/inbox/inbox.go new file mode 100644 index 00000000..55aaf524 --- /dev/null +++ b/backend/internal/adapters/messenger/inbox/inbox.go @@ -0,0 +1,110 @@ +// Package inbox implements ports.AgentMessenger by writing each message as a +// file in /.ao/inbox/. The agent reads its inbox on demand; +// pinging the runtime pane to consume new files is a separate concern that +// lives in the runtime adapter, not here. +package inbox + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "os" + "path/filepath" + "strconv" + "time" + + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" +) + +// SessionWorkspace resolves a session id to the absolute path of its workspace. +// The sqlite store satisfies this via GetSession; the adapter is in +// daemon/session_wiring.go. +type SessionWorkspace interface { + WorkspacePath(ctx context.Context, id domain.SessionID) (string, error) +} + +// Messenger writes inbox files into per-session workspaces. +type Messenger struct { + lookup SessionWorkspace + clock func() time.Time +} + +// New builds a Messenger over the given workspace lookup. lookup is required. +func New(lookup SessionWorkspace) *Messenger { + return &Messenger{lookup: lookup, clock: time.Now} +} + +var _ ports.AgentMessenger = (*Messenger)(nil) + +// Send writes message into /.ao/inbox/_.md. +// +// Filename collisions are practically impossible: nanosecond timestamp plus an +// 8-char hash of the body. We do not retry on EEXIST. +// +// Symlink safety: if .ao or .ao/inbox already exists as a symlink, refuse. +// Otherwise os.MkdirAll creates real directories and os.WriteFile (which uses +// O_CREATE|O_WRONLY|O_TRUNC without O_NOFOLLOW) writes the message body. The +// inbox is owned by ao; a symlink there is either user misconfig or attack. +func (m *Messenger) Send(ctx context.Context, id domain.SessionID, message string) error { + ws, err := m.lookup.WorkspacePath(ctx, id) + if err != nil { + return fmt.Errorf("inbox: lookup workspace for %s: %w", id, err) + } + if ws == "" { + return fmt.Errorf("inbox: empty workspace path for %s", id) + } + + aoDir := filepath.Join(ws, ".ao") + if err := ensureRealDir(aoDir); err != nil { + return fmt.Errorf("inbox: prepare .ao for %s: %w", id, err) + } + inboxDir := filepath.Join(aoDir, "inbox") + if err := ensureRealDir(inboxDir); err != nil { + return fmt.Errorf("inbox: prepare inbox for %s: %w", id, err) + } + + name := filenameFor(m.clock(), message) + if err := os.WriteFile(filepath.Join(inboxDir, name), []byte(message), 0o600); err != nil { + return fmt.Errorf("inbox: write %s for %s: %w", name, id, err) + } + return nil +} + +// ensureRealDir creates path if missing (0755), refuses if path is a symlink. +// Lstat (not Stat) is used so a symlink isn't followed into a different tree. +// +// The workspace root itself is not Lstat-checked because gitworktree.Workspace +// resolves ManagedRoot to an absolute, symlink-free path at construction +// (gitworktree.physicalAbs); per-session workspaces under it are created by ao. +// A symlinked .ao or .ao/inbox inside an ao-owned workspace would be user +// misconfig or attack, and is the only segment that can be tampered with +// between Spawn and Send. +func ensureRealDir(path string) error { + info, err := os.Lstat(path) + switch { + case err == nil: + if info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("%q is a symlink; refusing to follow", path) + } + if !info.IsDir() { + return fmt.Errorf("%q exists and is not a directory", path) + } + return nil + case errors.Is(err, os.ErrNotExist): + return os.MkdirAll(path, 0o750) + default: + return err + } +} + +// filenameFor builds a sortable, collision-resistant name from the timestamp +// and message body. Underscore separator keeps the timestamp's own dashes +// distinguishable from the hash prefix. +func filenameFor(t time.Time, message string) string { + sum := sha256.Sum256([]byte(message)) + hash := hex.EncodeToString(sum[:])[:8] + return strconv.FormatInt(t.UnixNano(), 10) + "_" + hash + ".md" +} diff --git a/backend/internal/adapters/messenger/inbox/inbox_test.go b/backend/internal/adapters/messenger/inbox/inbox_test.go new file mode 100644 index 00000000..f9c0235e --- /dev/null +++ b/backend/internal/adapters/messenger/inbox/inbox_test.go @@ -0,0 +1,152 @@ +package inbox_test + +import ( + "context" + "errors" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/messenger/inbox" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" +) + +func TestSatisfiesAgentMessenger(t *testing.T) { + var _ ports.AgentMessenger = (*inbox.Messenger)(nil) +} + +type fakeLookup struct { + path string + err error +} + +func (f fakeLookup) WorkspacePath(context.Context, domain.SessionID) (string, error) { + return f.path, f.err +} + +func TestSend_WritesMessageFile(t *testing.T) { + dir := t.TempDir() + m := inbox.New(fakeLookup{path: dir}) + if err := m.Send(context.Background(), "s-1", "hello agent"); err != nil { + t.Fatal(err) + } + inboxDir := filepath.Join(dir, ".ao", "inbox") + entries, err := os.ReadDir(inboxDir) + if err != nil { + t.Fatalf("inbox dir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("want 1 file, got %d", len(entries)) + } + name := entries[0].Name() + if !strings.HasSuffix(name, ".md") { + t.Errorf("want .md suffix, got %q", name) + } + body, err := os.ReadFile(filepath.Join(inboxDir, name)) + if err != nil { + t.Fatal(err) + } + if string(body) != "hello agent" { + t.Errorf("body %q want %q", body, "hello agent") + } +} + +func TestSend_CreatesInboxDirIfMissing(t *testing.T) { + dir := t.TempDir() + // dir contains no .ao yet. + m := inbox.New(fakeLookup{path: dir}) + if err := m.Send(context.Background(), "s-1", "x"); err != nil { + t.Fatal(err) + } + if _, err := os.Stat(filepath.Join(dir, ".ao", "inbox")); err != nil { + t.Fatalf("inbox dir not created: %v", err) + } +} + +func TestSend_TwoSendsProduceTwoFiles(t *testing.T) { + dir := t.TempDir() + m := inbox.New(fakeLookup{path: dir}) + ctx := context.Background() + if err := m.Send(ctx, "s-1", "first"); err != nil { + t.Fatal(err) + } + if err := m.Send(ctx, "s-1", "second"); err != nil { + t.Fatal(err) + } + entries, _ := os.ReadDir(filepath.Join(dir, ".ao", "inbox")) + if len(entries) != 2 { + t.Fatalf("want 2 files, got %d", len(entries)) + } +} + +func TestSend_UnknownSessionReturnsError(t *testing.T) { + m := inbox.New(fakeLookup{err: errors.New("not found")}) + err := m.Send(context.Background(), "s-1", "x") + if err == nil { + t.Fatal("expected error when workspace lookup fails") + } + if !strings.Contains(err.Error(), "not found") { + t.Errorf("error should wrap lookup error, got %v", err) + } +} + +func TestSend_EmptyWorkspacePathReturnsError(t *testing.T) { + // A spawned-but-not-yet-mark-spawned row has WorkspacePath == "". The + // messenger must refuse rather than write into "/.ao/inbox/...". + m := inbox.New(fakeLookup{path: ""}) + if err := m.Send(context.Background(), "s-1", "x"); err == nil { + t.Fatal("expected error for empty workspace path") + } +} + +func TestSend_SymlinkedInboxIsRefused(t *testing.T) { + dir := t.TempDir() + // Create .ao/inbox as a symlink to a sibling directory. + target := t.TempDir() + if err := os.MkdirAll(filepath.Join(dir, ".ao"), 0o755); err != nil { + t.Fatal(err) + } + if err := os.Symlink(target, filepath.Join(dir, ".ao", "inbox")); err != nil { + t.Skipf("symlink not supported: %v", err) + } + m := inbox.New(fakeLookup{path: dir}) + err := m.Send(context.Background(), "s-1", "x") + if err == nil { + t.Fatal("expected refusal when inbox is a symlink") + } + if entries, _ := os.ReadDir(target); len(entries) != 0 { + t.Errorf("symlink target should not have received writes, got %d entries", len(entries)) + } +} + +func TestSend_EmptyMessageStillWritesAFile(t *testing.T) { + dir := t.TempDir() + m := inbox.New(fakeLookup{path: dir}) + if err := m.Send(context.Background(), "s-1", ""); err != nil { + t.Fatal(err) + } + entries, _ := os.ReadDir(filepath.Join(dir, ".ao", "inbox")) + if len(entries) != 1 { + t.Fatalf("want 1 file even for empty message, got %d", len(entries)) + } +} + +func TestSend_FilenameContainsTimestampAndHashPrefix(t *testing.T) { + dir := t.TempDir() + m := inbox.New(fakeLookup{path: dir}) + if err := m.Send(context.Background(), "s-1", "payload"); err != nil { + t.Fatal(err) + } + entries, _ := os.ReadDir(filepath.Join(dir, ".ao", "inbox")) + name := strings.TrimSuffix(entries[0].Name(), ".md") + // Format: _; underscore separator avoids the timestamp's own dashes. + parts := strings.SplitN(name, "_", 2) + if len(parts) != 2 { + t.Fatalf("filename should be _.md, got %q", entries[0].Name()) + } + if len(parts[1]) < 4 { + t.Errorf("hash prefix too short: %q", parts[1]) + } +} diff --git a/backend/internal/adapters/registry.go b/backend/internal/adapters/registry.go new file mode 100644 index 00000000..a384979a --- /dev/null +++ b/backend/internal/adapters/registry.go @@ -0,0 +1,83 @@ +// Package adapters defines the plugin contract every external integration +// (agent, tracker, scm, runtime) satisfies plus a registry that holds the +// concrete plugins the daemon resolves by id. +package adapters + +import ( + "fmt" + "sort" +) + +// Capability tags a Manifest with the role(s) a plugin fills. +type Capability string + +// Known capabilities. A plugin may advertise more than one. +const ( + CapabilityAgent Capability = "agent" + CapabilityIssueTracker Capability = "issue-tracker" +) + +// Manifest is the self-describing record every Adapter returns. +type Manifest struct { + ID string `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Version string `json:"version"` + Capabilities []Capability `json:"capabilities"` +} + +// Adapter is the minimal contract every registered plugin satisfies: it can +// describe itself via Manifest. Per-capability behaviour lives on richer +// interfaces (e.g. agent.Agent) that callers obtain via type assertion. +type Adapter interface { + Manifest() Manifest +} + +// Registry holds the daemon's resolved plugins, keyed by Manifest.ID. +type Registry struct { + adapters map[string]Adapter +} + +// NewRegistry returns an empty Registry ready to accept Register calls. +func NewRegistry() *Registry { + return &Registry{ + adapters: make(map[string]Adapter), + } +} + +// Register adds adapter under its Manifest.ID, returning an error when the id +// is empty or already in use. +func (r *Registry) Register(adapter Adapter) error { + manifest := adapter.Manifest() + if manifest.ID == "" { + return fmt.Errorf("adapter id is required") + } + if _, exists := r.adapters[manifest.ID]; exists { + return fmt.Errorf("adapter %q is already registered", manifest.ID) + } + + r.adapters[manifest.ID] = adapter + return nil +} + +// Get returns the registered adapter with the given id, or nil and false +// when no such adapter exists. +func (r *Registry) Get(id string) (Adapter, bool) { + p, ok := r.adapters[id] + return p, ok +} + +// Manifests returns every registered adapter's Manifest, sorted by id for +// deterministic output. +func (r *Registry) Manifests() []Manifest { + manifests := make([]Manifest, 0, len(r.adapters)) + for _, adapter := range r.adapters { + manifests = append(manifests, adapter.Manifest()) + } + + sort.Slice(manifests, func(i, j int) bool { + return manifests[i].ID < manifests[j].ID + }) + + return manifests +} diff --git a/backend/internal/adapters/scm/github/find_branch_pr.go b/backend/internal/adapters/scm/github/find_branch_pr.go new file mode 100644 index 00000000..7ee79981 --- /dev/null +++ b/backend/internal/adapters/scm/github/find_branch_pr.go @@ -0,0 +1,94 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +// FindOpenPRForBranch returns the canonical github.com URL of the most +// recently updated open PR whose head ref is "{owner}:{branch}", or "" +// with a nil error when no open PR matches. +// +// The poller uses this for branch-based discovery: since the session +// record does not (yet) carry a stored PR URL, the only way to find +// "the PR for this session" is by the workspace branch. The endpoint +// hit is GET /repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open +// per the GitHub REST API. +// +// When multiple open PRs share the same head ref (rare but legal — +// e.g. forks that pushed to the same branch name), we pick the most +// recently updated one rather than failing closed. Failing closed +// would silently stop observing the PR every time a stale duplicate +// shows up. +func (p *Provider) FindOpenPRForBranch(ctx context.Context, owner, repo, branch string) (string, error) { + owner = strings.TrimSpace(owner) + repo = strings.TrimSpace(repo) + branch = strings.TrimSpace(branch) + if owner == "" || repo == "" || branch == "" { + return "", fmt.Errorf("github scm: FindOpenPRForBranch requires owner/repo/branch (got %q/%q/%q)", owner, repo, branch) + } + + q := url.Values{} + q.Set("state", "open") + q.Set("head", owner+":"+branch) + q.Set("per_page", "100") + + resp, err := p.client.doREST(ctx, http.MethodGet, repoPath(owner, repo, "pulls"), q, nil) + if err != nil { + return "", err + } + if len(resp.Body) == 0 { + return "", nil + } + var list []listedPR + if err := json.Unmarshal(resp.Body, &list); err != nil { + return "", fmt.Errorf("github scm: decode pulls list: %w", err) + } + if len(list) == 0 { + return "", nil + } + + best := -1 + var bestTime time.Time + for i, pr := range list { + if !strings.EqualFold(pr.State, "open") { + continue + } + t := parsePRTimestamp(pr.UpdatedAt) + if best < 0 || t.After(bestTime) { + best = i + bestTime = t + } + } + if best < 0 { + return "", nil + } + chosen := list[best] + if chosen.HTMLURL != "" { + return chosen.HTMLURL, nil + } + // Construct the canonical web URL from owner/repo/number when the + // API response omits html_url (some enterprise responses elide it). + return "https://github.com/" + owner + "/" + repo + "/pull/" + strconv.Itoa(chosen.Number), nil +} + +type listedPR struct { + Number int `json:"number"` + State string `json:"state"` + HTMLURL string `json:"html_url"` + UpdatedAt string `json:"updated_at"` +} + +func parsePRTimestamp(s string) time.Time { + t, err := time.Parse(time.RFC3339, s) + if err != nil { + return time.Time{} + } + return t +} diff --git a/backend/internal/adapters/scm/github/find_branch_pr_test.go b/backend/internal/adapters/scm/github/find_branch_pr_test.go new file mode 100644 index 00000000..39b5be77 --- /dev/null +++ b/backend/internal/adapters/scm/github/find_branch_pr_test.go @@ -0,0 +1,131 @@ +package github + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + "testing" +) + +func TestFindOpenPRForBranchSingleMatch(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("head"); got != "acme:feat/x" { + t.Errorf("head query = %q, want acme:feat/x", got) + } + if got := r.URL.Query().Get("state"); got != "open" { + t.Errorf("state query = %q, want open", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 7, "state": "open", "html_url": "https://github.com/acme/repo/pull/7", "updated_at": "2026-05-01T10:00:00Z"}, + }) + }) + + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "https://github.com/acme/repo/pull/7" { + t.Fatalf("url = %q", url) + } +} + +func TestFindOpenPRForBranchNoMatch(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("[]")) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "" { + t.Fatalf("url = %q, want empty", url) + } +} + +func TestFindOpenPRForBranchMultiplePicksMostRecent(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 1, "state": "open", "html_url": "https://github.com/acme/repo/pull/1", "updated_at": "2026-01-01T00:00:00Z"}, + {"number": 9, "state": "open", "html_url": "https://github.com/acme/repo/pull/9", "updated_at": "2026-05-01T00:00:00Z"}, + {"number": 4, "state": "open", "html_url": "https://github.com/acme/repo/pull/4", "updated_at": "2026-03-01T00:00:00Z"}, + }) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "https://github.com/acme/repo/pull/9" { + t.Fatalf("url = %q, want pull/9", url) + } +} + +func TestFindOpenPRForBranchEmptyInputsError(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + for _, tc := range []struct{ owner, repo, branch string }{ + {"", "repo", "b"}, + {"o", "", "b"}, + {"o", "r", ""}, + } { + _, err := p.FindOpenPRForBranch(ctx(), tc.owner, tc.repo, tc.branch) + if err == nil { + t.Errorf("expected error for empty input %+v", tc) + } + } +} + +func TestFindOpenPRForBranchRateLimited(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", "1700000000") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"API rate limit exceeded"}`)) + }) + _, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if !errors.Is(err, ErrRateLimited) { + t.Fatalf("err = %v, want ErrRateLimited", err) + } +} + +func TestFindOpenPRForBranchAuthFailed(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"Bad credentials"}`)) + }) + _, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if !errors.Is(err, ErrAuthFailed) { + t.Fatalf("err = %v, want ErrAuthFailed", err) + } +} + +func TestFindOpenPRForBranchSynthesizesURLWhenHTMLEmpty(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 42, "state": "open", "updated_at": "2026-05-01T10:00:00Z"}, + }) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("err = %v", err) + } + if !strings.HasSuffix(url, "/acme/repo/pull/42") { + t.Fatalf("url = %q, want suffix /acme/repo/pull/42", url) + } +} diff --git a/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver.go b/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver.go new file mode 100644 index 00000000..496f8adc --- /dev/null +++ b/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver.go @@ -0,0 +1,47 @@ +// Package projectresolver supplies gitworktree.Workspace with a RepoResolver +// backed by the project.Manager. It lives in its own subpackage so the +// gitworktree package can stay free of the project package import (and the +// import cycle that would create if project ever depended on gitworktree). +package projectresolver + +import ( + "context" + "fmt" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/workspace/gitworktree" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// Resolver maps a domain.ProjectID to its local repo path by consulting the +// project store via project.Manager. +type Resolver struct { + projects project.Manager +} + +// New builds a Resolver over the given Manager. projects is required. +func New(projects project.Manager) *Resolver { + return &Resolver{projects: projects} +} + +var _ gitworktree.RepoResolver = (*Resolver)(nil) + +// RepoPath returns the absolute repo path the project is registered against. +// A degraded project (config failed to load) and an unknown project both yield +// an error rather than the empty path that would silently mis-create worktrees. +// +// The gitworktree.RepoResolver interface is context-free, so we use +// context.Background() to call the underlying Manager. +func (r *Resolver) RepoPath(projectID domain.ProjectID) (string, error) { + res, err := r.projects.Get(context.Background(), projectID) + if err != nil { + return "", fmt.Errorf("projectresolver: lookup %q: %w", projectID, err) + } + if res.Project == nil { + return "", fmt.Errorf("projectresolver: project %q is %s; no repo path available", projectID, res.Status) + } + if res.Project.Path == "" { + return "", fmt.Errorf("projectresolver: project %q has no path", projectID) + } + return res.Project.Path, nil +} diff --git a/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver_test.go b/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver_test.go new file mode 100644 index 00000000..4721a529 --- /dev/null +++ b/backend/internal/adapters/workspace/gitworktree/projectresolver/resolver_test.go @@ -0,0 +1,69 @@ +package projectresolver_test + +import ( + "context" + "os/exec" + "testing" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/workspace/gitworktree" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/workspace/gitworktree/projectresolver" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +func TestSatisfiesRepoResolver(t *testing.T) { + var _ gitworktree.RepoResolver = (*projectresolver.Resolver)(nil) +} + +func TestRepoPath_ReturnsProjectPath(t *testing.T) { + mgr := project.NewMemoryManager() + repo := mkGitRepo(t) + added, err := mgr.Add(context.Background(), project.AddInput{Path: repo}) + if err != nil { + t.Fatal(err) + } + r := projectresolver.New(mgr) + got, err := r.RepoPath(added.ID) + if err != nil { + t.Fatal(err) + } + if got != added.Path { + t.Fatalf("got %q want %q", got, added.Path) + } +} + +func TestRepoPath_UnknownProjectReturnsError(t *testing.T) { + mgr := project.NewMemoryManager() + r := projectresolver.New(mgr) + if _, err := r.RepoPath("nope"); err == nil { + t.Fatal("expected error for unknown project") + } +} + +func TestRepoPath_DegradedProjectReturnsError(t *testing.T) { + // Degraded resolves a status, not a Project — the resolver must surface an + // error rather than the empty path that would silently mis-create worktrees. + r := projectresolver.New(stubManagerDegraded{}) + _, err := r.RepoPath("p1") + if err == nil { + t.Fatal("expected error for degraded project") + } +} + +// stubManagerDegraded only overrides Get; other Manager methods would panic if +// reached, which they should not in this test. +type stubManagerDegraded struct{ project.Manager } + +func (stubManagerDegraded) Get(context.Context, domain.ProjectID) (project.GetResult, error) { + return project.GetResult{Status: "degraded"}, nil +} + +func mkGitRepo(t *testing.T) string { + t.Helper() + dir := t.TempDir() + cmd := exec.Command("git", "init", "-q", dir) + if err := cmd.Run(); err != nil { + t.Skipf("git not available: %v", err) + } + return dir +} diff --git a/backend/internal/cli/root.go b/backend/internal/cli/root.go index ce015738..9dfd49f4 100644 --- a/backend/internal/cli/root.go +++ b/backend/internal/cli/root.go @@ -146,6 +146,7 @@ func NewRootCommand(deps Deps) *cobra.Command { root.AddCommand(newStartCommand(ctx)) root.AddCommand(newStopCommand(ctx)) root.AddCommand(newStatusCommand(ctx)) + root.AddCommand(newSpawnCommand(ctx)) root.AddCommand(newDoctorCommand(ctx)) root.AddCommand(newCompletionCommand()) root.AddCommand(newVersionCommand()) diff --git a/backend/internal/cli/spawn.go b/backend/internal/cli/spawn.go new file mode 100644 index 00000000..84e52720 --- /dev/null +++ b/backend/internal/cli/spawn.go @@ -0,0 +1,145 @@ +package cli + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/spf13/cobra" + + "github.com/aoagents/agent-orchestrator/backend/internal/config" + "github.com/aoagents/agent-orchestrator/backend/internal/runfile" +) + +// spawnRequestTimeout bounds a single POST /api/v1/sessions call. It is +// deliberately longer than DefaultDeps.HTTPClient.Timeout (which is sized for +// fast probes like /healthz and /shutdown) because spawn synchronously creates +// a worktree, launches a zellij pane, and starts the agent — that can comfortably +// exceed 2 s on a cold cache. 90 s buys headroom over the server's +// config.DefaultRequestTimeout (60 s) without hanging the CLI forever on a +// truly stuck daemon. +const spawnRequestTimeout = 90 * time.Second + +type spawnOptions struct { + project string + prompt string + agent string +} + +func newSpawnCommand(ctx *commandContext) *cobra.Command { + var opts spawnOptions + cmd := &cobra.Command{ + Use: "spawn", + Short: "Spawn a new agent session", + Args: noArgs, + RunE: func(cmd *cobra.Command, _ []string) error { + return ctx.spawnSession(cmd.Context(), cmd.OutOrStdout(), opts) + }, + } + cmd.Flags().StringVar(&opts.prompt, "prompt", "", "Initial prompt for the agent") + cmd.Flags().StringVar(&opts.project, "project", "", "Project id") + cmd.Flags().StringVar(&opts.agent, "agent", "claude-code", "Agent plugin") + return cmd +} + +type spawnAPIRequest struct { + ProjectID string `json:"projectId"` + Prompt string `json:"prompt"` + Agent string `json:"agent,omitempty"` +} + +type spawnAPIResponse struct { + SessionID string `json:"sessionId"` + WorkspacePath string `json:"workspacePath"` + RuntimeHandle string `json:"runtimeHandle"` +} + +type apiError struct { + Kind string `json:"error"` + Code string `json:"code"` + Message string `json:"message"` +} + +func (c *commandContext) spawnSession(ctx context.Context, out io.Writer, opts spawnOptions) error { + prompt := strings.TrimSpace(opts.prompt) + if prompt == "" { + return usageError{errors.New("usage: --prompt is required")} + } + project := strings.TrimSpace(opts.project) + if project == "" { + return usageError{errors.New("usage: --project is required")} + } + + cfg, err := config.Load() + if err != nil { + return err + } + + info, err := runfile.Read(cfg.RunFilePath) + if err != nil { + return fmt.Errorf("read run-file: %w", err) + } + if info == nil { + return errors.New("AO daemon is not running; start it with `ao start`") + } + + payload := spawnAPIRequest{ + ProjectID: project, + Prompt: prompt, + Agent: opts.agent, + } + body, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("encode request: %w", err) + } + + url := fmt.Sprintf("http://%s:%d/api/v1/sessions", config.LoopbackHost, info.Port) + + reqCtx, cancel := context.WithTimeout(ctx, spawnRequestTimeout) + defer cancel() + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + // Use a dedicated client (no client-level timeout) so the deadline is + // driven solely by reqCtx. The shared deps.HTTPClient is sized for + // short-lived probes; reusing it here would preempt spawn long before + // the daemon could finish provisioning. + resp, err := (&http.Client{}).Do(req) + if err != nil { + return fmt.Errorf("daemon request: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + var ok spawnAPIResponse + if err := json.Unmarshal(respBody, &ok); err != nil { + return fmt.Errorf("decode response: %w", err) + } + _, err := fmt.Fprintf(out, "Spawned session %s in %s\nAttach: zellij attach %s\n", + ok.SessionID, ok.WorkspacePath, ok.RuntimeHandle) + return err + } + + // Non-2xx: surface the server's error envelope when present, otherwise the + // raw status. Both 4xx and 5xx exit 1; usage errors (which exit 2) come from + // flag validation above. + var apiErr apiError + if jerr := json.Unmarshal(respBody, &apiErr); jerr == nil && apiErr.Kind != "" { + return fmt.Errorf("%s: %s", apiErr.Kind, apiErr.Message) + } + return fmt.Errorf("daemon returned HTTP %d", resp.StatusCode) +} diff --git a/backend/internal/cli/spawn_test.go b/backend/internal/cli/spawn_test.go new file mode 100644 index 00000000..2638a4d4 --- /dev/null +++ b/backend/internal/cli/spawn_test.go @@ -0,0 +1,230 @@ +package cli + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "testing" + "time" + + "github.com/aoagents/agent-orchestrator/backend/internal/runfile" +) + +// spawnServer wires up an httptest server, writes a runfile pointing at it, and +// returns the captured request body slot the caller assertions can read. +func spawnServer(t *testing.T, status int, respBody string) (*httptest.Server, *string) { + t.Helper() + var captured string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/api/v1/sessions" && r.Method == http.MethodPost { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read req body: %v", err) + } + captured = string(body) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _, _ = io.WriteString(w, respBody) + return + } + http.NotFound(w, r) + })) + t.Cleanup(srv.Close) + return srv, &captured +} + +func writeRunFileFor(t *testing.T, cfg testConfig, srv *httptest.Server) { + t.Helper() + port := serverPort(t, srv.URL) + if err := runfile.Write(cfg.runFile, runfile.Info{ + PID: os.Getpid(), + Port: port, + StartedAt: time.Unix(100, 0).UTC(), + }); err != nil { + t.Fatal(err) + } +} + +func TestSpawn_Success(t *testing.T) { + cfg := setConfigEnv(t) + resp := `{"sessionId":"demo-1","workspacePath":"/tmp/demo-1","runtimeHandle":"zellij-demo-1"}` + srv, captured := spawnServer(t, http.StatusCreated, resp) + writeRunFileFor(t, cfg, srv) + + out, errOut, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "demo", "--prompt", "do the thing", "--agent", "claude-code") + if err != nil { + t.Fatalf("unexpected error: %v\nstderr=%s", err, errOut) + } + if !strings.Contains(out, "Spawned session demo-1 in /tmp/demo-1") { + t.Fatalf("stdout missing spawn line:\n%s", out) + } + if !strings.Contains(out, "Attach: zellij attach zellij-demo-1") { + t.Fatalf("stdout missing attach line:\n%s", out) + } + + var req struct { + ProjectID string `json:"projectId"` + Prompt string `json:"prompt"` + Agent string `json:"agent"` + } + if err := json.Unmarshal([]byte(*captured), &req); err != nil { + t.Fatalf("decode captured req: %v\nbody=%s", err, *captured) + } + if req.ProjectID != "demo" || req.Prompt != "do the thing" || req.Agent != "claude-code" { + t.Fatalf("captured payload = %#v", req) + } +} + +func TestSpawn_DefaultsAgent(t *testing.T) { + cfg := setConfigEnv(t) + srv, captured := spawnServer(t, http.StatusCreated, + `{"sessionId":"demo-1","workspacePath":"/tmp/demo-1","runtimeHandle":"zellij-demo-1"}`) + writeRunFileFor(t, cfg, srv) + + _, errOut, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "demo", "--prompt", "x") + if err != nil { + t.Fatalf("unexpected error: %v\nstderr=%s", err, errOut) + } + if !strings.Contains(*captured, `"agent":"claude-code"`) { + t.Fatalf("agent default not sent: %s", *captured) + } +} + +func TestSpawn_EmptyPromptIsUsageError(t *testing.T) { + setConfigEnv(t) + _, _, err := executeCLI(t, Deps{}, "spawn", "--project", "demo", "--prompt", " ") + if err == nil { + t.Fatal("expected usage error for empty prompt") + } + if got := ExitCode(err); got != 2 { + t.Fatalf("exit code = %d, want 2", got) + } + if !strings.Contains(err.Error(), "--prompt is required") { + t.Fatalf("error missing usage message: %v", err) + } +} + +func TestSpawn_MissingProjectIsUsageError(t *testing.T) { + setConfigEnv(t) + _, _, err := executeCLI(t, Deps{}, "spawn", "--prompt", "x") + if err == nil { + t.Fatal("expected usage error for missing project") + } + if got := ExitCode(err); got != 2 { + t.Fatalf("exit code = %d, want 2", got) + } +} + +func TestSpawn_ServerBadRequestExits1(t *testing.T) { + cfg := setConfigEnv(t) + srv, _ := spawnServer(t, http.StatusBadRequest, + `{"error":"bad_request","code":"PROMPT_REQUIRED","message":"prompt is required"}`) + writeRunFileFor(t, cfg, srv) + + _, errOut, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "demo", "--prompt", "x") + if err == nil { + t.Fatal("expected runtime error from 400") + } + if got := ExitCode(err); got != 1 { + t.Fatalf("exit code = %d, want 1", got) + } + if !strings.Contains(err.Error(), "bad_request") && !strings.Contains(errOut, "bad_request") { + t.Fatalf("error did not include server kind: %v\nstderr=%s", err, errOut) + } +} + +func TestSpawn_ServerNotFoundExits1(t *testing.T) { + cfg := setConfigEnv(t) + srv, _ := spawnServer(t, http.StatusNotFound, + `{"error":"not_found","code":"PROJECT_NOT_FOUND","message":"Unknown project"}`) + writeRunFileFor(t, cfg, srv) + + _, _, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "missing", "--prompt", "x") + if err == nil { + t.Fatal("expected runtime error from 404") + } + if got := ExitCode(err); got != 1 { + t.Fatalf("exit code = %d, want 1", got) + } +} + +func TestSpawn_ServerInternalErrorExits1(t *testing.T) { + cfg := setConfigEnv(t) + srv, _ := spawnServer(t, http.StatusInternalServerError, + `{"error":"internal","code":"SPAWN_FAILED","message":"Failed to spawn session"}`) + writeRunFileFor(t, cfg, srv) + + _, _, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "demo", "--prompt", "x") + if err == nil { + t.Fatal("expected runtime error from 500") + } + if got := ExitCode(err); got != 1 { + t.Fatalf("exit code = %d, want 1", got) + } +} + +func TestSpawn_DaemonNotRunningExits1(t *testing.T) { + setConfigEnv(t) + // No runfile: daemon is stopped. + _, _, err := executeCLI(t, Deps{}, "spawn", "--project", "demo", "--prompt", "x") + if err == nil { + t.Fatal("expected error when daemon is not running") + } + if got := ExitCode(err); got != 1 { + t.Fatalf("exit code = %d, want 1", got) + } +} + +func TestSpawn_SessionsDisabledExits1(t *testing.T) { + cfg := setConfigEnv(t) + srv, _ := spawnServer(t, http.StatusServiceUnavailable, + `{"error":"sessions_disabled","code":"SESSIONS_DISABLED","message":"Session Manager is not wired in this daemon"}`) + writeRunFileFor(t, cfg, srv) + + _, errOut, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "demo", "--prompt", "x") + if err == nil { + t.Fatal("expected error from 503") + } + if got := ExitCode(err); got != 1 { + t.Fatalf("exit code = %d, want 1", got) + } + if !strings.Contains(err.Error(), "sessions_disabled") && !strings.Contains(errOut, "sessions_disabled") { + t.Fatalf("error did not include sessions_disabled: %v\nstderr=%s", err, errOut) + } +} + +// Sanity helper: ensure the formatted spawn message is stable. +func TestSpawn_StdoutShape(t *testing.T) { + cfg := setConfigEnv(t) + srv, _ := spawnServer(t, http.StatusCreated, fmt.Sprintf( + `{"sessionId":%q,"workspacePath":%q,"runtimeHandle":%q}`, + "proj-7", "/tmp/proj-7", "zellij-proj-7")) + writeRunFileFor(t, cfg, srv) + + out, _, err := executeCLI(t, Deps{ + ProcessAlive: func(int) bool { return true }, + }, "spawn", "--project", "proj", "--prompt", "go") + if err != nil { + t.Fatal(err) + } + want := "Spawned session proj-7 in /tmp/proj-7\nAttach: zellij attach zellij-proj-7\n" + if out != want { + t.Fatalf("stdout mismatch:\n got %q\n want %q", out, want) + } +} diff --git a/backend/internal/daemon/daemon.go b/backend/internal/daemon/daemon.go index b8d89053..228b2ddd 100644 --- a/backend/internal/daemon/daemon.go +++ b/backend/internal/daemon/daemon.go @@ -11,6 +11,7 @@ import ( "os/signal" "syscall" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/messenger/inbox" "github.com/aoagents/agent-orchestrator/backend/internal/adapters/runtime/zellij" "github.com/aoagents/agent-orchestrator/backend/internal/config" "github.com/aoagents/agent-orchestrator/backend/internal/httpd" @@ -58,27 +59,50 @@ func Run() error { return err } + // Singletons shared across the daemon. Constructing each exactly once and + // passing the same instance everywhere prevents the multi-zellij-socket / + // dual-LCM / dual-project-store hazards that fragmented adapters create. + runtimeAdapter := zellij.New(zellij.Options{}) + projects := project.NewManager(store) + messenger := inbox.New(newStoreWorkspaceLookup(store)) + // Terminal streaming: the Zellij runtime supplies the PTY-attach command and // liveness; the CDC broadcaster feeds the session-state channel. The manager // is handed to httpd, which mounts it at /mux. Raw PTY bytes never flow // through the CDC change_log — only session-state events do. - runtimeAdapter := zellij.New(zellij.Options{}) termMgr := terminal.NewManager(runtimeAdapter, cdcPipe.Broadcaster, log) defer termMgr.Close() - srv, err := httpd.NewWithDeps(cfg, log, termMgr, httpd.APIDeps{Projects: project.NewManager(store)}) + // Bring up the Lifecycle Manager + reaper, then the Session Manager stack + // over the same lcm/runtime/projects/messenger singletons. SM is constructed + // before the HTTP server so the service.Session wrapper can be plumbed into + // APIDeps and the /api/v1/sessions controller can drive it. + lcStack := startLifecycle(ctx, store, runtimeAdapter, messenger, log) + ss, err := buildSessionStack(cfg, store, runtimeAdapter, projects, lcStack.lcm, messenger) + if err != nil { + stop() + lcStack.Stop() + if cdcErr := cdcPipe.Stop(); cdcErr != nil { + log.Error("cdc pipeline shutdown", "err", cdcErr) + } + return err + } + + srv, err := httpd.NewWithDeps(cfg, log, termMgr, httpd.APIDeps{Projects: projects, Sessions: ss.svc}) if err != nil { stop() + lcStack.Stop() if cdcErr := cdcPipe.Stop(); cdcErr != nil { log.Error("cdc pipeline shutdown", "err", cdcErr) } return err } - // Bring up the Lifecycle Manager and the reaper. This makes the session - // lifecycle write path live end-to-end: reducer write -> store -> DB trigger - // -> change_log -> poller -> broadcaster. - lcStack := startLifecycle(ctx, store, runtimeAdapter, log) + // SCM observation: polling Provider -> pr.Manager -> lifecycle nudges. + // Constructed after lifecycle so the PR Manager can forward observations + // to ApplyPRObservation; runs alongside the reaper as a sibling background + // loop. Missing GITHUB_TOKEN degrades gracefully (loop is not started). + scmStk := startSCM(ctx, store, projects, lcStack.lcm, log) runErr := srv.Run(ctx) @@ -87,6 +111,7 @@ func Run() error { // via defer) avoids the LIFO trap where a Stop() that blocks on ctx-cancel // runs before the cancel — which would hang any non-signal exit path. stop() + scmStk.Stop() lcStack.Stop() if err := cdcPipe.Stop(); err != nil { log.Error("cdc pipeline shutdown", "err", err) diff --git a/backend/internal/daemon/lifecycle_wiring.go b/backend/internal/daemon/lifecycle_wiring.go index 5c04002d..23071cb9 100644 --- a/backend/internal/daemon/lifecycle_wiring.go +++ b/backend/internal/daemon/lifecycle_wiring.go @@ -10,18 +10,21 @@ import ( "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" ) -// lifecycleStack owns the runtime reaper goroutine started with the lifecycle -// reducer. The reducer itself is only used for wiring observations into storage. +// lifecycleStack owns the Lifecycle Manager (which the Session Manager and the +// reaper both depend on) and the reaper goroutine. type lifecycleStack struct { + lcm *lifecycle.Manager reaperDone <-chan struct{} } // startLifecycle constructs the Lifecycle Manager over the store and starts the -// reaper. The goroutine stops when ctx is cancelled; Stop waits for it to drain. -func startLifecycle(ctx context.Context, store *sqlite.Store, runtime ports.Runtime, logger *slog.Logger) *lifecycleStack { - lcm := lifecycle.New(store, nil) +// reaper. The messenger is passed into the LCM so PR-driven reactions (CI fail, +// review feedback, merge conflict) can nudge the agent. The goroutine stops +// when ctx is cancelled; Stop waits for it to drain. +func startLifecycle(ctx context.Context, store *sqlite.Store, runtime ports.Runtime, messenger ports.AgentMessenger, logger *slog.Logger) *lifecycleStack { + lcm := lifecycle.New(store, messenger) rp := reaper.New(lcm, store, runtime, reaper.Config{Logger: logger}) - return &lifecycleStack{reaperDone: rp.Start(ctx)} + return &lifecycleStack{lcm: lcm, reaperDone: rp.Start(ctx)} } // Stop waits for the reaper goroutine to exit. The caller must cancel the ctx diff --git a/backend/internal/daemon/scm_wiring.go b/backend/internal/daemon/scm_wiring.go new file mode 100644 index 00000000..a0390cac --- /dev/null +++ b/backend/internal/daemon/scm_wiring.go @@ -0,0 +1,61 @@ +package daemon + +import ( + "context" + "errors" + "log/slog" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/lifecycle" + "github.com/aoagents/agent-orchestrator/backend/internal/observe/scm" + "github.com/aoagents/agent-orchestrator/backend/internal/pr" + "github.com/aoagents/agent-orchestrator/backend/internal/project" + "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" +) + +// scmStack owns the SCM observation loop: a GitHub Provider, a pr.Manager +// that writes PR rows and forwards observations to lifecycle for nudges, +// and the polling goroutine that drives both. A nil-token environment +// degrades gracefully — the daemon still runs locally without SCM +// observation; PR-driven nudges (CI-failure log tail, review feedback, +// merge-conflict rebase) will not fire until a token is supplied. +type scmStack struct { + pollerDone <-chan struct{} +} + +// startSCM constructs and starts the SCM observation stack. The Provider +// reads its token from AO_GITHUB_TOKEN (preferred) or GITHUB_TOKEN, both +// via os.Getenv. Without a token, the poller is not started and a no-op +// done channel is returned — Stop is a free call in that case. +func startSCM(ctx context.Context, store *sqlite.Store, projects project.Manager, lcm *lifecycle.Manager, log *slog.Logger) *scmStack { + tokenSource := scmgithub.EnvTokenSource{EnvVars: []string{"AO_GITHUB_TOKEN", "GITHUB_TOKEN"}} + provider, err := scmgithub.NewProvider(scmgithub.ProviderOptions{Token: tokenSource}) + if err != nil { + if errors.Is(err, scmgithub.ErrNoToken) { + log.Info("scm poller: no GITHUB_TOKEN configured, SCM observation disabled") + } else { + log.Warn("scm poller: provider construction failed, SCM observation disabled", "err", err) + } + return &scmStack{pollerDone: closedDone()} + } + prMgr := pr.New(pr.Deps{Writer: store, Lifecycle: lcm}) + poller := scm.New(scm.Deps{ + Provider: provider, + Branches: provider, + Sessions: store, + Projects: projects, + PR: prMgr, + Logger: log, + }) + return &scmStack{pollerDone: poller.Start(ctx)} +} + +// Stop waits for the poller goroutine to exit. The caller must cancel the +// ctx passed to startSCM before calling Stop. +func (s *scmStack) Stop() { <-s.pollerDone } + +func closedDone() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} diff --git a/backend/internal/daemon/session_wiring.go b/backend/internal/daemon/session_wiring.go new file mode 100644 index 00000000..f9ada6ef --- /dev/null +++ b/backend/internal/daemon/session_wiring.go @@ -0,0 +1,77 @@ +package daemon + +import ( + "context" + "fmt" + "path/filepath" + + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent/claudecode" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/agent/portshim" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/messenger/inbox" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/workspace/gitworktree" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/workspace/gitworktree/projectresolver" + "github.com/aoagents/agent-orchestrator/backend/internal/config" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/lifecycle" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" + "github.com/aoagents/agent-orchestrator/backend/internal/service" + sessionmanager "github.com/aoagents/agent-orchestrator/backend/internal/session_manager" + "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" +) + +// sessionStack groups the per-session collaborators the daemon assembles around +// the Session Manager. The controller-facing surface is *service.Session, which +// wraps the internal session_manager.Manager with read-model assembly. +type sessionStack struct { + svc *service.Session + workspace ports.Workspace + messenger ports.AgentMessenger +} + +// buildSessionStack assembles the session-runtime stack: gitworktree workspace +// over a project-store-backed RepoResolver, claudecode-via-portshim agent, +// inbox-file AgentMessenger, the internal session_manager.Manager, and the +// service.Session wrapper that the HTTP controller binds to. The runtime, lcm, +// projects, and store passed in are the same instances the rest of the daemon +// uses, so there is one source of truth per collaborator. +func buildSessionStack(cfg config.Config, store *sqlite.Store, runtime ports.Runtime, projects project.Manager, lcm *lifecycle.Manager, messenger ports.AgentMessenger) (*sessionStack, error) { + ws, err := gitworktree.New(gitworktree.Options{ + ManagedRoot: filepath.Join(cfg.DataDir, "worktrees"), + RepoResolver: projectresolver.New(projects), + }) + if err != nil { + return nil, fmt.Errorf("gitworktree: %w", err) + } + sm := sessionmanager.New(sessionmanager.Deps{ + Runtime: runtime, + Agent: portshim.New(claudecode.New()), + Workspace: ws, + Store: store, + Messenger: messenger, + Lifecycle: lcm, + }) + svc := service.NewSession(sm, store) + return &sessionStack{svc: svc, workspace: ws, messenger: messenger}, nil +} + +// storeWorkspaceLookup adapts the sqlite store to the SessionWorkspace lookup +// the inbox messenger needs. WorkspacePath becomes meaningful only after the +// LCM records spawn metadata, so a session that exists but has no path is an +// error — Send must not invent a destination. +type storeWorkspaceLookup struct{ store *sqlite.Store } + +func newStoreWorkspaceLookup(store *sqlite.Store) inbox.SessionWorkspace { + return storeWorkspaceLookup{store: store} +} + +func (s storeWorkspaceLookup) WorkspacePath(ctx context.Context, id domain.SessionID) (string, error) { + rec, ok, err := s.store.GetSession(ctx, id) + if err != nil { + return "", err + } + if !ok { + return "", fmt.Errorf("session %s not found", id) + } + return rec.Metadata.WorkspacePath, nil +} diff --git a/backend/internal/daemon/wiring_test.go b/backend/internal/daemon/wiring_test.go index 6d6dae04..78cca537 100644 --- a/backend/internal/daemon/wiring_test.go +++ b/backend/internal/daemon/wiring_test.go @@ -2,11 +2,16 @@ package daemon import ( "context" + "os" + "path/filepath" "sync" "testing" "time" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/messenger/inbox" + "github.com/aoagents/agent-orchestrator/backend/internal/adapters/runtime/zellij" "github.com/aoagents/agent-orchestrator/backend/internal/cdc" + "github.com/aoagents/agent-orchestrator/backend/internal/config" "github.com/aoagents/agent-orchestrator/backend/internal/domain" "github.com/aoagents/agent-orchestrator/backend/internal/lifecycle" "github.com/aoagents/agent-orchestrator/backend/internal/ports" @@ -69,3 +74,74 @@ func TestWiring_WriteFlowsToBroadcaster(t *testing.T) { t.Fatalf("expected a change_log event for %s to reach the broadcaster, got %d events", rec.ID, len(got)) } } + +// TestWiring_SessionStackSharesSingletons asserts the daemon's wiring shape: +// startLifecycle and buildSessionStack share the same messenger and LCM, and +// the messenger reaches the same store the SM reads. Two LCMs would split +// agent-nudge state; two messengers would route inbox writes inconsistently. +// +// The pointer-identity check on ss.messenger proves buildSessionStack does not +// silently construct a second messenger; the end-to-end Send through a row the +// store owns proves the storeWorkspaceLookup is the same store SM uses. +func TestWiring_SessionStackSharesSingletons(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + store, err := sqlite.Open(t.TempDir()) + if err != nil { + t.Fatal(err) + } + defer store.Close() + cfg := config.Config{DataDir: t.TempDir()} + + projects := project.NewManager(store) + runtime := zellij.New(zellij.Options{}) + messenger := inbox.New(newStoreWorkspaceLookup(store)) + lcStack := startLifecycle(ctx, store, runtime, messenger, nil) + // Cancel-then-Stop in order: Stop drains the reaper goroutine, which only + // exits when ctx is cancelled. A naive `defer cancel(); defer lcStack.Stop()` + // reverses this (defer is LIFO) and deadlocks. + t.Cleanup(func() { + cancel() + lcStack.Stop() + }) + + if lcStack.lcm == nil { + t.Fatal("lifecycleStack must expose its LCM so the SM can share it") + } + ss, err := buildSessionStack(cfg, store, runtime, projects, lcStack.lcm, messenger) + if err != nil { + t.Fatalf("buildSessionStack: %v", err) + } + if ss.svc == nil || ss.workspace == nil || ss.messenger == nil { + t.Fatal("session stack must be fully populated") + } + if ss.messenger != messenger { + t.Error("buildSessionStack must reuse the messenger it is given, not construct a second one") + } + + // End-to-end: a session row in the shared store should be reachable through + // the messenger that buildSessionStack wired up. A second store would + // surface as "session not found" here. + if err := store.Upsert(ctx, project.Row{ID: "p", Path: "/repo/p", RegisteredAt: time.Now()}); err != nil { + t.Fatal(err) + } + workspaceDir := t.TempDir() + rec, err := store.CreateSession(ctx, domain.SessionRecord{ + ProjectID: "p", Kind: domain.KindWorker, + Activity: domain.Activity{State: domain.ActivityIdle, LastActivityAt: time.Now()}, + Metadata: domain.SessionMetadata{WorkspacePath: workspaceDir}, + }) + if err != nil { + t.Fatal(err) + } + if err := ss.messenger.Send(ctx, rec.ID, "hello"); err != nil { + t.Fatalf("messenger.Send through shared store lookup: %v", err) + } + entries, err := os.ReadDir(filepath.Join(workspaceDir, ".ao", "inbox")) + if err != nil { + t.Fatalf("inbox dir: %v", err) + } + if len(entries) != 1 { + t.Fatalf("want 1 inbox file, got %d", len(entries)) + } +} diff --git a/backend/internal/integration/scm_poller_test.go b/backend/internal/integration/scm_poller_test.go new file mode 100644 index 00000000..021fed82 --- /dev/null +++ b/backend/internal/integration/scm_poller_test.go @@ -0,0 +1,185 @@ +package integration + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/observe/scm" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// TestSCMPollerEndToEnd boots store + LCM + pr.Manager + the scm.Poller +// against an httptest GitHub stub, ticks once, and asserts: +// - the poller resolved the PR URL via branch discovery +// - pr.Manager persisted the PR row (PRWriter side of the bus) +// - lifecycle.ApplyPRObservation fired the CI-failure nudge to the messenger +// +// This is the seam-by-seam validation that aa-37's spec describes: from +// SCM observation to PR row to agent nudge, with every dependency the +// daemon wires in production. +func TestSCMPollerEndToEnd(t *testing.T) { + ctx := context.Background() + st := newStack(t) + + if err := st.store.Upsert(ctx, project.Row{ID: "acme", Path: "/repo/acme", RepoOriginURL: "https://github.com/acme/repo.git", RegisteredAt: time.Now()}); err != nil { + t.Fatal(err) + } + sess, err := st.sm.Spawn(ctx, ports.SpawnConfig{ProjectID: "acme", Kind: domain.KindWorker, Branch: "feat/x", Prompt: "fix CI"}) + if err != nil { + t.Fatal(err) + } + + // The PR URL the GitHub stub will report for branch acme:feat/x. + prURL := "https://github.com/acme/repo/pull/77" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/pulls": + if got := r.URL.Query().Get("head"); got != "acme:feat/x" { + t.Errorf("pulls list head = %q, want acme:feat/x", got) + } + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 77, "state": "open", "html_url": prURL, "updated_at": "2026-05-15T10:00:00Z"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/pulls/77": + w.Header().Set("ETag", `W/"v1"`) + _ = json.NewEncoder(w).Encode(map[string]any{ + "number": 77, + "state": "open", + "draft": false, + "merged": false, + "merged_at": nil, + "html_url": prURL, + "head": map[string]any{"ref": "feat/x", "sha": "deadbeef"}, + "base": map[string]any{"ref": "main"}, + "mergeable": false, + "rebaseable": true, + "mergeable_state": "blocked", + "merge_state_status": "BLOCKED", + }) + case r.Method == http.MethodPost && r.URL.Path == "/graphql": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "number": 77, + "url": prURL, + "state": "OPEN", + "isDraft": false, + "merged": false, + "closed": false, + "mergeable": "MERGEABLE", + "mergeStateStatus": "BLOCKED", + "reviewDecision": "REVIEW_REQUIRED", + "headRefOid": "deadbeef", + "commits": map[string]any{"nodes": []any{ + map[string]any{"commit": map[string]any{ + "oid": "deadbeef", + "statusCheckRollup": map[string]any{ + "state": "FAILURE", + "contexts": map[string]any{ + "nodes": []any{ + map[string]any{ + "__typename": "CheckRun", + "name": "build", + "status": "COMPLETED", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/acme/repo/runs/9001", + "databaseId": float64(9001), + }, + }, + "pageInfo": map[string]any{"hasNextPage": false}, + }, + }, + }}, + }}, + "reviewThreads": map[string]any{"nodes": []any{}}, + }, + }, + }, + }) + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/actions/jobs/9001/logs": + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte("FAIL TestX\nFAIL TestY\n")) + default: + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "no handler", http.StatusNotImplemented) + } + })) + t.Cleanup(server.Close) + + provider, err := scmgithub.NewProvider(scmgithub.ProviderOptions{ + Token: scmgithub.StaticTokenSource("tkn"), + HTTPClient: server.Client(), + RESTBase: server.URL, + GraphQLURL: server.URL + "/graphql", + }) + if err != nil { + t.Fatal(err) + } + + projects := project.NewManager(st.store) + poller := scm.New(scm.Deps{ + Provider: provider, + Branches: provider, + Sessions: st.store, + Projects: projects, + PR: st.prm, + Interval: time.Hour, // ticker won't fire — we call Tick directly + ObserveTimeout: 5 * time.Second, + RemoteResolver: func(context.Context, string) (string, error) { + // The project Row.RepoOriginURL is set above, so this fallback + // should never be called; failing loudly catches a regression + // where the poller silently shells out instead of using + // project.Repo. + t.Fatalf("remote resolver should not be invoked when project.Repo is set") + return "", nil + }, + }) + + if err := poller.Tick(ctx); err != nil { + t.Fatalf("poller.Tick: %v", err) + } + + got, ok, err := st.store.GetPR(ctx, prURL) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("pr row not written for %s", prURL) + } + if got.SessionID != sess.ID { + t.Errorf("pr.SessionID = %q, want %q", got.SessionID, sess.ID) + } + if got.CI != domain.CIFailing { + t.Errorf("pr.CI = %q, want %q", got.CI, domain.CIFailing) + } + checks, err := st.store.ListChecks(ctx, prURL) + if err != nil { + t.Fatal(err) + } + if len(checks) != 1 || checks[0].Status != domain.PRCheckFailed { + t.Fatalf("checks = %+v", checks) + } + + if len(st.msg.msgs) != 1 { + t.Fatalf("expected exactly 1 lifecycle nudge, got %d (a double-nudge would regress sendOnce)", len(st.msg.msgs)) + } + if !strings.Contains(st.msg.msgs[0], "CI is failing") { + t.Errorf("messenger did not receive CI-failure body; got %q", st.msg.msgs[0]) + } + if !strings.Contains(st.msg.msgs[0], "FAIL TestX") { + t.Errorf("messenger did not receive log-tail body; got %q", st.msg.msgs[0]) + } +} diff --git a/backend/internal/observe/scm/poller.go b/backend/internal/observe/scm/poller.go new file mode 100644 index 00000000..e907d954 --- /dev/null +++ b/backend/internal/observe/scm/poller.go @@ -0,0 +1,364 @@ +// Package scm implements the OBSERVE-layer polling loop that drives +// SCM (pull-request) observations into the PR Manager and Lifecycle +// Manager. The loop is intentionally dumb: every tick it lists alive +// sessions, finds the open PR for each session's branch, asks the +// Provider for an observation, and hands the result to the PR +// Manager (which transactionally writes the row and forwards to +// lifecycle for nudges). +// +// The poller does not own any reaction logic. CI-failure log-tail +// nudges, review-feedback nudges (capped at reviewMaxNudge), and +// merge-conflict rebase nudges all live in lifecycle.ApplyPRObservation. +// Polling is uniform 30s for v1; per-PR adaptive cadence is a follow-up. +package scm + +import ( + "context" + "errors" + "log/slog" + "net/url" + "os/exec" + "strings" + "sync/atomic" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// DefaultInterval is the cadence used when Deps.Interval is zero. +const DefaultInterval = 30 * time.Second + +// DefaultObserveTimeout caps one Provider.Observe call so a single hung +// request can't stall the whole tick. +const DefaultObserveTimeout = 15 * time.Second + +// Provider observes one PR by its canonical URL. The github adapter +// satisfies this; other SCM adapters (gitlab, etc.) can implement the +// same surface without touching the poller. +type Provider interface { + Observe(ctx context.Context, prURL string) (ports.PRObservation, error) +} + +// BranchPRFinder resolves a session's branch to its open PR URL. v1 +// uses this because sessions do not (yet) carry a PR URL field; when +// they do, the poller will prefer the stored URL and only fall back +// here. An empty return with nil error means "no matching open PR". +type BranchPRFinder interface { + FindOpenPRForBranch(ctx context.Context, owner, repo, branch string) (string, error) +} + +// sessionLister narrows the sqlite store to what the poller needs. +type sessionLister interface { + ListAllSessions(ctx context.Context) ([]domain.SessionRecord, error) +} + +// projectGetter narrows project.Manager to its read path. +type projectGetter interface { + Get(ctx context.Context, id domain.ProjectID) (project.GetResult, error) +} + +// prApplier is the seam over pr.Manager.ApplyObservation — which itself +// persists the PR row and forwards to lifecycle for nudges. Keeping +// this one method on the seam means the poller never needs to know +// about lifecycle directly. +type prApplier interface { + ApplyObservation(ctx context.Context, id domain.SessionID, o ports.PRObservation) error +} + +// remoteResolver shells out to git to read a repo's origin URL. +// Injected so tests don't require a real git checkout. +type remoteResolver func(ctx context.Context, projectPath string) (string, error) + +// Deps groups every collaborator the Poller needs. Zero-valued +// optional fields fall back to safe defaults (slog.Default, 30s tick, +// 15s observe deadline, real `git` for origin lookup). +type Deps struct { + Provider Provider + Branches BranchPRFinder + Sessions sessionLister + Projects projectGetter + PR prApplier + Logger *slog.Logger + Interval time.Duration + ObserveTimeout time.Duration + RemoteResolver func(ctx context.Context, projectPath string) (string, error) +} + +// Poller is the SCM observation loop. Construct it with New, start the +// background goroutine with Start. Tick is exported so daemon and tests +// can drive a single cycle synchronously. +type Poller struct { + provider Provider + branches BranchPRFinder + sessions sessionLister + projects projectGetter + pr prApplier + logger *slog.Logger + interval time.Duration + observeTimeout time.Duration + remoteResolver remoteResolver + + healthy atomic.Bool +} + +// New constructs a Poller from its dependencies. +func New(d Deps) *Poller { + p := &Poller{ + provider: d.Provider, + branches: d.Branches, + sessions: d.Sessions, + projects: d.Projects, + pr: d.PR, + logger: d.Logger, + interval: d.Interval, + observeTimeout: d.ObserveTimeout, + remoteResolver: d.RemoteResolver, + } + if p.interval <= 0 { + p.interval = DefaultInterval + } + if p.observeTimeout <= 0 { + p.observeTimeout = DefaultObserveTimeout + } + if p.logger == nil { + p.logger = slog.Default() + } + if p.remoteResolver == nil { + p.remoteResolver = defaultRemoteResolver + } + p.healthy.Store(true) + return p +} + +// Healthy reports whether the SCM provider's authentication has been +// observed working since the poller started. It starts true and flips +// to false the first time the provider returns ErrAuthFailed; it does +// NOT auto-recover, because a single subsequent success could be an +// ETag-cached 304 that didn't actually exercise the token. A future +// health route consumes this bit; clearing it after token rotation is +// a daemon-restart concern. +func (p *Poller) Healthy() bool { return p.healthy.Load() } + +// Start launches the background goroutine and returns a channel that +// closes once the loop has exited. The loop exits when ctx is cancelled; +// callers should wait on the returned channel before tearing down the +// PR Manager / lifecycle / store dependencies. +func (p *Poller) Start(ctx context.Context) <-chan struct{} { + done := make(chan struct{}) + go p.loop(ctx, done) + return done +} + +func (p *Poller) loop(ctx context.Context, done chan<- struct{}) { + defer close(done) + t := time.NewTicker(p.interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := p.Tick(ctx); err != nil { + p.logger.Error("scm poller: tick failed", "err", err) + } + } + } +} + +// Tick runs one observation cycle. +// +// It lists every session, skips terminated rows and rows without a +// branch, resolves each remaining session's open PR URL via the +// BranchPRFinder, asks the Provider for an observation under a +// per-call deadline, and hands a successful observation to the PR +// Manager. Errors are classified by sentinel: +// - ErrRateLimited: short-circuit the rest of the tick (don't burn +// through remaining sessions while GitHub is asking us to back off). +// - ErrAuthFailed: flip Healthy() to false; continue with the next +// session so a single misconfigured token does not stall the loop. +// - other: log warn, continue. +// +// A session-listing failure is the only error Tick propagates; it +// short-circuits the cycle just like the reaper. +func (p *Poller) Tick(ctx context.Context) error { + sessions, err := p.sessions.ListAllSessions(ctx) + if err != nil { + return err + } + for _, sess := range sessions { + if sess.IsTerminated || sess.Metadata.Branch == "" { + continue + } + if err := ctx.Err(); err != nil { + return err + } + stop := p.pollOne(ctx, sess) + if stop { + return nil + } + } + return nil +} + +// pollOne handles one session. Returns stop=true when the caller +// should short-circuit the remaining sessions (rate-limit signal). +func (p *Poller) pollOne(ctx context.Context, sess domain.SessionRecord) bool { + prURL, err := p.resolvePRURL(ctx, sess) + if err != nil { + return p.classify(sess.ID, "resolve-pr-url", err) + } + if prURL == "" { + p.logger.Debug("scm poller: no open PR for branch, skipping", + "session", sess.ID, "branch", sess.Metadata.Branch) + return false + } + + pollCtx, cancel := context.WithTimeout(ctx, p.observeTimeout) + defer cancel() + obs, err := p.provider.Observe(pollCtx, prURL) + if err != nil { + return p.classify(sess.ID, "observe", err) + } + if !obs.Fetched { + p.logger.Debug("scm poller: observation not fetched, skipping", + "session", sess.ID, "url", prURL) + return false + } + if err := p.pr.ApplyObservation(ctx, sess.ID, obs); err != nil { + p.logger.Warn("scm poller: apply observation failed", + "session", sess.ID, "err", err) + } + return false +} + +// classify maps a Provider/lookup error to the loop's continue/stop +// decision and surfaces it in the logs. Auth-class failures flip the +// Healthy() bool; rate-limit signals stop the tick. +func (p *Poller) classify(sid domain.SessionID, stage string, err error) bool { + switch { + case errors.Is(err, scmgithub.ErrRateLimited): + p.logger.Warn("scm poller: rate limited, skipping rest of tick", + "session", sid, "stage", stage, "err", err) + return true + case errors.Is(err, scmgithub.ErrAuthFailed): + p.healthy.Store(false) + p.logger.Error("scm poller: auth failed, provider marked unhealthy", + "session", sid, "stage", stage, "err", err) + return false + default: + p.logger.Warn("scm poller: error", + "session", sid, "stage", stage, "err", err) + return false + } +} + +// resolvePRURL finds the open PR URL for a session's branch. +// +// v1 strategy: branch-based discovery. Look up the session's project, +// derive owner/repo from project.Repo (which today holds the origin URL), +// falling back to `git remote get-url origin` against the project's +// on-disk path, then ask BranchPRFinder. When neither yields an +// owner/repo, the session is silently skipped — that is not a poller bug, +// it's a project that hasn't been configured for SCM observation. +// +// When the session record grows a stored PR URL field (separate PR), +// this function should prefer it over branch discovery. +func (p *Poller) resolvePRURL(ctx context.Context, sess domain.SessionRecord) (string, error) { + if p.branches == nil { + return "", nil + } + res, err := p.projects.Get(ctx, sess.ProjectID) + if err != nil { + return "", err + } + if res.Project == nil { + return "", nil + } + owner, repo, ok := ownerRepoFromProject(*res.Project) + if !ok { + remoteURL, err := p.remoteResolver(ctx, res.Project.Path) + if err != nil { + p.logger.Debug("scm poller: git remote lookup failed, skipping session", + "session", sess.ID, "project", sess.ProjectID, "err", err) + return "", nil + } + owner, repo, ok = parseGitHubRemote(remoteURL) + if !ok { + return "", nil + } + } + return p.branches.FindOpenPRForBranch(ctx, owner, repo, sess.Metadata.Branch) +} + +// ownerRepoFromProject derives (owner, repo) from a Project. Today +// project.Repo holds the origin URL (despite the type comment claiming +// "owner/name") — so we try both shapes here without touching the +// project package. +func ownerRepoFromProject(p project.Project) (owner, repo string, ok bool) { + repoField := strings.TrimSpace(p.Repo) + if repoField == "" { + return "", "", false + } + if o, r, ok := parseGitHubRemote(repoField); ok { + return o, r, true + } + return "", "", false +} + +// parseGitHubRemote accepts both URL- and SSH-style remote strings and +// the bare "owner/repo" shorthand. It is intentionally host-agnostic — +// the github.Provider will reject non-github hosts at Observe time, so +// rejecting them here would just duplicate that check and silently drop +// legitimately-configured projects on enterprise hosts. +// +// Recognised forms: +// - https://github.com/owner/repo[.git] +// - http(s)://host/owner/repo[.git] +// - git@host:owner/repo[.git] +// - ssh://git@host/owner/repo[.git] +// - owner/repo +func parseGitHubRemote(s string) (owner, repo string, ok bool) { + s = strings.TrimSpace(s) + if s == "" { + return "", "", false + } + switch { + case strings.HasPrefix(s, "git@"): + idx := strings.Index(s, ":") + if idx < 0 { + return "", "", false + } + s = s[idx+1:] + case strings.Contains(s, "://"): + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return "", "", false + } + s = strings.TrimPrefix(u.Path, "/") + } + s = strings.TrimSuffix(s, ".git") + parts := strings.SplitN(s, "/", 3) + if len(parts) < 2 { + return "", "", false + } + owner = strings.TrimSpace(parts[0]) + repo = strings.TrimSpace(parts[1]) + if owner == "" || repo == "" { + return "", "", false + } + return owner, repo, true +} + +func defaultRemoteResolver(ctx context.Context, projectPath string) (string, error) { + if strings.TrimSpace(projectPath) == "" { + return "", errors.New("scm poller: project has no path") + } + out, err := exec.CommandContext(ctx, "git", "-C", projectPath, "remote", "get-url", "origin").Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} diff --git a/backend/internal/observe/scm/poller_test.go b/backend/internal/observe/scm/poller_test.go new file mode 100644 index 00000000..64e09162 --- /dev/null +++ b/backend/internal/observe/scm/poller_test.go @@ -0,0 +1,516 @@ +package scm + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// --------------------------------------------------------------------------- +// Fakes +// --------------------------------------------------------------------------- + +type fakeProvider struct { + mu sync.Mutex + calls []string + obs map[string]ports.PRObservation + errs map[string]error + hangFor time.Duration +} + +func (f *fakeProvider) Observe(ctx context.Context, prURL string) (ports.PRObservation, error) { + f.mu.Lock() + f.calls = append(f.calls, prURL) + hang := f.hangFor + f.mu.Unlock() + if hang > 0 { + select { + case <-time.After(hang): + case <-ctx.Done(): + return ports.PRObservation{URL: prURL}, ctx.Err() + } + } + f.mu.Lock() + defer f.mu.Unlock() + if err, ok := f.errs[prURL]; ok { + return ports.PRObservation{URL: prURL}, err + } + if o, ok := f.obs[prURL]; ok { + return o, nil + } + return ports.PRObservation{URL: prURL}, nil +} + +func (f *fakeProvider) seenURLs() []string { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]string, len(f.calls)) + copy(out, f.calls) + return out +} + +type fakeBranches struct { + mu sync.Mutex + urls map[string]string // owner/repo/branch -> prURL + err error + callCount int +} + +func (f *fakeBranches) FindOpenPRForBranch(_ context.Context, owner, repo, branch string) (string, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.callCount++ + if f.err != nil { + return "", f.err + } + return f.urls[owner+"/"+repo+"/"+branch], nil +} + +type fakeSessions struct { + sessions []domain.SessionRecord + err error +} + +func (f *fakeSessions) ListAllSessions(context.Context) ([]domain.SessionRecord, error) { + if f.err != nil { + return nil, f.err + } + out := make([]domain.SessionRecord, len(f.sessions)) + copy(out, f.sessions) + return out, nil +} + +type fakeProjects struct { + projects map[domain.ProjectID]project.Project +} + +func (f *fakeProjects) Get(_ context.Context, id domain.ProjectID) (project.GetResult, error) { + p, ok := f.projects[id] + if !ok { + return project.GetResult{}, errors.New("project not found") + } + pp := p + return project.GetResult{Status: "ok", Project: &pp}, nil +} + +type fakePR struct { + mu sync.Mutex + applied []appliedObs + applyErr error +} + +type appliedObs struct { + id domain.SessionID + obs ports.PRObservation +} + +func (f *fakePR) ApplyObservation(_ context.Context, id domain.SessionID, o ports.PRObservation) error { + f.mu.Lock() + defer f.mu.Unlock() + f.applied = append(f.applied, appliedObs{id: id, obs: o}) + return f.applyErr +} + +func (f *fakePR) records() []appliedObs { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]appliedObs, len(f.applied)) + copy(out, f.applied) + return out +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestPoller(t *testing.T, d Deps) *Poller { + t.Helper() + if d.Logger == nil { + d.Logger = slog.New(slog.NewTextHandler(testWriter{t}, &slog.HandlerOptions{Level: slog.LevelDebug})) + } + return New(d) +} + +type testWriter struct{ t *testing.T } + +func (w testWriter) Write(p []byte) (int, error) { + w.t.Log(string(p)) + return len(p), nil +} + +func aliveSession(id domain.SessionID, project domain.ProjectID, branch string) domain.SessionRecord { + return domain.SessionRecord{ + ID: id, + ProjectID: project, + Kind: domain.KindWorker, + Metadata: domain.SessionMetadata{Branch: branch, RuntimeHandleID: "h"}, + } +} + +func terminatedSession(id domain.SessionID, project domain.ProjectID, branch string) domain.SessionRecord { + s := aliveSession(id, project, branch) + s.IsTerminated = true + return s +} + +func githubProject(id domain.ProjectID) project.Project { + return project.Project{ID: id, Path: "/repo/" + string(id), Repo: "https://github.com/acme/repo.git"} +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestTickObservesAliveSessionAndAppliesObservation(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + terminatedSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/11": {Fetched: true, URL: "https://github.com/acme/repo/pull/11", Number: 11, CI: domain.CIPassing}, + }} + prm := &fakePR{} + + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: sessions, + Projects: projects, + PR: prm, + }) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick error: %v", err) + } + + if got := provider.seenURLs(); len(got) != 1 || got[0] != "https://github.com/acme/repo/pull/11" { + t.Fatalf("provider.Observe calls = %v, want [pull/11] (terminated session skipped)", got) + } + rec := prm.records() + if len(rec) != 1 || rec[0].id != "s-1" || rec[0].obs.Number != 11 { + t.Fatalf("pr.ApplyObservation = %+v, want one call for s-1/pull-11", rec) + } +} + +func TestTickSkipsApplyWhenNotFetched(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{"acme/repo/feat/x": "https://github.com/acme/repo/pull/11"}} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/11": {Fetched: false, URL: "https://github.com/acme/repo/pull/11"}, + }} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("ApplyObservation called %d times on !Fetched obs", len(got)) + } +} + +func TestTickSkipsSessionsWithoutBranch(t *testing.T) { + ctx := context.Background() + noBranch := aliveSession("s-1", "acme", "") + sessions := &fakeSessions{sessions: []domain.SessionRecord{noBranch}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 0 { + t.Fatalf("provider should not be called for session without branch, got %v", got) + } + if got := branches.callCount; got != 0 { + t.Fatalf("branches lookup should not be called for session without branch, got %d", got) + } +} + +func TestTickSkipsSessionsWithNoOpenPR(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{}} // empty: no PR exists + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 0 { + t.Fatalf("provider should not be called when no PR found, got %v", got) + } +} + +func TestTickRateLimitShortCircuits(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": scmgithub.ErrRateLimited, + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 1 { + t.Fatalf("expected exactly one Observe call (rate-limit short-circuits), got %v", got) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("no observations should be applied after rate-limit, got %d", len(got)) + } +} + +func TestTickAuthFailureMarksUnhealthyAndContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": scmgithub.ErrAuthFailed, + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12, CI: domain.CIPassing}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + if !p.Healthy() { + t.Fatalf("poller should start healthy") + } + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if p.Healthy() { + t.Fatalf("poller should be unhealthy after ErrAuthFailed") + } + if got := provider.seenURLs(); len(got) != 2 { + t.Fatalf("expected provider to be called for both sessions, got %v", got) + } + rec := prm.records() + if len(rec) != 1 || rec[0].id != "s-2" { + t.Fatalf("expected one apply for s-2 after auth failure on s-1, got %+v", rec) + } +} + +func TestTickProjectLookupErrorContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "missing", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 1 || got[0].id != "s-2" { + t.Fatalf("expected s-2 applied after project-lookup err on s-1, got %+v", got) + } + if !p.Healthy() { + t.Fatalf("project lookup error should not mark unhealthy") + } +} + +func TestTickGenericErrorContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": errors.New("transient network blip"), + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 1 || got[0].id != "s-2" { + t.Fatalf("expected s-2 applied after generic err on s-1, got %+v", got) + } + if !p.Healthy() { + t.Fatalf("generic errors should not mark unhealthy") + } +} + +func TestPerCallDeadline(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{"acme/repo/feat/x": "https://github.com/acme/repo/pull/11"}} + provider := &fakeProvider{hangFor: 200 * time.Millisecond} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: sessions, + Projects: projects, + PR: prm, + ObserveTimeout: 10 * time.Millisecond, + }) + start := time.Now() + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if elapsed := time.Since(start); elapsed > 150*time.Millisecond { + t.Fatalf("Tick took %v — per-call deadline did not fire", elapsed) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("no apply on deadline timeout, got %d", len(got)) + } +} + +func TestStartDrainsOnContextCancel(t *testing.T) { + sessions := &fakeSessions{} + projects := &fakeProjects{} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm, + Interval: 5 * time.Millisecond, + }) + ctx, cancel := context.WithCancel(context.Background()) + done := p.Start(ctx) + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("poller did not exit within 1s of ctx cancel") + } +} + +func TestStartTicksRepeatedly(t *testing.T) { + var ticks atomic.Int32 + sessions := &fakeSessions{} + projects := &fakeProjects{} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: &countingSessions{wrap: sessions, ticks: &ticks}, + Projects: projects, + PR: prm, + Interval: 5 * time.Millisecond, + }) + ctx, cancel := context.WithCancel(context.Background()) + done := p.Start(ctx) + deadline := time.After(500 * time.Millisecond) +loop: + for ticks.Load() < 3 { + select { + case <-deadline: + break loop + case <-time.After(2 * time.Millisecond): + } + } + cancel() + <-done + if ticks.Load() < 2 { + t.Fatalf("expected at least 2 ticks, got %d", ticks.Load()) + } +} + +// countingSessions ticks the counter each time ListAllSessions is called. +type countingSessions struct { + wrap *fakeSessions + ticks *atomic.Int32 +} + +func (c *countingSessions) ListAllSessions(ctx context.Context) ([]domain.SessionRecord, error) { + c.ticks.Add(1) + return c.wrap.ListAllSessions(ctx) +} + +// --------------------------------------------------------------------------- +// owner/repo derivation +// --------------------------------------------------------------------------- + +func TestParseGitHubRemote(t *testing.T) { + tests := []struct{ in, owner, repo string }{ + {"https://github.com/acme/repo.git", "acme", "repo"}, + {"https://github.com/acme/repo", "acme", "repo"}, + {"git@github.com:acme/repo.git", "acme", "repo"}, + {"ssh://git@github.com/acme/repo.git", "acme", "repo"}, + {"acme/repo", "acme", "repo"}, + {"", "", ""}, + {"https://gitlab.com/x/y", "x", "y"}, // host-agnostic parser; provider rejects non-GitHub at Observe time + } + for _, tc := range tests { + owner, repo, ok := parseGitHubRemote(tc.in) + if tc.owner == "" { + if ok { + t.Errorf("parseGitHubRemote(%q): expected !ok, got %q/%q", tc.in, owner, repo) + } + continue + } + if !ok || owner != tc.owner || repo != tc.repo { + t.Errorf("parseGitHubRemote(%q) = %q/%q ok=%v; want %q/%q true", tc.in, owner, repo, ok, tc.owner, tc.repo) + } + } +} diff --git a/backend/internal/session_manager/manager.go b/backend/internal/session_manager/manager.go index 18c21c5c..1307ce1f 100644 --- a/backend/internal/session_manager/manager.go +++ b/backend/internal/session_manager/manager.go @@ -94,7 +94,16 @@ func (m *Manager) Spawn(ctx context.Context, cfg ports.SpawnConfig) (domain.Sess } id := rec.ID - ws, err := m.workspace.Create(ctx, ports.WorkspaceConfig{ProjectID: cfg.ProjectID, SessionID: id, Branch: cfg.Branch}) + // The CLI/API does not expose a branch flag, but the gitworktree adapter + // requires a non-empty branch (and cannot have two worktrees on the same + // branch). The session id is assigned by the store above, so this is the + // only layer where a per-session default ref can be computed. + branch := cfg.Branch + if branch == "" { + branch = "ao/" + string(id) + } + + ws, err := m.workspace.Create(ctx, ports.WorkspaceConfig{ProjectID: cfg.ProjectID, SessionID: id, Branch: branch}) if err != nil { m.markSpawnFailedTerminated(ctx, id) return domain.SessionRecord{}, fmt.Errorf("spawn %s: workspace: %w", id, err) diff --git a/backend/internal/session_manager/manager_test.go b/backend/internal/session_manager/manager_test.go index 22ef7875..0e8c9dd7 100644 --- a/backend/internal/session_manager/manager_test.go +++ b/backend/internal/session_manager/manager_test.go @@ -159,6 +159,33 @@ func TestSpawn_AssignsIDAndGoesIdle(t *testing.T) { t.Fatal("handle not folded") } } + +// SpawnConfig.Branch is optional from the API surface (the CLI does not expose +// it). The SM is the only layer with the session id (assigned by the store +// inside Spawn), so it defaults the branch to a per-session ref. The +// gitworktree workspace requires a non-empty branch and cannot have two +// worktrees on the same branch — so the default must be unique per session. +func TestSpawn_DefaultsBranchPerSession_WhenUnset(t *testing.T) { + m, st, _, _ := newManager() + if _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Prompt: "do it"}); err != nil { + t.Fatal(err) + } + if got := st.sessions["mer-1"].Metadata.Branch; got != "ao/mer-1" { + t.Fatalf("default branch: got %q, want %q", got, "ao/mer-1") + } +} + +// An explicit branch in SpawnConfig must win over the default — the API/CLI +// layer can still pin a branch when it wants to. +func TestSpawn_HonorsExplicitBranch(t *testing.T) { + m, st, _, _ := newManager() + if _, err := m.Spawn(ctx, ports.SpawnConfig{ProjectID: "mer", Kind: domain.KindWorker, Prompt: "do it", Branch: "feature/x"}); err != nil { + t.Fatal(err) + } + if got := st.sessions["mer-1"].Metadata.Branch; got != "feature/x" { + t.Fatalf("explicit branch: got %q, want %q", got, "feature/x") + } +} func TestSpawn_RollsBackOnRuntimeFailure(t *testing.T) { m, st, _, ws := newManager() m.runtime = &fakeRuntime{createErr: errors.New("boom")} diff --git a/docs/agent/README.md b/docs/agent/README.md new file mode 100644 index 00000000..9c079926 --- /dev/null +++ b/docs/agent/README.md @@ -0,0 +1,118 @@ +# Agent Adapter PRD + +## Goal + +Agent adapters let Better-AO run and observe different CLI coding agents without hardcoding agent-specific behavior into the spawn engine. Every CLI coding agent must implement the contract in `backend/internal/adapters/agent/agent.go`. + +The important current slice is hook-derived session info. Better-AO should know a running worker's native agent session id, title, and summary from agent hooks installed in the per-session worktree, not from scanning agent transcript/cache files. + +## Current Decisions + +- Better-AO only needs to derive session info for Better-AO-managed sessions. +- Hook installation happens at worktree/session creation time. +- `SessionInfo` reads normalized metadata persisted in Better-AO's session store. +- `SessionInfo` must not infer display info by reading agent transcript/cache files. +- `SummaryIsFallback` is removed from `agent.SessionInfo`. +- `TranscriptPath` is removed from `agent.SessionInfo`. +- `Title` and `Summary` are both first-class fields. +- `Title` is derived from the user prompt hook. +- `Summary` is derived from the stop/final assistant hook. +- Agent adapter `Metadata` should stay nil/empty unless an adapter has a real extra field that does not belong in the normalized contract. + +## Agent Contract + +The shared contract lives in `backend/internal/adapters/agent/agent.go`. + +Required adapter behavior: + +- `GetConfigSpec` describes user-facing agent config. +- `GetLaunchCommand` builds the native agent command. +- `GetPromptDeliveryStrategy` says whether the prompt is passed in argv or sent after launch. +- `GetAgentHooks` installs or merges Better-AO hooks into the agent's workspace-local hook config. +- `GetRestoreCommand` builds a native resume command when restore is supported. +- `SessionInfo` returns normalized metadata: + - `AgentSessionID` + - `Title` + - `Summary` + - optional adapter-specific `Metadata` + +Implementation layout: + +- Agent-specific hook installation and embedded hook templates should live beside the agent adapter in `backend/internal/adapters/agent//hooks.go`. +- Launch, restore, and session-info behavior can stay in the main agent implementation unless the file grows enough to justify another split. + +## Metadata Keys + +Hook callbacks persist these normalized keys in the session metadata JSON blob: + +- `agentSessionId`: native agent session id. +- `title`: display title, derived from the first user prompt hook for the session. +- `summary`: display summary, derived from the final assistant message exposed to the stop hook. + +The original spawn prompt may remain in metadata as `prompt` for launch/debug fallback, but `title` is the preferred display title once hook metadata lands. + +## Hook Methodology + +Agent adapters install hooks into the worktree-local config owned by the native agent. + +Hook callbacks run through hidden Better-AO CLI commands: + +```text +better-ao hooks +``` + +The callback: + +1. Reads the native hook JSON payload from stdin. +2. Reads the Better-AO session id from `BETTER_AO_SESSION_ID`. +3. Opens `~/.better-ao/state.db`. +4. Merges normalized metadata into the matching session row. +5. Publishes `session.updated` when metadata changed. +6. Prints `{}` and exits 0 for successful no-op cases, including non-AO sessions or missing rows. + +The spawn engine inserts the Better-AO session row before launching the durability provider so early startup hooks can update an existing row. If launch fails after insertion, spawn deletes the row during rollback. + +## Restore Boundary + +Session display info and native restore are separate concerns. + +Some agents may still need transcript-derived or deterministic native ids for `GetRestoreCommand` until restore is redesigned for that agent. Do not remove restore support just because `SessionInfo` stops reading transcripts. + +For `SessionInfo`, transcript/cache files are not an acceptable source of title or summary. + +## UI And Events + +The workspace adapter prefers: + +- `metadata.title` as session title. +- `metadata.summary` as session description. +- `metadata.prompt` only as fallback. + +Hook metadata changes publish `session.updated`. The frontend listens to `session.created`, `session.terminated`, and `session.updated` and invalidates the workspace query. + + +## Acceptance Criteria + +Agent adapter behavior: + +- Agent hook installation preserves user hooks and deduplicates Better-AO hooks. +- Hook callbacks persist native session id, title, and summary. +- `SessionInfo` returns normalized fields from persisted metadata. +- `SessionInfo` does not read transcripts or caches for title/summary. +- Adapter-specific metadata stays nil/empty unless a concrete feature requires it. + +Engine and UI: + +- Spawn installs hooks before launching the native agent. +- The session row exists before launch so hooks can merge metadata. +- Launch failure after row insertion deletes the row. +- Metadata updates publish `session.updated`. +- The dashboard refreshes title/summary without a manual reload. + +Verification: + +```sh +go test ./... +node --test scripts/*.test.mjs +pnpm --filter @better-ao/web lint:ts +``` diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..0cf6e3d9 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1780030872, + "narHash": "sha256-u6WU/yd/o8iYQrHX3RAwO1hYa3LkoSL+WNQD0rJfJZQ=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "e9a7635a57597d9754eccebdfc7045e6c8600e6b", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 00000000..d99c9381 --- /dev/null +++ b/flake.nix @@ -0,0 +1,41 @@ +{ + description = "agent-orchestrator development shell"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = + { + nixpkgs, + flake-utils, + ... + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { inherit system; }; + go = pkgs.go_1_25; + in + { + devShells.default = pkgs.mkShell { + buildInputs = [ + go + pkgs.gotools + pkgs.nodejs_22 + pkgs.pnpm_10 + pkgs.just + ]; + + shellHook = '' + export GOROOT="${go}/share/go" + export GOPATH="$PWD/.go" + export GOBIN="$GOPATH/bin" + export PNPM_HOME="$PWD/.pnpm" + export PATH="$GOBIN:$PNPM_HOME:$PATH" + ''; + }; + } + ); +} diff --git a/scripts/ao-here.sh b/scripts/ao-here.sh new file mode 100755 index 00000000..df980e0e --- /dev/null +++ b/scripts/ao-here.sh @@ -0,0 +1,104 @@ +#!/usr/bin/env bash +# +# ao-here.sh — register the current (or given) directory as an AO project and +# start the daemon. Uses OUR Go binary (built from this repo's +# backend/cmd/ao) explicitly — does NOT rely on whatever `ao` is on PATH +# (which on dev machines is usually the TypeScript orchestrator CLI). +# +# Usage: +# ./scripts/ao-here.sh # registers $PWD +# ./scripts/ao-here.sh /path/to/repo # registers given path +# +# Env overrides: +# AO_HOST (default 127.0.0.1) +# AO_PORT (default 3001) + +set -euo pipefail + +# Find the repo root: this script lives at /scripts/ao-here.sh +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" +BACKEND_DIR="${REPO_ROOT}/backend" + +if [[ ! -d "${BACKEND_DIR}/cmd/ao" ]]; then + echo "error: can't find backend/cmd/ao under ${REPO_ROOT}" >&2 + echo " (this script must live inside the agent-orchestrator repo)" >&2 + exit 1 +fi + +if ! command -v go >/dev/null 2>&1; then + echo "error: 'go' is required to build the daemon" >&2 + exit 1 +fi + +if ! command -v jq >/dev/null 2>&1; then + echo "error: 'jq' is required (brew install jq)" >&2 + exit 1 +fi + +# Build the daemon binary to a local path inside the repo (gitignored). +# Rebuild if any source file is newer than the existing binary. +AO_BIN="${BACKEND_DIR}/bin/ao" +if [[ ! -x "$AO_BIN" ]] || [[ -n "$(find "${BACKEND_DIR}" -newer "$AO_BIN" -type f -name '*.go' -print -quit 2>/dev/null || true)" ]]; then + echo "[ao] building daemon -> ${AO_BIN}" + (cd "$BACKEND_DIR" && go build -o "$AO_BIN" ./cmd/ao) +fi + +PROJECT_PATH="$(cd "${1:-$PWD}" && pwd)" + +if [[ ! -d "$PROJECT_PATH/.git" ]]; then + echo "error: $PROJECT_PATH is not a git repository (no .git dir)" >&2 + exit 1 +fi + +AO_HOST="${AO_HOST:-127.0.0.1}" +AO_PORT="${AO_PORT:-3001}" +BASE="http://${AO_HOST}:${AO_PORT}" + +is_ready() { curl -fsS --max-time 1 "${BASE}/readyz" >/dev/null 2>&1; } + +if is_ready; then + echo "[ao] daemon already running at ${BASE}" +else + echo "[ao] starting daemon..." + "$AO_BIN" start + for _ in {1..30}; do + if is_ready; then break; fi + sleep 1 + done + if ! is_ready; then + echo "error: daemon did not become ready in 30s at ${BASE}" >&2 + exit 1 + fi + echo "[ao] daemon ready at ${BASE}" +fi + +BODY="$(jq -nc --arg path "$PROJECT_PATH" '{path: $path}')" +RESPONSE="$(curl -sS -w '\n%{http_code}' -X POST -H 'Content-Type: application/json' -d "$BODY" "${BASE}/api/v1/projects")" +HTTP_CODE="$(echo "$RESPONSE" | tail -1)" +BODY_OUT="$(echo "$RESPONSE" | sed '$d')" + +case "$HTTP_CODE" in + 201) + PROJECT_ID="$(echo "$BODY_OUT" | jq -r '.project.id')" + echo "[ao] registered project: $PROJECT_ID -> $PROJECT_PATH" + ;; + 409) + PROJECT_ID="$(echo "$BODY_OUT" | jq -r '.details.existingProjectId // empty')" + if [[ -z "$PROJECT_ID" ]]; then + echo "error: conflict response missing existingProjectId; raw:" >&2 + echo "$BODY_OUT" | jq . >&2 2>/dev/null || echo "$BODY_OUT" >&2 + exit 1 + fi + echo "[ao] project already registered: $PROJECT_ID -> $PROJECT_PATH" + ;; + *) + echo "error: unexpected HTTP $HTTP_CODE from POST /api/v1/projects:" >&2 + echo "$BODY_OUT" | jq . >&2 2>/dev/null || echo "$BODY_OUT" >&2 + exit 1 + ;; +esac + +echo "" +echo " next:" +echo " ${AO_BIN} spawn --project $PROJECT_ID --prompt \"\""