diff --git a/backend/internal/adapters/scm/github/find_branch_pr.go b/backend/internal/adapters/scm/github/find_branch_pr.go new file mode 100644 index 00000000..7ee79981 --- /dev/null +++ b/backend/internal/adapters/scm/github/find_branch_pr.go @@ -0,0 +1,94 @@ +package github + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +// FindOpenPRForBranch returns the canonical github.com URL of the most +// recently updated open PR whose head ref is "{owner}:{branch}", or "" +// with a nil error when no open PR matches. +// +// The poller uses this for branch-based discovery: since the session +// record does not (yet) carry a stored PR URL, the only way to find +// "the PR for this session" is by the workspace branch. The endpoint +// hit is GET /repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open +// per the GitHub REST API. +// +// When multiple open PRs share the same head ref (rare but legal — +// e.g. forks that pushed to the same branch name), we pick the most +// recently updated one rather than failing closed. Failing closed +// would silently stop observing the PR every time a stale duplicate +// shows up. +func (p *Provider) FindOpenPRForBranch(ctx context.Context, owner, repo, branch string) (string, error) { + owner = strings.TrimSpace(owner) + repo = strings.TrimSpace(repo) + branch = strings.TrimSpace(branch) + if owner == "" || repo == "" || branch == "" { + return "", fmt.Errorf("github scm: FindOpenPRForBranch requires owner/repo/branch (got %q/%q/%q)", owner, repo, branch) + } + + q := url.Values{} + q.Set("state", "open") + q.Set("head", owner+":"+branch) + q.Set("per_page", "100") + + resp, err := p.client.doREST(ctx, http.MethodGet, repoPath(owner, repo, "pulls"), q, nil) + if err != nil { + return "", err + } + if len(resp.Body) == 0 { + return "", nil + } + var list []listedPR + if err := json.Unmarshal(resp.Body, &list); err != nil { + return "", fmt.Errorf("github scm: decode pulls list: %w", err) + } + if len(list) == 0 { + return "", nil + } + + best := -1 + var bestTime time.Time + for i, pr := range list { + if !strings.EqualFold(pr.State, "open") { + continue + } + t := parsePRTimestamp(pr.UpdatedAt) + if best < 0 || t.After(bestTime) { + best = i + bestTime = t + } + } + if best < 0 { + return "", nil + } + chosen := list[best] + if chosen.HTMLURL != "" { + return chosen.HTMLURL, nil + } + // Construct the canonical web URL from owner/repo/number when the + // API response omits html_url (some enterprise responses elide it). + return "https://github.com/" + owner + "/" + repo + "/pull/" + strconv.Itoa(chosen.Number), nil +} + +type listedPR struct { + Number int `json:"number"` + State string `json:"state"` + HTMLURL string `json:"html_url"` + UpdatedAt string `json:"updated_at"` +} + +func parsePRTimestamp(s string) time.Time { + t, err := time.Parse(time.RFC3339, s) + if err != nil { + return time.Time{} + } + return t +} diff --git a/backend/internal/adapters/scm/github/find_branch_pr_test.go b/backend/internal/adapters/scm/github/find_branch_pr_test.go new file mode 100644 index 00000000..39b5be77 --- /dev/null +++ b/backend/internal/adapters/scm/github/find_branch_pr_test.go @@ -0,0 +1,131 @@ +package github + +import ( + "encoding/json" + "errors" + "net/http" + "strings" + "testing" +) + +func TestFindOpenPRForBranchSingleMatch(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + if got := r.URL.Query().Get("head"); got != "acme:feat/x" { + t.Errorf("head query = %q, want acme:feat/x", got) + } + if got := r.URL.Query().Get("state"); got != "open" { + t.Errorf("state query = %q, want open", got) + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 7, "state": "open", "html_url": "https://github.com/acme/repo/pull/7", "updated_at": "2026-05-01T10:00:00Z"}, + }) + }) + + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "https://github.com/acme/repo/pull/7" { + t.Fatalf("url = %q", url) + } +} + +func TestFindOpenPRForBranchNoMatch(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("[]")) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "" { + t.Fatalf("url = %q, want empty", url) + } +} + +func TestFindOpenPRForBranchMultiplePicksMostRecent(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 1, "state": "open", "html_url": "https://github.com/acme/repo/pull/1", "updated_at": "2026-01-01T00:00:00Z"}, + {"number": 9, "state": "open", "html_url": "https://github.com/acme/repo/pull/9", "updated_at": "2026-05-01T00:00:00Z"}, + {"number": 4, "state": "open", "html_url": "https://github.com/acme/repo/pull/4", "updated_at": "2026-03-01T00:00:00Z"}, + }) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("FindOpenPRForBranch: %v", err) + } + if url != "https://github.com/acme/repo/pull/9" { + t.Fatalf("url = %q, want pull/9", url) + } +} + +func TestFindOpenPRForBranchEmptyInputsError(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + for _, tc := range []struct{ owner, repo, branch string }{ + {"", "repo", "b"}, + {"o", "", "b"}, + {"o", "r", ""}, + } { + _, err := p.FindOpenPRForBranch(ctx(), tc.owner, tc.repo, tc.branch) + if err == nil { + t.Errorf("expected error for empty input %+v", tc) + } + } +} + +func TestFindOpenPRForBranchRateLimited(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-RateLimit-Remaining", "0") + w.Header().Set("X-RateLimit-Reset", "1700000000") + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"message":"API rate limit exceeded"}`)) + }) + _, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if !errors.Is(err, ErrRateLimited) { + t.Fatalf("err = %v, want ErrRateLimited", err) + } +} + +func TestFindOpenPRForBranchAuthFailed(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"Bad credentials"}`)) + }) + _, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if !errors.Is(err, ErrAuthFailed) { + t.Fatalf("err = %v, want ErrAuthFailed", err) + } +} + +func TestFindOpenPRForBranchSynthesizesURLWhenHTMLEmpty(t *testing.T) { + fake := newFakeGH(t) + p := newProviderForTest(t, fake) + fake.on(http.MethodGet, "/repos/acme/repo/pulls", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 42, "state": "open", "updated_at": "2026-05-01T10:00:00Z"}, + }) + }) + url, err := p.FindOpenPRForBranch(ctx(), "acme", "repo", "feat/x") + if err != nil { + t.Fatalf("err = %v", err) + } + if !strings.HasSuffix(url, "/acme/repo/pull/42") { + t.Fatalf("url = %q, want suffix /acme/repo/pull/42", url) + } +} diff --git a/backend/internal/daemon/daemon.go b/backend/internal/daemon/daemon.go index 626656f5..c897a027 100644 --- a/backend/internal/daemon/daemon.go +++ b/backend/internal/daemon/daemon.go @@ -98,6 +98,12 @@ func Run() error { } _ = ss // sm: HTTP routes land in a follow-up PR (γ) + // SCM observation: polling Provider -> pr.Manager -> lifecycle nudges. + // Constructed after lifecycle so the PR Manager can forward observations + // to ApplyPRObservation; runs alongside the reaper as a sibling background + // loop. Missing GITHUB_TOKEN degrades gracefully (loop is not started). + scmStk := startSCM(ctx, store, projects, lcStack.lcm, log) + runErr := srv.Run(ctx) // Shut the background goroutines down in order: cancel the context FIRST so @@ -105,6 +111,7 @@ func Run() error { // via defer) avoids the LIFO trap where a Stop() that blocks on ctx-cancel // runs before the cancel — which would hang any non-signal exit path. stop() + scmStk.Stop() lcStack.Stop() if err := cdcPipe.Stop(); err != nil { log.Error("cdc pipeline shutdown", "err", err) diff --git a/backend/internal/daemon/scm_wiring.go b/backend/internal/daemon/scm_wiring.go new file mode 100644 index 00000000..a0390cac --- /dev/null +++ b/backend/internal/daemon/scm_wiring.go @@ -0,0 +1,61 @@ +package daemon + +import ( + "context" + "errors" + "log/slog" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/lifecycle" + "github.com/aoagents/agent-orchestrator/backend/internal/observe/scm" + "github.com/aoagents/agent-orchestrator/backend/internal/pr" + "github.com/aoagents/agent-orchestrator/backend/internal/project" + "github.com/aoagents/agent-orchestrator/backend/internal/storage/sqlite" +) + +// scmStack owns the SCM observation loop: a GitHub Provider, a pr.Manager +// that writes PR rows and forwards observations to lifecycle for nudges, +// and the polling goroutine that drives both. A nil-token environment +// degrades gracefully — the daemon still runs locally without SCM +// observation; PR-driven nudges (CI-failure log tail, review feedback, +// merge-conflict rebase) will not fire until a token is supplied. +type scmStack struct { + pollerDone <-chan struct{} +} + +// startSCM constructs and starts the SCM observation stack. The Provider +// reads its token from AO_GITHUB_TOKEN (preferred) or GITHUB_TOKEN, both +// via os.Getenv. Without a token, the poller is not started and a no-op +// done channel is returned — Stop is a free call in that case. +func startSCM(ctx context.Context, store *sqlite.Store, projects project.Manager, lcm *lifecycle.Manager, log *slog.Logger) *scmStack { + tokenSource := scmgithub.EnvTokenSource{EnvVars: []string{"AO_GITHUB_TOKEN", "GITHUB_TOKEN"}} + provider, err := scmgithub.NewProvider(scmgithub.ProviderOptions{Token: tokenSource}) + if err != nil { + if errors.Is(err, scmgithub.ErrNoToken) { + log.Info("scm poller: no GITHUB_TOKEN configured, SCM observation disabled") + } else { + log.Warn("scm poller: provider construction failed, SCM observation disabled", "err", err) + } + return &scmStack{pollerDone: closedDone()} + } + prMgr := pr.New(pr.Deps{Writer: store, Lifecycle: lcm}) + poller := scm.New(scm.Deps{ + Provider: provider, + Branches: provider, + Sessions: store, + Projects: projects, + PR: prMgr, + Logger: log, + }) + return &scmStack{pollerDone: poller.Start(ctx)} +} + +// Stop waits for the poller goroutine to exit. The caller must cancel the +// ctx passed to startSCM before calling Stop. +func (s *scmStack) Stop() { <-s.pollerDone } + +func closedDone() <-chan struct{} { + ch := make(chan struct{}) + close(ch) + return ch +} diff --git a/backend/internal/integration/scm_poller_test.go b/backend/internal/integration/scm_poller_test.go new file mode 100644 index 00000000..021fed82 --- /dev/null +++ b/backend/internal/integration/scm_poller_test.go @@ -0,0 +1,185 @@ +package integration + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/observe/scm" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// TestSCMPollerEndToEnd boots store + LCM + pr.Manager + the scm.Poller +// against an httptest GitHub stub, ticks once, and asserts: +// - the poller resolved the PR URL via branch discovery +// - pr.Manager persisted the PR row (PRWriter side of the bus) +// - lifecycle.ApplyPRObservation fired the CI-failure nudge to the messenger +// +// This is the seam-by-seam validation that aa-37's spec describes: from +// SCM observation to PR row to agent nudge, with every dependency the +// daemon wires in production. +func TestSCMPollerEndToEnd(t *testing.T) { + ctx := context.Background() + st := newStack(t) + + if err := st.store.Upsert(ctx, project.Row{ID: "acme", Path: "/repo/acme", RepoOriginURL: "https://github.com/acme/repo.git", RegisteredAt: time.Now()}); err != nil { + t.Fatal(err) + } + sess, err := st.sm.Spawn(ctx, ports.SpawnConfig{ProjectID: "acme", Kind: domain.KindWorker, Branch: "feat/x", Prompt: "fix CI"}) + if err != nil { + t.Fatal(err) + } + + // The PR URL the GitHub stub will report for branch acme:feat/x. + prURL := "https://github.com/acme/repo/pull/77" + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/json") + switch { + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/pulls": + if got := r.URL.Query().Get("head"); got != "acme:feat/x" { + t.Errorf("pulls list head = %q, want acme:feat/x", got) + } + _ = json.NewEncoder(w).Encode([]map[string]any{ + {"number": 77, "state": "open", "html_url": prURL, "updated_at": "2026-05-15T10:00:00Z"}, + }) + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/pulls/77": + w.Header().Set("ETag", `W/"v1"`) + _ = json.NewEncoder(w).Encode(map[string]any{ + "number": 77, + "state": "open", + "draft": false, + "merged": false, + "merged_at": nil, + "html_url": prURL, + "head": map[string]any{"ref": "feat/x", "sha": "deadbeef"}, + "base": map[string]any{"ref": "main"}, + "mergeable": false, + "rebaseable": true, + "mergeable_state": "blocked", + "merge_state_status": "BLOCKED", + }) + case r.Method == http.MethodPost && r.URL.Path == "/graphql": + _ = json.NewEncoder(w).Encode(map[string]any{ + "data": map[string]any{ + "repository": map[string]any{ + "pullRequest": map[string]any{ + "number": 77, + "url": prURL, + "state": "OPEN", + "isDraft": false, + "merged": false, + "closed": false, + "mergeable": "MERGEABLE", + "mergeStateStatus": "BLOCKED", + "reviewDecision": "REVIEW_REQUIRED", + "headRefOid": "deadbeef", + "commits": map[string]any{"nodes": []any{ + map[string]any{"commit": map[string]any{ + "oid": "deadbeef", + "statusCheckRollup": map[string]any{ + "state": "FAILURE", + "contexts": map[string]any{ + "nodes": []any{ + map[string]any{ + "__typename": "CheckRun", + "name": "build", + "status": "COMPLETED", + "conclusion": "FAILURE", + "detailsUrl": "https://github.com/acme/repo/runs/9001", + "databaseId": float64(9001), + }, + }, + "pageInfo": map[string]any{"hasNextPage": false}, + }, + }, + }}, + }}, + "reviewThreads": map[string]any{"nodes": []any{}}, + }, + }, + }, + }) + case r.Method == http.MethodGet && r.URL.Path == "/repos/acme/repo/actions/jobs/9001/logs": + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte("FAIL TestX\nFAIL TestY\n")) + default: + t.Errorf("unexpected request: %s %s", r.Method, r.URL.Path) + http.Error(w, "no handler", http.StatusNotImplemented) + } + })) + t.Cleanup(server.Close) + + provider, err := scmgithub.NewProvider(scmgithub.ProviderOptions{ + Token: scmgithub.StaticTokenSource("tkn"), + HTTPClient: server.Client(), + RESTBase: server.URL, + GraphQLURL: server.URL + "/graphql", + }) + if err != nil { + t.Fatal(err) + } + + projects := project.NewManager(st.store) + poller := scm.New(scm.Deps{ + Provider: provider, + Branches: provider, + Sessions: st.store, + Projects: projects, + PR: st.prm, + Interval: time.Hour, // ticker won't fire — we call Tick directly + ObserveTimeout: 5 * time.Second, + RemoteResolver: func(context.Context, string) (string, error) { + // The project Row.RepoOriginURL is set above, so this fallback + // should never be called; failing loudly catches a regression + // where the poller silently shells out instead of using + // project.Repo. + t.Fatalf("remote resolver should not be invoked when project.Repo is set") + return "", nil + }, + }) + + if err := poller.Tick(ctx); err != nil { + t.Fatalf("poller.Tick: %v", err) + } + + got, ok, err := st.store.GetPR(ctx, prURL) + if err != nil { + t.Fatal(err) + } + if !ok { + t.Fatalf("pr row not written for %s", prURL) + } + if got.SessionID != sess.ID { + t.Errorf("pr.SessionID = %q, want %q", got.SessionID, sess.ID) + } + if got.CI != domain.CIFailing { + t.Errorf("pr.CI = %q, want %q", got.CI, domain.CIFailing) + } + checks, err := st.store.ListChecks(ctx, prURL) + if err != nil { + t.Fatal(err) + } + if len(checks) != 1 || checks[0].Status != domain.PRCheckFailed { + t.Fatalf("checks = %+v", checks) + } + + if len(st.msg.msgs) != 1 { + t.Fatalf("expected exactly 1 lifecycle nudge, got %d (a double-nudge would regress sendOnce)", len(st.msg.msgs)) + } + if !strings.Contains(st.msg.msgs[0], "CI is failing") { + t.Errorf("messenger did not receive CI-failure body; got %q", st.msg.msgs[0]) + } + if !strings.Contains(st.msg.msgs[0], "FAIL TestX") { + t.Errorf("messenger did not receive log-tail body; got %q", st.msg.msgs[0]) + } +} diff --git a/backend/internal/observe/scm/poller.go b/backend/internal/observe/scm/poller.go new file mode 100644 index 00000000..e907d954 --- /dev/null +++ b/backend/internal/observe/scm/poller.go @@ -0,0 +1,364 @@ +// Package scm implements the OBSERVE-layer polling loop that drives +// SCM (pull-request) observations into the PR Manager and Lifecycle +// Manager. The loop is intentionally dumb: every tick it lists alive +// sessions, finds the open PR for each session's branch, asks the +// Provider for an observation, and hands the result to the PR +// Manager (which transactionally writes the row and forwards to +// lifecycle for nudges). +// +// The poller does not own any reaction logic. CI-failure log-tail +// nudges, review-feedback nudges (capped at reviewMaxNudge), and +// merge-conflict rebase nudges all live in lifecycle.ApplyPRObservation. +// Polling is uniform 30s for v1; per-PR adaptive cadence is a follow-up. +package scm + +import ( + "context" + "errors" + "log/slog" + "net/url" + "os/exec" + "strings" + "sync/atomic" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// DefaultInterval is the cadence used when Deps.Interval is zero. +const DefaultInterval = 30 * time.Second + +// DefaultObserveTimeout caps one Provider.Observe call so a single hung +// request can't stall the whole tick. +const DefaultObserveTimeout = 15 * time.Second + +// Provider observes one PR by its canonical URL. The github adapter +// satisfies this; other SCM adapters (gitlab, etc.) can implement the +// same surface without touching the poller. +type Provider interface { + Observe(ctx context.Context, prURL string) (ports.PRObservation, error) +} + +// BranchPRFinder resolves a session's branch to its open PR URL. v1 +// uses this because sessions do not (yet) carry a PR URL field; when +// they do, the poller will prefer the stored URL and only fall back +// here. An empty return with nil error means "no matching open PR". +type BranchPRFinder interface { + FindOpenPRForBranch(ctx context.Context, owner, repo, branch string) (string, error) +} + +// sessionLister narrows the sqlite store to what the poller needs. +type sessionLister interface { + ListAllSessions(ctx context.Context) ([]domain.SessionRecord, error) +} + +// projectGetter narrows project.Manager to its read path. +type projectGetter interface { + Get(ctx context.Context, id domain.ProjectID) (project.GetResult, error) +} + +// prApplier is the seam over pr.Manager.ApplyObservation — which itself +// persists the PR row and forwards to lifecycle for nudges. Keeping +// this one method on the seam means the poller never needs to know +// about lifecycle directly. +type prApplier interface { + ApplyObservation(ctx context.Context, id domain.SessionID, o ports.PRObservation) error +} + +// remoteResolver shells out to git to read a repo's origin URL. +// Injected so tests don't require a real git checkout. +type remoteResolver func(ctx context.Context, projectPath string) (string, error) + +// Deps groups every collaborator the Poller needs. Zero-valued +// optional fields fall back to safe defaults (slog.Default, 30s tick, +// 15s observe deadline, real `git` for origin lookup). +type Deps struct { + Provider Provider + Branches BranchPRFinder + Sessions sessionLister + Projects projectGetter + PR prApplier + Logger *slog.Logger + Interval time.Duration + ObserveTimeout time.Duration + RemoteResolver func(ctx context.Context, projectPath string) (string, error) +} + +// Poller is the SCM observation loop. Construct it with New, start the +// background goroutine with Start. Tick is exported so daemon and tests +// can drive a single cycle synchronously. +type Poller struct { + provider Provider + branches BranchPRFinder + sessions sessionLister + projects projectGetter + pr prApplier + logger *slog.Logger + interval time.Duration + observeTimeout time.Duration + remoteResolver remoteResolver + + healthy atomic.Bool +} + +// New constructs a Poller from its dependencies. +func New(d Deps) *Poller { + p := &Poller{ + provider: d.Provider, + branches: d.Branches, + sessions: d.Sessions, + projects: d.Projects, + pr: d.PR, + logger: d.Logger, + interval: d.Interval, + observeTimeout: d.ObserveTimeout, + remoteResolver: d.RemoteResolver, + } + if p.interval <= 0 { + p.interval = DefaultInterval + } + if p.observeTimeout <= 0 { + p.observeTimeout = DefaultObserveTimeout + } + if p.logger == nil { + p.logger = slog.Default() + } + if p.remoteResolver == nil { + p.remoteResolver = defaultRemoteResolver + } + p.healthy.Store(true) + return p +} + +// Healthy reports whether the SCM provider's authentication has been +// observed working since the poller started. It starts true and flips +// to false the first time the provider returns ErrAuthFailed; it does +// NOT auto-recover, because a single subsequent success could be an +// ETag-cached 304 that didn't actually exercise the token. A future +// health route consumes this bit; clearing it after token rotation is +// a daemon-restart concern. +func (p *Poller) Healthy() bool { return p.healthy.Load() } + +// Start launches the background goroutine and returns a channel that +// closes once the loop has exited. The loop exits when ctx is cancelled; +// callers should wait on the returned channel before tearing down the +// PR Manager / lifecycle / store dependencies. +func (p *Poller) Start(ctx context.Context) <-chan struct{} { + done := make(chan struct{}) + go p.loop(ctx, done) + return done +} + +func (p *Poller) loop(ctx context.Context, done chan<- struct{}) { + defer close(done) + t := time.NewTicker(p.interval) + defer t.Stop() + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := p.Tick(ctx); err != nil { + p.logger.Error("scm poller: tick failed", "err", err) + } + } + } +} + +// Tick runs one observation cycle. +// +// It lists every session, skips terminated rows and rows without a +// branch, resolves each remaining session's open PR URL via the +// BranchPRFinder, asks the Provider for an observation under a +// per-call deadline, and hands a successful observation to the PR +// Manager. Errors are classified by sentinel: +// - ErrRateLimited: short-circuit the rest of the tick (don't burn +// through remaining sessions while GitHub is asking us to back off). +// - ErrAuthFailed: flip Healthy() to false; continue with the next +// session so a single misconfigured token does not stall the loop. +// - other: log warn, continue. +// +// A session-listing failure is the only error Tick propagates; it +// short-circuits the cycle just like the reaper. +func (p *Poller) Tick(ctx context.Context) error { + sessions, err := p.sessions.ListAllSessions(ctx) + if err != nil { + return err + } + for _, sess := range sessions { + if sess.IsTerminated || sess.Metadata.Branch == "" { + continue + } + if err := ctx.Err(); err != nil { + return err + } + stop := p.pollOne(ctx, sess) + if stop { + return nil + } + } + return nil +} + +// pollOne handles one session. Returns stop=true when the caller +// should short-circuit the remaining sessions (rate-limit signal). +func (p *Poller) pollOne(ctx context.Context, sess domain.SessionRecord) bool { + prURL, err := p.resolvePRURL(ctx, sess) + if err != nil { + return p.classify(sess.ID, "resolve-pr-url", err) + } + if prURL == "" { + p.logger.Debug("scm poller: no open PR for branch, skipping", + "session", sess.ID, "branch", sess.Metadata.Branch) + return false + } + + pollCtx, cancel := context.WithTimeout(ctx, p.observeTimeout) + defer cancel() + obs, err := p.provider.Observe(pollCtx, prURL) + if err != nil { + return p.classify(sess.ID, "observe", err) + } + if !obs.Fetched { + p.logger.Debug("scm poller: observation not fetched, skipping", + "session", sess.ID, "url", prURL) + return false + } + if err := p.pr.ApplyObservation(ctx, sess.ID, obs); err != nil { + p.logger.Warn("scm poller: apply observation failed", + "session", sess.ID, "err", err) + } + return false +} + +// classify maps a Provider/lookup error to the loop's continue/stop +// decision and surfaces it in the logs. Auth-class failures flip the +// Healthy() bool; rate-limit signals stop the tick. +func (p *Poller) classify(sid domain.SessionID, stage string, err error) bool { + switch { + case errors.Is(err, scmgithub.ErrRateLimited): + p.logger.Warn("scm poller: rate limited, skipping rest of tick", + "session", sid, "stage", stage, "err", err) + return true + case errors.Is(err, scmgithub.ErrAuthFailed): + p.healthy.Store(false) + p.logger.Error("scm poller: auth failed, provider marked unhealthy", + "session", sid, "stage", stage, "err", err) + return false + default: + p.logger.Warn("scm poller: error", + "session", sid, "stage", stage, "err", err) + return false + } +} + +// resolvePRURL finds the open PR URL for a session's branch. +// +// v1 strategy: branch-based discovery. Look up the session's project, +// derive owner/repo from project.Repo (which today holds the origin URL), +// falling back to `git remote get-url origin` against the project's +// on-disk path, then ask BranchPRFinder. When neither yields an +// owner/repo, the session is silently skipped — that is not a poller bug, +// it's a project that hasn't been configured for SCM observation. +// +// When the session record grows a stored PR URL field (separate PR), +// this function should prefer it over branch discovery. +func (p *Poller) resolvePRURL(ctx context.Context, sess domain.SessionRecord) (string, error) { + if p.branches == nil { + return "", nil + } + res, err := p.projects.Get(ctx, sess.ProjectID) + if err != nil { + return "", err + } + if res.Project == nil { + return "", nil + } + owner, repo, ok := ownerRepoFromProject(*res.Project) + if !ok { + remoteURL, err := p.remoteResolver(ctx, res.Project.Path) + if err != nil { + p.logger.Debug("scm poller: git remote lookup failed, skipping session", + "session", sess.ID, "project", sess.ProjectID, "err", err) + return "", nil + } + owner, repo, ok = parseGitHubRemote(remoteURL) + if !ok { + return "", nil + } + } + return p.branches.FindOpenPRForBranch(ctx, owner, repo, sess.Metadata.Branch) +} + +// ownerRepoFromProject derives (owner, repo) from a Project. Today +// project.Repo holds the origin URL (despite the type comment claiming +// "owner/name") — so we try both shapes here without touching the +// project package. +func ownerRepoFromProject(p project.Project) (owner, repo string, ok bool) { + repoField := strings.TrimSpace(p.Repo) + if repoField == "" { + return "", "", false + } + if o, r, ok := parseGitHubRemote(repoField); ok { + return o, r, true + } + return "", "", false +} + +// parseGitHubRemote accepts both URL- and SSH-style remote strings and +// the bare "owner/repo" shorthand. It is intentionally host-agnostic — +// the github.Provider will reject non-github hosts at Observe time, so +// rejecting them here would just duplicate that check and silently drop +// legitimately-configured projects on enterprise hosts. +// +// Recognised forms: +// - https://github.com/owner/repo[.git] +// - http(s)://host/owner/repo[.git] +// - git@host:owner/repo[.git] +// - ssh://git@host/owner/repo[.git] +// - owner/repo +func parseGitHubRemote(s string) (owner, repo string, ok bool) { + s = strings.TrimSpace(s) + if s == "" { + return "", "", false + } + switch { + case strings.HasPrefix(s, "git@"): + idx := strings.Index(s, ":") + if idx < 0 { + return "", "", false + } + s = s[idx+1:] + case strings.Contains(s, "://"): + u, err := url.Parse(s) + if err != nil || u.Host == "" { + return "", "", false + } + s = strings.TrimPrefix(u.Path, "/") + } + s = strings.TrimSuffix(s, ".git") + parts := strings.SplitN(s, "/", 3) + if len(parts) < 2 { + return "", "", false + } + owner = strings.TrimSpace(parts[0]) + repo = strings.TrimSpace(parts[1]) + if owner == "" || repo == "" { + return "", "", false + } + return owner, repo, true +} + +func defaultRemoteResolver(ctx context.Context, projectPath string) (string, error) { + if strings.TrimSpace(projectPath) == "" { + return "", errors.New("scm poller: project has no path") + } + out, err := exec.CommandContext(ctx, "git", "-C", projectPath, "remote", "get-url", "origin").Output() + if err != nil { + return "", err + } + return strings.TrimSpace(string(out)), nil +} diff --git a/backend/internal/observe/scm/poller_test.go b/backend/internal/observe/scm/poller_test.go new file mode 100644 index 00000000..d4350594 --- /dev/null +++ b/backend/internal/observe/scm/poller_test.go @@ -0,0 +1,519 @@ +package scm + +import ( + "context" + "errors" + "log/slog" + "sync" + "sync/atomic" + "testing" + "time" + + scmgithub "github.com/aoagents/agent-orchestrator/backend/internal/adapters/scm/github" + "github.com/aoagents/agent-orchestrator/backend/internal/domain" + "github.com/aoagents/agent-orchestrator/backend/internal/ports" + "github.com/aoagents/agent-orchestrator/backend/internal/project" +) + +// --------------------------------------------------------------------------- +// Fakes +// --------------------------------------------------------------------------- + +type fakeProvider struct { + mu sync.Mutex + calls []string + obs map[string]ports.PRObservation + errs map[string]error + hangFor time.Duration +} + +func (f *fakeProvider) Observe(ctx context.Context, prURL string) (ports.PRObservation, error) { + f.mu.Lock() + f.calls = append(f.calls, prURL) + hang := f.hangFor + f.mu.Unlock() + if hang > 0 { + select { + case <-time.After(hang): + case <-ctx.Done(): + return ports.PRObservation{URL: prURL}, ctx.Err() + } + } + f.mu.Lock() + defer f.mu.Unlock() + if err, ok := f.errs[prURL]; ok { + return ports.PRObservation{URL: prURL}, err + } + if o, ok := f.obs[prURL]; ok { + return o, nil + } + return ports.PRObservation{URL: prURL}, nil +} + +func (f *fakeProvider) seenURLs() []string { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]string, len(f.calls)) + copy(out, f.calls) + return out +} + +type fakeBranches struct { + mu sync.Mutex + urls map[string]string // owner/repo/branch -> prURL + err error + callCount int +} + +func (f *fakeBranches) FindOpenPRForBranch(_ context.Context, owner, repo, branch string) (string, error) { + f.mu.Lock() + defer f.mu.Unlock() + f.callCount++ + if f.err != nil { + return "", f.err + } + return f.urls[owner+"/"+repo+"/"+branch], nil +} + +type fakeSessions struct { + sessions []domain.SessionRecord + err error +} + +func (f *fakeSessions) ListAllSessions(context.Context) ([]domain.SessionRecord, error) { + if f.err != nil { + return nil, f.err + } + out := make([]domain.SessionRecord, len(f.sessions)) + copy(out, f.sessions) + return out, nil +} + +type fakeProjects struct { + projects map[domain.ProjectID]project.Project +} + +func (f *fakeProjects) Get(_ context.Context, id domain.ProjectID) (project.GetResult, error) { + p, ok := f.projects[id] + if !ok { + return project.GetResult{}, errors.New("project not found") + } + pp := p + return project.GetResult{Status: "ok", Project: &pp}, nil +} + +type fakePR struct { + mu sync.Mutex + applied []appliedObs + applyErr error +} + +type appliedObs struct { + id domain.SessionID + obs ports.PRObservation +} + +func (f *fakePR) ApplyObservation(_ context.Context, id domain.SessionID, o ports.PRObservation) error { + f.mu.Lock() + defer f.mu.Unlock() + f.applied = append(f.applied, appliedObs{id: id, obs: o}) + return f.applyErr +} + +func (f *fakePR) records() []appliedObs { + f.mu.Lock() + defer f.mu.Unlock() + out := make([]appliedObs, len(f.applied)) + copy(out, f.applied) + return out +} + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +func newTestPoller(t *testing.T, d Deps) *Poller { + t.Helper() + if d.Logger == nil { + d.Logger = slog.New(slog.NewTextHandler(testWriter{t}, &slog.HandlerOptions{Level: slog.LevelDebug})) + } + return New(d) +} + +type testWriter struct{ t *testing.T } + +func (w testWriter) Write(p []byte) (int, error) { + w.t.Log(string(p)) + return len(p), nil +} + +func aliveSession(id domain.SessionID, project domain.ProjectID, branch string) domain.SessionRecord { + return domain.SessionRecord{ + ID: id, + ProjectID: project, + Kind: domain.KindWorker, + Metadata: domain.SessionMetadata{Branch: branch, RuntimeHandleID: "h"}, + } +} + +func terminatedSession(id domain.SessionID, project domain.ProjectID, branch string) domain.SessionRecord { + s := aliveSession(id, project, branch) + s.IsTerminated = true + return s +} + +func githubProject(id domain.ProjectID) project.Project { + return project.Project{ID: id, Path: "/repo/" + string(id), Repo: "https://github.com/acme/repo.git"} +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +func TestTickObservesAliveSessionAndAppliesObservation(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + terminatedSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/11": {Fetched: true, URL: "https://github.com/acme/repo/pull/11", Number: 11, CI: domain.CIPassing}, + }} + prm := &fakePR{} + + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: sessions, + Projects: projects, + PR: prm, + }) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick error: %v", err) + } + + if got := provider.seenURLs(); len(got) != 1 || got[0] != "https://github.com/acme/repo/pull/11" { + t.Fatalf("provider.Observe calls = %v, want [pull/11] (terminated session skipped)", got) + } + rec := prm.records() + if len(rec) != 1 || rec[0].id != "s-1" || rec[0].obs.Number != 11 { + t.Fatalf("pr.ApplyObservation = %+v, want one call for s-1/pull-11", rec) + } +} + +func TestTickSkipsApplyWhenNotFetched(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{"acme/repo/feat/x": "https://github.com/acme/repo/pull/11"}} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/11": {Fetched: false, URL: "https://github.com/acme/repo/pull/11"}, + }} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("ApplyObservation called %d times on !Fetched obs", len(got)) + } +} + +func TestTickSkipsSessionsWithoutBranch(t *testing.T) { + ctx := context.Background() + noBranch := aliveSession("s-1", "acme", "") + sessions := &fakeSessions{sessions: []domain.SessionRecord{noBranch}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 0 { + t.Fatalf("provider should not be called for session without branch, got %v", got) + } + if got := branches.callCount; got != 0 { + t.Fatalf("branches lookup should not be called for session without branch, got %d", got) + } +} + +func TestTickSkipsSessionsWithNoOpenPR(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{}} // empty: no PR exists + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 0 { + t.Fatalf("provider should not be called when no PR found, got %v", got) + } +} + +func TestTickRateLimitShortCircuits(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": scmgithub.ErrRateLimited, + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := provider.seenURLs(); len(got) != 1 { + t.Fatalf("expected exactly one Observe call (rate-limit short-circuits), got %v", got) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("no observations should be applied after rate-limit, got %d", len(got)) + } +} + +func TestTickAuthFailureMarksUnhealthyAndContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": scmgithub.ErrAuthFailed, + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12, CI: domain.CIPassing}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + if !p.Healthy() { + t.Fatalf("poller should start healthy") + } + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if p.Healthy() { + t.Fatalf("poller should be unhealthy after ErrAuthFailed") + } + if got := provider.seenURLs(); len(got) != 2 { + t.Fatalf("expected provider to be called for both sessions, got %v", got) + } + rec := prm.records() + if len(rec) != 1 || rec[0].id != "s-2" { + t.Fatalf("expected one apply for s-2 after auth failure on s-1, got %+v", rec) + } +} + +func TestTickProjectLookupErrorContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "missing", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }} + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 1 || got[0].id != "s-2" { + t.Fatalf("expected s-2 applied after project-lookup err on s-1, got %+v", got) + } + if !p.Healthy() { + t.Fatalf("project lookup error should not mark unhealthy") + } +} + +func TestTickGenericErrorContinues(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{ + aliveSession("s-1", "acme", "feat/x"), + aliveSession("s-2", "acme", "feat/y"), + }} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{ + "acme/repo/feat/x": "https://github.com/acme/repo/pull/11", + "acme/repo/feat/y": "https://github.com/acme/repo/pull/12", + }} + provider := &fakeProvider{ + errs: map[string]error{ + "https://github.com/acme/repo/pull/11": errors.New("transient network blip"), + }, + obs: map[string]ports.PRObservation{ + "https://github.com/acme/repo/pull/12": {Fetched: true, URL: "https://github.com/acme/repo/pull/12", Number: 12}, + }, + } + prm := &fakePR{} + p := newTestPoller(t, Deps{Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm}) + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if got := prm.records(); len(got) != 1 || got[0].id != "s-2" { + t.Fatalf("expected s-2 applied after generic err on s-1, got %+v", got) + } + if !p.Healthy() { + t.Fatalf("generic errors should not mark unhealthy") + } +} + +func TestPerCallDeadline(t *testing.T) { + ctx := context.Background() + sessions := &fakeSessions{sessions: []domain.SessionRecord{aliveSession("s-1", "acme", "feat/x")}} + projects := &fakeProjects{projects: map[domain.ProjectID]project.Project{"acme": githubProject("acme")}} + branches := &fakeBranches{urls: map[string]string{"acme/repo/feat/x": "https://github.com/acme/repo/pull/11"}} + provider := &fakeProvider{hangFor: 200 * time.Millisecond} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: sessions, + Projects: projects, + PR: prm, + ObserveTimeout: 10 * time.Millisecond, + }) + start := time.Now() + if err := p.Tick(ctx); err != nil { + t.Fatalf("Tick: %v", err) + } + if elapsed := time.Since(start); elapsed > 150*time.Millisecond { + t.Fatalf("Tick took %v — per-call deadline did not fire", elapsed) + } + if got := prm.records(); len(got) != 0 { + t.Fatalf("no apply on deadline timeout, got %d", len(got)) + } +} + +func TestStartDrainsOnContextCancel(t *testing.T) { + sessions := &fakeSessions{} + projects := &fakeProjects{} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, Branches: branches, Sessions: sessions, Projects: projects, PR: prm, + Interval: 5 * time.Millisecond, + }) + ctx, cancel := context.WithCancel(context.Background()) + done := p.Start(ctx) + cancel() + select { + case <-done: + case <-time.After(time.Second): + t.Fatal("poller did not exit within 1s of ctx cancel") + } +} + +func TestStartTicksRepeatedly(t *testing.T) { + var ticks atomic.Int32 + sessions := &fakeSessions{} + projects := &fakeProjects{} + branches := &fakeBranches{} + provider := &fakeProvider{} + prm := &fakePR{} + p := newTestPoller(t, Deps{ + Provider: provider, + Branches: branches, + Sessions: &countingSessions{wrap: sessions, ticks: &ticks}, + Projects: projects, + PR: prm, + Interval: 5 * time.Millisecond, + }) + ctx, cancel := context.WithCancel(context.Background()) + done := p.Start(ctx) + deadline := time.After(500 * time.Millisecond) +loop: + for { + if ticks.Load() >= 3 { + break + } + select { + case <-deadline: + break loop + case <-time.After(2 * time.Millisecond): + } + } + cancel() + <-done + if ticks.Load() < 2 { + t.Fatalf("expected at least 2 ticks, got %d", ticks.Load()) + } +} + +// countingSessions ticks the counter each time ListAllSessions is called. +type countingSessions struct { + wrap *fakeSessions + ticks *atomic.Int32 +} + +func (c *countingSessions) ListAllSessions(ctx context.Context) ([]domain.SessionRecord, error) { + c.ticks.Add(1) + return c.wrap.ListAllSessions(ctx) +} + +// --------------------------------------------------------------------------- +// owner/repo derivation +// --------------------------------------------------------------------------- + +func TestParseGitHubRemote(t *testing.T) { + tests := []struct{ in, owner, repo string }{ + {"https://github.com/acme/repo.git", "acme", "repo"}, + {"https://github.com/acme/repo", "acme", "repo"}, + {"git@github.com:acme/repo.git", "acme", "repo"}, + {"ssh://git@github.com/acme/repo.git", "acme", "repo"}, + {"acme/repo", "acme", "repo"}, + {"", "", ""}, + {"https://gitlab.com/x/y", "x", "y"}, // host-agnostic parser; provider rejects non-GitHub at Observe time + } + for _, tc := range tests { + owner, repo, ok := parseGitHubRemote(tc.in) + if tc.owner == "" { + if ok { + t.Errorf("parseGitHubRemote(%q): expected !ok, got %q/%q", tc.in, owner, repo) + } + continue + } + if !ok || owner != tc.owner || repo != tc.repo { + t.Errorf("parseGitHubRemote(%q) = %q/%q ok=%v; want %q/%q true", tc.in, owner, repo, ok, tc.owner, tc.repo) + } + } +}