From b9bdf39bf61ebd3d3fb244cbd9477867e74acbca Mon Sep 17 00:00:00 2001 From: bwarminski Date: Fri, 12 Jun 2026 16:24:29 -0400 Subject: [PATCH] Implement Branch Connection View for Postgres and Vitess Add `pscale branch connections {show, top, kill, kill-transaction}`: a one-shot list and an interactive TUI that polls, captures, and replays live connections. The engine is detected per database; `branch processlist` remains as a hidden compatibility alias. --- .gitignore | 1 + go.mod | 10 +- internal/cmd/branch/branch.go | 1 + internal/cmd/branch/connections.go | 40 + internal/cmd/branch/connections/actions.go | 162 ++ .../cmd/branch/connections/actions_test.go | 249 +++ internal/cmd/branch/connections/client.go | 20 + .../cmd/branch/connections/engine_flags.go | 26 + internal/cmd/branch/connections/headless.go | 78 + .../cmd/branch/connections/headless_test.go | 122 ++ internal/cmd/branch/connections/replay.go | 39 + .../cmd/branch/connections/replay_test.go | 392 ++++ internal/cmd/branch/connections/show.go | 280 +++ .../cmd/branch/connections/show_metadata.go | 114 + .../cmd/branch/connections/show_mysql_test.go | 262 +++ internal/cmd/branch/connections/show_test.go | 492 +++++ internal/cmd/branch/connections/top.go | 654 ++++++ internal/cmd/branch/connections/top_test.go | 821 +++++++ internal/cmd/branch/connections_commands.go | 176 ++ internal/cmd/branch/connections_test.go | 763 +++++++ internal/cmd/branch/kill.go | 88 +- internal/cmd/branch/kill_test.go | 119 +- internal/cmd/branch/processlist.go | 116 +- internal/cmd/branch/processlist_test.go | 122 +- internal/cmd/root.go | 3 +- internal/connections/actions.go | 12 + internal/connections/client.go | 580 +++++ internal/connections/client_test.go | 992 +++++++++ internal/connections/connection_list.go | 260 +++ internal/connections/connection_list_test.go | 267 +++ internal/connections/history/capture.go | 19 + .../connections/history/capture_history.go | 127 ++ .../history/capture_history_test.go | 150 ++ .../connections/history/capture_reader.go | 113 + .../history/capture_reader_test.go | 166 ++ .../connections/history/capture_writer.go | 92 + .../history/capture_writer_test.go | 136 ++ internal/connections/history/replay_source.go | 52 + .../connections/history/replay_source_test.go | 102 + internal/connections/tui/blocking_graph.go | 231 ++ .../connections/tui/blocking_graph_test.go | 138 ++ internal/connections/tui/capture.go | 89 + .../tui/connection_capabilities.go | 113 + internal/connections/tui/detail.go | 454 ++++ internal/connections/tui/detail_test.go | 690 ++++++ internal/connections/tui/help.go | 143 ++ internal/connections/tui/model.go | 1193 +++++++++++ internal/connections/tui/model_test.go | 1895 +++++++++++++++++ internal/connections/tui/query_format.go | 134 ++ internal/connections/tui/query_format_test.go | 234 ++ internal/connections/tui/styles.go | 103 + internal/connections/tui/styles_test.go | 188 ++ internal/connections/tui/table.go | 1355 ++++++++++++ internal/connections/tui/table_test.go | 1118 ++++++++++ 54 files changed, 15999 insertions(+), 297 deletions(-) create mode 100644 internal/cmd/branch/connections.go create mode 100644 internal/cmd/branch/connections/actions.go create mode 100644 internal/cmd/branch/connections/actions_test.go create mode 100644 internal/cmd/branch/connections/client.go create mode 100644 internal/cmd/branch/connections/engine_flags.go create mode 100644 internal/cmd/branch/connections/headless.go create mode 100644 internal/cmd/branch/connections/headless_test.go create mode 100644 internal/cmd/branch/connections/replay.go create mode 100644 internal/cmd/branch/connections/replay_test.go create mode 100644 internal/cmd/branch/connections/show.go create mode 100644 internal/cmd/branch/connections/show_metadata.go create mode 100644 internal/cmd/branch/connections/show_mysql_test.go create mode 100644 internal/cmd/branch/connections/show_test.go create mode 100644 internal/cmd/branch/connections/top.go create mode 100644 internal/cmd/branch/connections/top_test.go create mode 100644 internal/cmd/branch/connections_commands.go create mode 100644 internal/cmd/branch/connections_test.go create mode 100644 internal/connections/actions.go create mode 100644 internal/connections/client.go create mode 100644 internal/connections/client_test.go create mode 100644 internal/connections/connection_list.go create mode 100644 internal/connections/connection_list_test.go create mode 100644 internal/connections/history/capture.go create mode 100644 internal/connections/history/capture_history.go create mode 100644 internal/connections/history/capture_history_test.go create mode 100644 internal/connections/history/capture_reader.go create mode 100644 internal/connections/history/capture_reader_test.go create mode 100644 internal/connections/history/capture_writer.go create mode 100644 internal/connections/history/capture_writer_test.go create mode 100644 internal/connections/history/replay_source.go create mode 100644 internal/connections/history/replay_source_test.go create mode 100644 internal/connections/tui/blocking_graph.go create mode 100644 internal/connections/tui/blocking_graph_test.go create mode 100644 internal/connections/tui/capture.go create mode 100644 internal/connections/tui/connection_capabilities.go create mode 100644 internal/connections/tui/detail.go create mode 100644 internal/connections/tui/detail_test.go create mode 100644 internal/connections/tui/help.go create mode 100644 internal/connections/tui/model.go create mode 100644 internal/connections/tui/model_test.go create mode 100644 internal/connections/tui/query_format.go create mode 100644 internal/connections/tui/query_format_test.go create mode 100644 internal/connections/tui/styles.go create mode 100644 internal/connections/tui/styles_test.go create mode 100644 internal/connections/tui/table.go create mode 100644 internal/connections/tui/table_test.go diff --git a/.gitignore b/.gitignore index 30eccc48..5d82f069 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ dist/ completions/ vendor/ .idea +.worktrees/ # Output of running go build cmd/pscale/main.go main diff --git a/go.mod b/go.mod index 37d54b2b..c2f2414a 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,11 @@ require ( github.com/adrg/xdg v0.5.3 github.com/benbjohnson/clock v1.3.5 github.com/briandowns/spinner v1.23.2 + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 + github.com/charmbracelet/bubbletea v1.3.10 github.com/charmbracelet/huh v1.0.0 + github.com/charmbracelet/lipgloss v1.1.0 + github.com/charmbracelet/x/ansi v0.11.2 github.com/fatih/color v1.19.0 github.com/frankban/quicktest v1.14.6 github.com/go-sql-driver/mysql v1.9.3 @@ -24,6 +28,7 @@ require ( github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-shellwords v1.0.12 github.com/mitchellh/go-homedir v1.1.0 + github.com/muesli/termenv v0.16.0 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/planetscale/planetscale-go v0.168.1 github.com/planetscale/psdb v0.0.0-20250717190954-65c6661ab6e4 @@ -53,11 +58,7 @@ require ( github.com/atotto/clipboard v0.1.4 // indirect github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect github.com/catppuccin/go v0.3.0 // indirect - github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect - github.com/charmbracelet/bubbletea v1.3.10 // indirect github.com/charmbracelet/colorprofile v0.3.3 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect - github.com/charmbracelet/x/ansi v0.11.2 // indirect github.com/charmbracelet/x/cellbuf v0.0.14 // indirect github.com/charmbracelet/x/exp/strings v0.0.0-20251201173703-9f73bfd934ff // indirect github.com/charmbracelet/x/term v0.2.2 // indirect @@ -93,7 +94,6 @@ require ( github.com/mtibben/percent v0.2.1 // indirect github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect github.com/muesli/cancelreader v0.2.2 // indirect - github.com/muesli/termenv v0.16.0 // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect github.com/pires/go-proxyproto v0.8.1 // indirect github.com/planetscale/vitess-types v0.0.0-20250728133330-81b28fd54ee5 // indirect diff --git a/internal/cmd/branch/branch.go b/internal/cmd/branch/branch.go index ba767c09..9eb81758 100644 --- a/internal/cmd/branch/branch.go +++ b/internal/cmd/branch/branch.go @@ -35,6 +35,7 @@ func BranchCmd(ch *cmdutil.Helper) *cobra.Command { cmd.AddCommand(RoutingRulesCmd(ch)) cmd.AddCommand(SafeMigrationsCmd(ch)) cmd.AddCommand(LintCmd(ch)) + cmd.AddCommand(ConnectionsCmd(ch)) cmd.AddCommand(ProcesslistCmd(ch)) cmd.AddCommand(vtctld.VtctldCmd(ch)) cmd.AddCommand(InfraCmd(ch)) diff --git a/internal/cmd/branch/connections.go b/internal/cmd/branch/connections.go new file mode 100644 index 00000000..70bef7e1 --- /dev/null +++ b/internal/cmd/branch/connections.go @@ -0,0 +1,40 @@ +package branch + +import ( + "github.com/planetscale/cli/internal/cmd/branch/connections" + "github.com/planetscale/cli/internal/cmdutil" + "github.com/spf13/cobra" +) + +// ConnectionsCmd manages branch connections across supported database engines. +func ConnectionsCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "connections ", + Short: "Show and kill branch connections", + Long: `Show and kill branch connections. + +Agent workflow: + 1. Run: pscale branch connections show --format json + 2. Inspect query_id, transaction_id, and connection_id from the selected row. + 3. Explain the proposed action and wait for user approval before running it. + 4. Run exactly one action command with the matching ID. + 5. Run show again to verify the result. + +Action semantics: + kill --query Cancels the listed query_id. + kill-transaction + Postgres only. destructive. Terminates the listed transaction_id if it still matches server state. + kill destructive. Terminates the listed connection_id. + +Use --format json when an agent or script needs to inspect query_id, +transaction_id, and connection_id fields. Human output uses vertical records so +query text and action IDs are not truncated.`, + } + + cmd.AddCommand(ConnectionsShowCmd(ch)) + cmd.AddCommand(ConnectionsKillCmd(ch)) + cmd.AddCommand(ConnectionsKillTransactionCmd(ch)) + cmd.AddCommand(connections.TopCmd(ch)) + + return cmd +} diff --git a/internal/cmd/branch/connections/actions.go b/internal/cmd/branch/connections/actions.go new file mode 100644 index 00000000..4e457c08 --- /dev/null +++ b/internal/cmd/branch/connections/actions.go @@ -0,0 +1,162 @@ +package connections + +import ( + "context" + "errors" + "strings" + + "github.com/planetscale/cli/internal/cmdutil" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/printer" + ps "github.com/planetscale/planetscale-go/planetscale" +) + +type actionResult struct { + Success bool `csv:"success" header:"success" json:"success"` + Keyspace string `csv:"keyspace" header:"keyspace" json:"keyspace,omitempty"` + Shard string `csv:"shard" header:"shard" json:"shard,omitempty"` + Tablet string `csv:"tablet" header:"tablet" json:"tablet,omitempty"` + ID int64 `csv:"id" header:"id,text" json:"id,omitempty"` + Kind string `csv:"kind" header:"kind" json:"kind,omitempty"` +} + +func (a *actionResult) MarshalCSVValue() interface{} { + return []*actionResult{a} +} + +type compactActionResult struct { + Success bool `csv:"success" header:"success" json:"success"` + ID int64 `csv:"id" header:"id,text" json:"id,omitempty"` + Kind string `csv:"kind" header:"kind" json:"kind,omitempty"` +} + +func (a *compactActionResult) MarshalCSVValue() interface{} { + return []*compactActionResult{a} +} + +func toActionResult(result live.ActionResult) *actionResult { + return &actionResult{ + Success: result.Success, + Keyspace: result.Keyspace, + Shard: result.Shard, + Tablet: result.Tablet, + ID: result.ID, + Kind: result.Kind, + } +} + +func toCompactActionResult(result live.ActionResult) *compactActionResult { + return &compactActionResult{ + Success: result.Success, + ID: result.ID, + Kind: result.Kind, + } +} + +// RunCancelQuery cancels the active query identified by a live connection query ID. +func RunCancelQuery(ctx context.Context, ch *cmdutil.Helper, database, branch, queryID string, target ConnectionTarget) error { + return RunCancelQueryForEngine(ctx, ch, database, branch, queryID, ps.DatabaseEngineMySQL, target) +} + +// RunCancelQueryForEngine cancels the active query and prints output for the resolved database engine. +func RunCancelQueryForEngine(ctx context.Context, ch *cmdutil.Helper, database, branch, queryID string, engine ps.DatabaseEngine, target ConnectionTarget) error { + return runAction(ctx, ch, database, branch, "query-id", queryID, target, func(ctx context.Context, client *live.Client, id string) (live.ActionResult, error) { + return client.CancelQueryResult(ctx, live.ActionTarget{QueryID: &id}) + }, engine) +} + +// RunKillTransaction terminates the connection identified by a live connection transaction ID. +func RunKillTransaction(ctx context.Context, ch *cmdutil.Helper, database, branch, transactionID string, target ConnectionTarget) error { + return RunKillTransactionForEngine(ctx, ch, database, branch, transactionID, ps.DatabaseEnginePostgres, target) +} + +// RunKillTransactionForEngine terminates a transaction and prints output for the resolved database engine. +func RunKillTransactionForEngine(ctx context.Context, ch *cmdutil.Helper, database, branch, transactionID string, engine ps.DatabaseEngine, target ConnectionTarget) error { + return runAction(ctx, ch, database, branch, "transaction-id", transactionID, target, func(ctx context.Context, client *live.Client, id string) (live.ActionResult, error) { + return client.TerminateTransactionResult(ctx, live.ActionTarget{TransactionID: &id}) + }, engine) +} + +// RunKillConnection terminates the connection identified by a live connection_id. +func RunKillConnection(ctx context.Context, ch *cmdutil.Helper, database, branch, connectionID string, target ConnectionTarget) error { + return RunKillConnectionForEngine(ctx, ch, database, branch, connectionID, ps.DatabaseEngineMySQL, target) +} + +// RunKillConnectionForEngine terminates a connection and prints output for the resolved database engine. +func RunKillConnectionForEngine(ctx context.Context, ch *cmdutil.Helper, database, branch, connectionID string, engine ps.DatabaseEngine, target ConnectionTarget) error { + return runAction(ctx, ch, database, branch, "connection-id", connectionID, target, func(ctx context.Context, client *live.Client, id string) (live.ActionResult, error) { + return client.TerminateConnectionResult(ctx, live.ActionTarget{ConnectionID: &id}) + }, engine) +} + +func runAction(ctx context.Context, ch *cmdutil.Helper, database, branch, idName, id string, target ConnectionTarget, runAction func(context.Context, *live.Client, string) (live.ActionResult, error), engine ps.DatabaseEngine) error { + if err := validateActionID(idName, id); err != nil { + return err + } + id = strings.TrimSpace(id) + + client, err := newConnectionsClient(ch, database, branch, target) + if err != nil { + return err + } + + result, err := runAction(ctx, client, id) + if err != nil { + return err + } + return printActionResult(ch, result, engine, idName) +} + +// ValidateConnectionID checks the connection action identifier without making network calls. +func ValidateConnectionID(id string) error { + return validateActionID("connection-id", id) +} + +// ValidateQueryID checks the query action identifier without making network calls. +func ValidateQueryID(id string) error { + return validateActionID("query-id", id) +} + +// ValidateTransactionID checks the transaction action identifier without making network calls. +func ValidateTransactionID(id string) error { + return validateActionID("transaction-id", id) +} + +func validateActionID(idName, id string) error { + if strings.TrimSpace(id) == "" { + return errors.New(idName + " is required") + } + return nil +} + +func printActionResult(ch *cmdutil.Helper, result live.ActionResult, engine ps.DatabaseEngine, idName string) error { + if ch.Printer.Format() == printer.Human { + ch.Printer.Printf("%s.\n", actionResultMessage(result, idName)) + return nil + } + if ch.Printer.Format() == printer.JSON { + return ch.Printer.PrintResource(toActionResult(result)) + } + if engine == ps.DatabaseEnginePostgres { + return ch.Printer.PrintResource(toCompactActionResult(result)) + } + return ch.Printer.PrintResource(toActionResult(result)) +} + +func actionResultMessage(result live.ActionResult, idName string) string { + var message string + switch idName { + case "query-id": + message = "Cancelled query" + case "transaction-id": + message = "Killed transaction" + case "connection-id": + message = "Killed connection" + default: + message = "Action sent" + } + if result.Tablet != "" { + message += " on " + result.Tablet + } + return message +} diff --git a/internal/cmd/branch/connections/actions_test.go b/internal/cmd/branch/connections/actions_test.go new file mode 100644 index 00000000..7a4ce6bf --- /dev/null +++ b/internal/cmd/branch/connections/actions_test.go @@ -0,0 +1,249 @@ +package connections + +import ( + "bytes" + "context" + "io" + "net/http" + "net/http/httptest" + "testing" + + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + "github.com/planetscale/cli/internal/printer" +) + +func TestActionRunnersIssueDeletesWithExplicitIDs(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *cmdutil.Helper, string, string, string, ConnectionTarget) error + id string + wantPath string + wantQuery string + target ConnectionTarget + }{ + { + name: "cancel query", + run: RunCancelQuery, + id: "primary-123-q", + wantPath: "/v1/organizations/acme/databases/pgload/branches/main/connections/query/primary-123-q", + }, + { + name: "kill transaction", + run: RunKillTransaction, + id: "primary-123-t", + wantPath: "/v1/organizations/acme/databases/pgload/branches/main/connections/transaction/primary-123-t", + }, + { + name: "kill connection with target", + run: RunKillConnection, + id: "zone1-1001-101", + wantPath: "/v1/organizations/acme/databases/pgload/branches/main/connections/connection/zone1-1001-101", + wantQuery: "keyspace=commerce&shard=-80", + target: ConnectionTarget{Keyspace: "commerce", Shard: "-80"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, tt.wantPath) + c.Assert(r.URL.RawQuery, qt.Equals, tt.wantQuery) + w.WriteHeader(http.StatusNoContent) + })) + t.Cleanup(server.Close) + + var out bytes.Buffer + ch := connectionsTestHelper(server.URL, printer.Human, &out) + err := tt.run(context.Background(), ch, "pgload", "main", tt.id, tt.target) + + c.Assert(err, qt.IsNil) + }) + } +} + +func TestActionRunnersPrintResults(t *testing.T) { + tests := []struct { + name string + format printer.Format + run func(context.Context, *cmdutil.Helper, string, string, string, ConnectionTarget) error + response string + want []string + }{ + { + name: "json", + format: printer.JSON, + run: RunCancelQuery, + response: `{"success":true,"keyspace":"commerce","shard":"-80","tablet":"zone1-1001","id":101,"kind":"query"}`, + want: []string{`"success": true`, `"keyspace": "commerce"`, `"kind": "query"`}, + }, + { + name: "transaction", + format: printer.JSON, + run: RunKillTransaction, + response: `{"success":true,"id":101,"kind":"transaction"}`, + want: []string{`"success": true`, `"kind": "transaction"`}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, tt.response) + })) + t.Cleanup(server.Close) + + var out bytes.Buffer + ch := connectionsTestHelper(server.URL, tt.format, &out) + err := tt.run(context.Background(), ch, "pgload", "main", "zone1-1001-101", ConnectionTarget{Keyspace: "commerce", Shard: "-80"}) + + c.Assert(err, qt.IsNil) + for _, want := range tt.want { + c.Assert(out.String(), qt.Contains, want) + } + }) + } +} + +func TestActionRunnersPrintSentenceInHumanFormat(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *cmdutil.Helper, string, string, string, ConnectionTarget) error + response string + want string + }{ + { + name: "cancel query", + run: RunCancelQuery, + response: `{"success":true}`, + want: "Cancelled query.\n", + }, + { + name: "kill transaction", + run: RunKillTransaction, + response: `{"success":true}`, + want: "Killed transaction.\n", + }, + { + name: "kill connection", + run: RunKillConnection, + response: `{"success":true,"tablet":"zone1-1001"}`, + want: "Killed connection on zone1-1001.\n", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, tt.response) + })) + t.Cleanup(server.Close) + + var out bytes.Buffer + ch := connectionsTestHelper(server.URL, printer.Human, &out) + err := tt.run(context.Background(), ch, "pgload", "main", "id-1", ConnectionTarget{}) + + c.Assert(err, qt.IsNil) + c.Assert(out.String(), qt.Equals, tt.want) + }) + } +} + +func TestActionRunnersRejectMissingID(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *cmdutil.Helper, string, string, string, ConnectionTarget) error + wantErr string + }{ + {name: "cancel query", run: RunCancelQuery, wantErr: "query-id is required"}, + {name: "kill transaction", run: RunKillTransaction, wantErr: "transaction-id is required"}, + {name: "kill connection", run: RunKillConnection, wantErr: "connection-id is required"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + ch := connectionsTestHelper("http://example.invalid", printer.Human, &out) + + err := tt.run(context.Background(), ch, "pgload", "main", " ", ConnectionTarget{}) + + c.Assert(err, qt.ErrorMatches, tt.wantErr) + }) + } +} + +func TestActionRunnersSurfaceServerErrors(t *testing.T) { + tests := []struct { + name string + run func(context.Context, *cmdutil.Helper, string, string, string, ConnectionTarget) error + id string + statusCode int + body string + wantErr string + }{ + { + name: "kill transaction", + run: RunKillTransaction, + id: "primary-123-t", + statusCode: http.StatusUnprocessableEntity, + body: `{"code":"verification_mismatch","message":"connection no longer matches the expected snapshot"}`, + wantErr: "terminate transaction: connection no longer matches the expected snapshot", + }, + { + name: "cancel query", + run: RunCancelQuery, + id: "primary-123-q", + statusCode: http.StatusNotFound, + wantErr: "cancel query: query_id not found.*", + }, + { + name: "kill connection", + run: RunKillConnection, + id: "primary-123-c", + statusCode: http.StatusNotFound, + wantErr: "terminate connection: connection_id not found.*", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(tt.statusCode) + _, _ = io.WriteString(w, tt.body) + })) + t.Cleanup(server.Close) + + var out bytes.Buffer + ch := connectionsTestHelper(server.URL, printer.Human, &out) + err := tt.run(context.Background(), ch, "pgload", "main", tt.id, ConnectionTarget{}) + + c.Assert(err, qt.ErrorMatches, tt.wantErr) + }) + } +} + +func connectionsTestHelper(baseURL string, format printer.Format, out *bytes.Buffer) *cmdutil.Helper { + p := printer.NewPrinter(&format) + p.SetHumanOutput(out) + p.SetResourceOutput(out) + + return &cmdutil.Helper{ + Config: &config.Config{ + BaseURL: baseURL, + Organization: "acme", + ServiceTokenID: "tid", + ServiceToken: "secret", + }, + Printer: p, + } +} diff --git a/internal/cmd/branch/connections/client.go b/internal/cmd/branch/connections/client.go new file mode 100644 index 00000000..137c9538 --- /dev/null +++ b/internal/cmd/branch/connections/client.go @@ -0,0 +1,20 @@ +package connections + +import ( + "github.com/planetscale/cli/internal/cmdutil" + live "github.com/planetscale/cli/internal/connections" +) + +func newConnectionsClient(ch *cmdutil.Helper, database, branch string, target ConnectionTarget) (*live.Client, error) { + return live.NewClient(live.ClientConfig{ + BaseURL: ch.Config.BaseURL, + Organization: ch.Config.Organization, + Database: database, + Branch: branch, + Keyspace: target.Keyspace, + Shard: target.Shard, + AccessToken: ch.Config.AccessToken, + ServiceTokenID: ch.Config.ServiceTokenID, + ServiceToken: ch.Config.ServiceToken, + }) +} diff --git a/internal/cmd/branch/connections/engine_flags.go b/internal/cmd/branch/connections/engine_flags.go new file mode 100644 index 00000000..f4e06b39 --- /dev/null +++ b/internal/cmd/branch/connections/engine_flags.go @@ -0,0 +1,26 @@ +package connections + +import ( + "errors" + + ps "github.com/planetscale/planetscale-go/planetscale" +) + +// ValidateEngineFlags rejects flags that only apply to another database engine. +func ValidateEngineFlags(engine ps.DatabaseEngine, filter ConnectionFilter, target ConnectionTarget) error { + return validateEngineFlags(engine, filter.connectionFilter(), target) +} + +func validateEngineFlags(engine ps.DatabaseEngine, filter connectionFilter, target ConnectionTarget) error { + switch engine { + case ps.DatabaseEnginePostgres: + if target.Keyspace != "" || target.Shard != "" { + return errors.New("--keyspace/--shard are only supported for Vitess databases") + } + case ps.DatabaseEngineMySQL: + if filter.active() { + return errors.New("--instance/--role are only supported for Postgres databases") + } + } + return nil +} diff --git a/internal/cmd/branch/connections/headless.go b/internal/cmd/branch/connections/headless.go new file mode 100644 index 00000000..825ee4b5 --- /dev/null +++ b/internal/cmd/branch/connections/headless.go @@ -0,0 +1,78 @@ +package connections + +import ( + "context" + "errors" + "time" + + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" +) + +// runHeadlessCapture writes captures through writer until duration elapses or +// ctx is canceled. A non-positive duration runs continuously until ctx fires. +func runHeadlessCapture(ctx context.Context, client clientInterface, writer captureWriter, duration, interval time.Duration) (err error) { + defer func() { + err = errors.Join(err, writer.Close()) + }() + + if duration > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, duration) + defer cancel() + } + + captured := false + finalize := func() error { + if duration <= 0 { + return nil + } + if captured { + return nil + } + return errors.New("capture produced no samples") + } + + list, err := client.List(ctx, live.SortByTransactionStart) + if err != nil { + if ctx.Err() != nil { + return finalize() + } + return err + } + if err := writer.Write(history.NewCapture(list)); err != nil { + return err + } + captured = true + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return finalize() + case <-ticker.C: + list, err := client.List(ctx, live.SortByTransactionStart) + if err != nil { + if ctx.Err() != nil { + return finalize() + } + return err + } + if err := writer.Write(history.NewCapture(list)); err != nil { + return err + } + captured = true + } + } +} + +type clientInterface interface { + List(context.Context, live.SortMode) (live.ConnectionList, error) +} + +type captureWriter interface { + Write(history.Capture) error + Close() error +} diff --git a/internal/cmd/branch/connections/headless_test.go b/internal/cmd/branch/connections/headless_test.go new file mode 100644 index 00000000..34d6877d --- /dev/null +++ b/internal/cmd/branch/connections/headless_test.go @@ -0,0 +1,122 @@ +package connections + +import ( + "bytes" + "context" + "errors" + "strings" + "testing" + "time" + + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" +) + +type stubClientInterface struct { + lists []live.ConnectionList + cursor int + err error +} + +func (s *stubClientInterface) List(ctx context.Context, sort live.SortMode) (live.ConnectionList, error) { + if s.err != nil { + return live.ConnectionList{}, s.err + } + if s.cursor >= len(s.lists) { + <-ctx.Done() + return live.ConnectionList{}, ctx.Err() + } + list := s.lists[s.cursor] + s.cursor++ + return list, nil +} + +type discardWriter struct{} + +func (discardWriter) Write(history.Capture) error { return nil } +func (discardWriter) Close() error { return nil } + +func TestRunHeadlessCaptureContinuousModeReturnsNilOnContextCancel(t *testing.T) { + src := &stubClientInterface{lists: []live.ConnectionList{{CapturedAt: time.UnixMilli(1)}}} + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond) + defer cancel() + + err := runHeadlessCapture(ctx, src, discardWriter{}, 0, time.Hour) + + if err != nil { + t.Fatalf("runHeadlessCapture: %v", err) + } +} + +func TestRunHeadlessCaptureNoSamplesForFiniteDuration(t *testing.T) { + src := &stubClientInterface{} + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel() + + err := runHeadlessCapture(ctx, src, discardWriter{}, 10*time.Millisecond, time.Hour) + + if err == nil || !strings.Contains(err.Error(), "capture produced no samples") { + t.Fatalf("err = %v, want no samples error", err) + } +} + +func TestRunHeadlessCaptureReturnsFirstSourceError(t *testing.T) { + sourceErr := errors.New("source failed") + src := &stubClientInterface{err: sourceErr} + + err := runHeadlessCapture(context.Background(), src, discardWriter{}, time.Second, time.Hour) + + if !errors.Is(err, sourceErr) { + t.Fatalf("err = %v, want source error", err) + } +} + +func TestRunHeadlessCaptureJoinsCloseError(t *testing.T) { + closeErr := errors.New("close failed") + src := &stubClientInterface{lists: []live.ConnectionList{{CapturedAt: time.UnixMilli(1)}}} + writer := &countingWriter{closeErr: closeErr} + + err := runHeadlessCapture(context.Background(), src, writer, 1*time.Millisecond, time.Hour) + + if !errors.Is(err, closeErr) { + t.Fatalf("err = %v, want close error", err) + } +} + +func TestRunHeadlessCaptureWritesInstancesMetadata(t *testing.T) { + list := live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC), + Connections: []live.Connection{}, + Sort: live.SortByTransactionStart, + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + }, + } + src := &stubClientInterface{lists: []live.ConnectionList{list}} + + var buf bytes.Buffer + writer := history.NewCaptureWriter(&buf) + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _ = runHeadlessCapture(ctx, src, writer, 0, 10*time.Millisecond) + + if !strings.Contains(buf.String(), `"instances":[{"id":"primary","role":"primary"}]`) { + t.Fatalf("instances not found in capture output:\n%s", buf.String()) + } +} + +type countingWriter struct { + writes int + closeErr error +} + +func (w *countingWriter) Write(history.Capture) error { + w.writes++ + return nil +} + +func (w *countingWriter) Close() error { + return w.closeErr +} diff --git a/internal/cmd/branch/connections/replay.go b/internal/cmd/branch/connections/replay.go new file mode 100644 index 00000000..c3840238 --- /dev/null +++ b/internal/cmd/branch/connections/replay.go @@ -0,0 +1,39 @@ +package connections + +import ( + "context" + "errors" + + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" +) + +// replayClient adapts a *history.ReplaySource to the tui.ConnectionsClient +// interface so the TUI can render captured snapshots. Action methods all +// reject with the same operator-visible message so c/k/K surface as errors in +// the footer rather than being dispatched to the real wire. +type replayClient struct { + source *history.ReplaySource +} + +func newReplayClient(source *history.ReplaySource) *replayClient { + return &replayClient{source: source} +} + +func (r *replayClient) List(ctx context.Context, mode live.SortMode) (live.ConnectionList, error) { + return r.source.List(ctx, mode) +} + +func (r *replayClient) CancelQuery(context.Context, live.ActionTarget) error { + return errReplayActionRejected +} + +func (r *replayClient) TerminateTransaction(context.Context, live.ActionTarget) error { + return errReplayActionRejected +} + +func (r *replayClient) TerminateConnection(context.Context, live.ActionTarget) error { + return errReplayActionRejected +} + +var errReplayActionRejected = errors.New("not available in replay mode") diff --git a/internal/cmd/branch/connections/replay_test.go b/internal/cmd/branch/connections/replay_test.go new file mode 100644 index 00000000..685c1cd3 --- /dev/null +++ b/internal/cmd/branch/connections/replay_test.go @@ -0,0 +1,392 @@ +package connections + +import ( + "bytes" + "context" + "os" + "path/filepath" + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" + "github.com/planetscale/cli/internal/connections/tui" +) + +func TestReplayClientListReturnsCapturedSnapshot(t *testing.T) { + c := qt.New(t) + source := newReplaySourceFromFixture(t, []live.Connection{{PID: 42, Instance: "primary"}}) + + client := newReplayClient(source) + list, err := client.List(context.Background(), live.SortByTransactionStart) + + c.Assert(err, qt.IsNil) + c.Assert(list.Connections[0].PID, qt.Equals, 42) +} + +func TestReplayClientRejectsAllActions(t *testing.T) { + c := qt.New(t) + source := newReplaySourceFromFixture(t, []live.Connection{{PID: 10}}) + client := newReplayClient(source) + target := live.ActionTarget{Instance: "primary", PID: 10} + + c.Assert(client.CancelQuery(context.Background(), target), qt.ErrorMatches, "not available in replay mode") + c.Assert(client.TerminateTransaction(context.Background(), target), qt.ErrorMatches, "not available in replay mode") + c.Assert(client.TerminateConnection(context.Background(), target), qt.ErrorMatches, "not available in replay mode") +} + +func TestTopCmdReplayRejectsMissingFile(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", "/nonexistent/path/trace.jsonl"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--replay: .*no such file.*") +} + +func TestTopCmdReplayRejectsCombinationWithCapture(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + tmp := writeReplayFixture(t, []live.Connection{{PID: 10, Instance: "primary"}}) + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp, "--capture", filepath.Join(filepath.Dir(tmp), "out.jsonl")}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--replay cannot be combined with --capture") +} + +// --duration kills the TUI after wall-clock elapses regardless of paused +// state, so combining it with replay would silently dismiss the operator +// mid-step. Reject explicitly the same way --capture is rejected. +func TestTopCmdReplayRejectsCombinationWithDuration(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + tmp := writeReplayFixture(t, []live.Connection{{PID: 10, Instance: "primary"}}) + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp, "--duration", "30s"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--duration cannot be combined with --replay") +} + +func TestTopCmdReplayRejectsWithoutTTY(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + tmp := writeReplayFixture(t, []live.Connection{{PID: 10, Instance: "primary"}}) + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--replay requires an interactive terminal") +} + +func TestTopCmdReplayHappyPathRunsWithoutLiveClient(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + return nil + }) + defer restoreProgram() + tmp := writeReplayFixture(t, []live.Connection{{PID: 99, Instance: "primary"}}) + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) +} + +func TestTopCmdReplayPreloadsAllCapturesAndStartsPaused(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + var capturedModel tea.Model + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + capturedModel = model + return nil + }) + defer restoreProgram() + + tmp := writeMultiSampleReplayFixture(t, 3) + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + c.Assert(cmd.Execute(), qt.IsNil) + c.Assert(capturedModel, qt.Not(qt.IsNil)) + + m := capturedModel.(tui.Model) + sized, _ := m.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + c.Assert(sized.(tui.Model).View(), qt.Contains, "paused") + + // Jumping to the oldest sample proves all 3 captures were preloaded into + // history and that the cursor isn't pinned at the latest. + updated, _ := sized.(tui.Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("{")}) + c.Assert(updated.(tui.Model).View(), qt.Contains, "step 1/3") +} + +func TestTopCmdReplayUsesPostgresViewForPostgreSQLTrace(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + var capturedModel tea.Model + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + capturedModel = model + return nil + }) + defer restoreProgram() + + tmp := writeReplayFixture(t, []live.Connection{{ + PID: 99, + Instance: "primary", + State: "active", + QueryText: "SELECT pg_sleep(1)", + BlockedBy: []int{42}, + Duration: time.Second, + WaitEvent: "ClientRead", + XactStart: ptrTime(time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC)), + QueryStart: ptrTime(time.Date(2026, 4, 28, 14, 59, 30, 0, time.UTC)), + }}) + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + c.Assert(cmd.Execute(), qt.IsNil) + c.Assert(capturedModel, qt.Not(qt.IsNil)) + + sized, _ := capturedModel.(tui.Model).Update(tea.WindowSizeMsg{Width: 160, Height: 24}) + view := sized.(tui.Model).View() + + c.Assert(view, qt.Contains, "sort xact_start") + c.Assert(view, qt.Contains, "BLOCK") + c.Assert(view, qt.Contains, "WAIT") + c.Assert(view, qt.Contains, "SELECT pg_sleep(1)") + c.Assert(view, qt.Not(qt.Contains), "TABLET") + c.Assert(view, qt.Not(qt.Contains), "DB") +} + +func TestTopCmdReplayUsesVitessViewForMySQLTrace(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + var capturedModel tea.Model + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + capturedModel = model + return nil + }) + defer restoreProgram() + + tmp := writeVitessReplayFixture(t) + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + c.Assert(cmd.Execute(), qt.IsNil) + c.Assert(capturedModel, qt.Not(qt.IsNil)) + + sized, _ := capturedModel.(tui.Model).Update(tea.WindowSizeMsg{Width: 160, Height: 24}) + view := sized.(tui.Model).View() + + c.Assert(view, qt.Contains, "sorted by duration") + c.Assert(view, qt.Contains, "TABLET") + c.Assert(view, qt.Contains, "DB") + c.Assert(view, qt.Contains, "SELECT 1") + c.Assert(view, qt.Not(qt.Contains), "BLOCK") + c.Assert(view, qt.Not(qt.Contains), "WAIT") +} + +func TestReplayHeaderShowsTraceTarget(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + var capturedModel tea.Model + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + capturedModel = model + return nil + }) + defer restoreProgram() + + tmp := writeVitessReplayFixtureWithTarget(t) + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + c.Assert(cmd.Execute(), qt.IsNil) + c.Assert(capturedModel, qt.Not(qt.IsNil)) + + sized, _ := capturedModel.(tui.Model).Update(tea.WindowSizeMsg{Width: 160, Height: 24}) + view := sized.(tui.Model).View() + + c.Assert(view, qt.Contains, "kind-live-connections-mysql / main / commerce / -80") +} + +func writeMultiSampleReplayFixture(t *testing.T, count int) string { + t.Helper() + path := filepath.Join(t.TempDir(), "multi.jsonl") + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + t.Fatal(err) + } + writer := history.NewCaptureWriter(file) + at := time.Date(2026, 5, 27, 12, 0, 0, 0, time.UTC) + for i := 0; i < count; i++ { + list := live.NewConnectionList(at.Add(time.Duration(i)*time.Second), []live.Connection{{ + PID: 100 + i, Instance: "primary", + }}, live.SortByTransactionStart) + if err := writer.Write(history.NewCapture(list)); err != nil { + t.Fatal(err) + } + } + if err := file.Close(); err != nil { + t.Fatal(err) + } + return path +} + +func writeVitessReplayFixtureWithTarget(t *testing.T) string { + t.Helper() + path := filepath.Join(t.TempDir(), "vitess-target.jsonl") + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + t.Fatal(err) + } + writer := history.NewCaptureWriter(file) + if err := writer.WriteCaptureStart(history.CaptureStart{ + At: time.Date(2026, 6, 4, 12, 29, 59, 0, time.UTC), + Database: "kind-live-connections-mysql", + Branch: "main", + Target: &history.CaptureTarget{ + Keyspace: "commerce", + Shard: "-80", + }, + }); err != nil { + t.Fatal(err) + } + connectionID := "zone1-1001-101" + queryID := "zone1-1001-101" + list := live.NewConnectionList(time.Date(2026, 6, 4, 12, 30, 0, 0, time.UTC), []live.Connection{{ + PID: 101, + Instance: "zone1-1001", + State: "Query/executing", + Duration: 42 * time.Second, + Username: "vt_app", + DatabaseName: "checkout", + QueryText: "SELECT 1", + ConnectionID: &connectionID, + QueryID: &queryID, + }}, live.SortByDuration) + list.DatabaseKind = live.DatabaseKindMySQL + if err := writer.Write(history.NewCapture(list)); err != nil { + t.Fatal(err) + } + if err := file.Close(); err != nil { + t.Fatal(err) + } + return path +} + +func writeVitessReplayFixture(t *testing.T) string { + t.Helper() + path := filepath.Join(t.TempDir(), "vitess.jsonl") + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + t.Fatal(err) + } + writer := history.NewCaptureWriter(file) + connectionID := "zone1-1001-101" + queryID := "zone1-1001-101" + list := live.NewConnectionList(time.Date(2026, 6, 4, 12, 30, 0, 0, time.UTC), []live.Connection{{ + PID: 101, + Instance: "zone1-1001", + State: "Query/executing", + Duration: 42 * time.Second, + Username: "vt_app", + DatabaseName: "checkout", + QueryText: "SELECT 1", + ConnectionID: &connectionID, + QueryID: &queryID, + }}, live.SortByDuration) + list.DatabaseKind = live.DatabaseKindMySQL + list.Topology = &live.Topology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"} + if err := writer.Write(history.NewCapture(list)); err != nil { + t.Fatal(err) + } + if err := file.Close(); err != nil { + t.Fatal(err) + } + return path +} + +func TestTopCmdReplayRejectsEmptyTraceFile(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + tmp := filepath.Join(t.TempDir(), "empty.jsonl") + if err := os.WriteFile(tmp, []byte{}, 0o600); err != nil { + t.Fatal(err) + } + + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{}}) + cmd.SetArgs([]string{"--replay", tmp}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--replay: capture file contains no replayable snapshots") +} + +func writeReplayFixture(t *testing.T, connections []live.Connection) string { + t.Helper() + path := filepath.Join(t.TempDir(), "trace.jsonl") + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + t.Fatal(err) + } + writer := history.NewCaptureWriter(file) + list := live.NewConnectionList(time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC), connections, live.SortByTransactionStart) + list.DatabaseKind = live.DatabaseKindPostgreSQL + if err := writer.Write(history.NewCapture(list)); err != nil { + t.Fatal(err) + } + if err := file.Close(); err != nil { + t.Fatal(err) + } + return path +} + +func ptrTime(t time.Time) *time.Time { + return &t +} + +func newReplaySourceFromFixture(t *testing.T, connections []live.Connection) *history.ReplaySource { + t.Helper() + var buffer bytes.Buffer + writer := history.NewCaptureWriter(&buffer) + list := live.NewConnectionList(time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC), connections, live.SortByTransactionStart) + if err := writer.Write(history.NewCapture(list)); err != nil { + t.Fatal(err) + } + source, err := history.NewReplaySource(&buffer) + if err != nil { + t.Fatal(err) + } + return source +} diff --git a/internal/cmd/branch/connections/show.go b/internal/cmd/branch/connections/show.go new file mode 100644 index 00000000..10b700c0 --- /dev/null +++ b/internal/cmd/branch/connections/show.go @@ -0,0 +1,280 @@ +package connections + +import ( + "context" + "fmt" + "io" + "strings" + "time" + + "github.com/planetscale/cli/internal/cmdutil" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/printer" +) + +// ConnectionFilter filters live connection rows by instance or instance role. +type ConnectionFilter struct { + Instance string + Role string +} + +// ConnectionTarget identifies an optional Vitess target for connection commands. +type ConnectionTarget struct { + Keyspace string + Shard string +} + +func (f ConnectionFilter) connectionFilter() connectionFilter { + return connectionFilter{instance: f.Instance, role: f.Role} +} + +// ValidateConnectionFilter validates mutually exclusive live connection filters. +func ValidateConnectionFilter(filter ConnectionFilter) error { + return validateConnectionFilter(filter.Instance, filter.Role) +} + +// RunList fetches and prints one live connection list. +func RunList(ctx context.Context, ch *cmdutil.Helper, database, branch string, filter ConnectionFilter, target ConnectionTarget) error { + if err := ValidateConnectionFilter(filter); err != nil { + return err + } + return runList(ctx, ch, database, branch, filter.connectionFilter(), target) +} + +func runList(ctx context.Context, ch *cmdutil.Helper, database, branch string, filter connectionFilter, target ConnectionTarget) error { + client, err := newConnectionsClient(ch, database, branch, target) + if err != nil { + return err + } + + list, err := filteredLister{client: client, filter: filter}.List(ctx, live.SortByTransactionStart) + if err != nil { + return live.UserFacingError(err, "view") + } + sortListForDisplay(&list) + + return PrintList(ch, list, ListTopology{}) +} + +// PrintList prints a live connection list in the configured CLI format. +func PrintList(ch *cmdutil.Helper, list live.ConnectionList, topology ListTopology) error { + topology = resolveListTopology(list, topology) + + if ch.Printer.Format() == printer.Human { + var out strings.Builder + printHumanConnectionList(&out, list, topology) + ch.Printer.Print(out.String()) + return nil + } + + return ch.Printer.PrintResource(toPrintableList(list, topology)) +} + +type printableList struct { + DatabaseKind live.DatabaseKind `json:"database_kind,omitempty"` + CapturedAt time.Time `json:"captured_at"` + Topology *ListTopology `json:"topology,omitempty"` + Instances []printableInstance `json:"instances"` + Connections []printableConnection `json:"connections"` +} + +type printableInstance struct { + ID string `json:"id"` + Role string `json:"role"` + Error string `json:"error,omitempty"` +} + +type printableConnection struct { + PID int `csv:"pid" header:"pid" json:"pid"` + Instance string `csv:"instance" header:"instance" json:"instance"` + InstanceRole string `csv:"role" header:"role" json:"instance_role"` + State string `csv:"state" header:"state" json:"state"` + DurationMS int64 `csv:"duration_ms" header:"duration_ms" json:"duration_ms"` + WaitEventType string `csv:"wait_event_type" header:"wait_event_type" json:"wait_event_type"` + WaitEvent string `csv:"wait_event" header:"wait_event" json:"wait_event"` + Username string `csv:"username" header:"username" json:"username"` + ApplicationName string `csv:"application_name" header:"application_name" json:"application_name"` + DatabaseName string `csv:"-" json:"database,omitempty"` + ClientAddr string `csv:"client_addr" header:"client_addr" json:"client_addr"` + QueryText string `csv:"query_text" header:"query_text" json:"query_text"` + BlockedBy []int `csv:"blocked_by" header:"blocked_by" json:"blocked_by"` + QueryID *string `csv:"query_id" header:"query_id" json:"query_id"` + TransactionID *string `csv:"transaction_id" header:"transaction_id" json:"transaction_id"` + ConnectionID *string `csv:"connection_id" header:"connection_id" json:"connection_id"` +} + +func toPrintableList(list live.ConnectionList, topology ListTopology) printableList { + out := printableList{ + DatabaseKind: list.DatabaseKind, + CapturedAt: list.CapturedAt, + Instances: toPrintableInstances(list.Instances), + Connections: toPrintableConnections(list), + } + if !topology.isEmpty() { + out.Topology = &topology + } + return out +} + +func toPrintableInstances(instanceList []live.InstanceMeta) []printableInstance { + instances := make([]printableInstance, 0, len(instanceList)) + for _, instance := range instanceList { + instances = append(instances, printableInstance{ + ID: instance.ID, + Role: instance.Role, + Error: instance.Error, + }) + } + return instances +} + +func (p printableList) MarshalCSVValue() interface{} { + if p.Topology != nil { + return p.connectionsWithTopology() + } + if p.DatabaseKind == live.DatabaseKindMySQL || p.hasDatabaseName() { + return p.connectionsWithDatabase() + } + return p.Connections +} + +func (p printableList) hasDatabaseName() bool { + for _, conn := range p.Connections { + if conn.DatabaseName != "" { + return true + } + } + return false +} + +func toPrintableConnections(list live.ConnectionList) []printableConnection { + connections := make([]printableConnection, 0, len(list.Connections)) + for _, conn := range list.Connections { + connections = append(connections, printableConnection{ + PID: conn.PID, + Instance: conn.Instance, + InstanceRole: conn.InstanceRole, + State: conn.State, + DurationMS: conn.Duration.Milliseconds(), + WaitEventType: conn.WaitEventType, + WaitEvent: conn.WaitEvent, + Username: conn.Username, + ApplicationName: conn.ApplicationName, + DatabaseName: conn.DatabaseName, + ClientAddr: conn.ClientAddr, + QueryText: conn.QueryText, + BlockedBy: printableBlockedBy(conn.BlockedBy), + QueryID: conn.QueryID, + TransactionID: conn.TransactionID, + ConnectionID: conn.ConnectionID, + }) + } + return connections +} + +func printableBlockedBy(blockedBy []int) []int { + // Normalize an absent blocker set to [] so the agent JSON always sees an + // array for blocked_by, consistent with instances. The wire field is + // omitempty, so an unblocked connection decodes to a nil slice. + if blockedBy == nil { + return []int{} + } + return blockedBy +} + +func printHumanConnectionList(out io.Writer, list live.ConnectionList, topology ListTopology) { + fmt.Fprintf(out, "captured_at: %s\n", list.CapturedAt.Format(time.RFC3339)) + vitess := list.DatabaseKind == live.DatabaseKindMySQL || !topology.isEmpty() + if !topology.isEmpty() { + fmt.Fprintln(out, "topology:") + fmt.Fprintf(out, " keyspace: %s\n", topology.Keyspace) + fmt.Fprintf(out, " shard: %s\n", topology.Shard) + fmt.Fprintf(out, " tablet: %s\n", topology.Tablet) + } + if warning := unreachableInstanceWarning(list.Instances); warning != "" { + fmt.Fprintln(out, warning) + } + + if len(list.Connections) == 0 { + fmt.Fprintln(out, "No live connections found.") + return + } + + for i, conn := range list.Connections { + if i > 0 { + fmt.Fprintln(out) + } + + fmt.Fprintf(out, "*************************** %d. row ***************************\n", i+1) + fields := conn.HumanFields() + if vitess { + fields = vitessHumanFields(conn) + } + for _, field := range fields { + writeHumanField(out, field[0], field[1]) + } + if conn.DatabaseName != "" { + writeHumanField(out, "database", conn.DatabaseName) + } + fmt.Fprintln(out, "query:") + fmt.Fprintln(out, conn.QueryText) + } +} + +func vitessHumanFields(conn live.Connection) [][2]string { + fields := conn.HumanFields() + out := make([][2]string, 0, len(fields)-1) + for _, field := range fields { + switch field[0] { + case "instance": + out = append(out, [2]string{"tablet", field[1]}) + case "role": + default: + out = append(out, field) + } + } + return out +} + +func writeHumanField(out io.Writer, name, value string) { + fmt.Fprintf(out, "%-16s %s\n", name+":", value) +} + +func unreachableInstanceWarning(instances []live.InstanceMeta) string { + var unreachable []string + for _, instance := range instances { + if instance.Error != "" { + unreachable = append(unreachable, instance.ID) + } + } + if len(unreachable) == 0 { + return "" + } + return "warning: partial results, unreachable instances: " + strings.Join(unreachable, ", ") +} + +func validateInstanceFilter(list live.ConnectionList, filter connectionFilter) error { + if filter.instance == "" { + return nil + } + + valid := make([]string, 0, len(list.Instances)) + for _, instance := range list.Instances { + valid = append(valid, instance.ID) + if instance.ID == filter.instance { + return nil + } + } + + return &live.UnknownInstanceError{Instance: filter.instance, Valid: valid} +} + +// sortListForDisplay orders Vitess processlist rows longest-running first so +// one-shot output matches the TUI's duration ordering. Postgres lists keep the +// server's transaction-start ordering. +func sortListForDisplay(list *live.ConnectionList) { + if list.DatabaseKind != live.DatabaseKindMySQL { + return + } + live.SortConnections(list.Connections, live.SortByDuration) +} diff --git a/internal/cmd/branch/connections/show_metadata.go b/internal/cmd/branch/connections/show_metadata.go new file mode 100644 index 00000000..b36ba3d5 --- /dev/null +++ b/internal/cmd/branch/connections/show_metadata.go @@ -0,0 +1,114 @@ +package connections + +import live "github.com/planetscale/cli/internal/connections" + +// ListTopology describes the Vitess tablet selected for a live connection list. +type ListTopology struct { + Keyspace string `json:"keyspace,omitempty"` + Shard string `json:"shard,omitempty"` + Tablet string `json:"tablet,omitempty"` +} + +func (t ListTopology) isEmpty() bool { + return t.Keyspace == "" && t.Shard == "" && t.Tablet == "" +} + +func resolveListTopology(list live.ConnectionList, topology ListTopology) ListTopology { + if !topology.isEmpty() || list.Topology == nil { + return topology + } + return ListTopology{ + Keyspace: list.Topology.Keyspace, + Shard: list.Topology.Shard, + Tablet: list.Topology.Tablet, + } +} + +type printableCSVConnectionWithDatabase struct { + PID int `csv:"pid" header:"pid" json:"pid"` + Instance string `csv:"instance" header:"instance" json:"instance"` + InstanceRole string `csv:"role" header:"role" json:"instance_role"` + State string `csv:"state" header:"state" json:"state"` + DurationMS int64 `csv:"duration_ms" header:"duration_ms" json:"duration_ms"` + WaitEventType string `csv:"wait_event_type" header:"wait_event_type" json:"wait_event_type"` + WaitEvent string `csv:"wait_event" header:"wait_event" json:"wait_event"` + Username string `csv:"username" header:"username" json:"username"` + ApplicationName string `csv:"application_name" header:"application_name" json:"application_name"` + DatabaseName string `csv:"database" header:"database" json:"database,omitempty"` + ClientAddr string `csv:"client_addr" header:"client_addr" json:"client_addr"` + QueryText string `csv:"query_text" header:"query_text" json:"query_text"` + BlockedBy []int `csv:"blocked_by" header:"blocked_by" json:"blocked_by"` + QueryID *string `csv:"query_id" header:"query_id" json:"query_id"` + TransactionID *string `csv:"transaction_id" header:"transaction_id" json:"transaction_id"` + ConnectionID *string `csv:"connection_id" header:"connection_id" json:"connection_id"` +} + +type printableConnectionWithTopology struct { + Keyspace string `csv:"keyspace" header:"keyspace" json:"keyspace"` + Shard string `csv:"shard" header:"shard" json:"shard"` + Tablet string `csv:"tablet" header:"tablet" json:"tablet"` + PID int `csv:"pid" header:"pid" json:"pid"` + Instance string `csv:"instance" header:"instance" json:"instance"` + InstanceRole string `csv:"role" header:"role" json:"instance_role"` + State string `csv:"state" header:"state" json:"state"` + DurationMS int64 `csv:"duration_ms" header:"duration_ms" json:"duration_ms"` + WaitEventType string `csv:"wait_event_type" header:"wait_event_type" json:"wait_event_type"` + WaitEvent string `csv:"wait_event" header:"wait_event" json:"wait_event"` + Username string `csv:"username" header:"username" json:"username"` + ApplicationName string `csv:"application_name" header:"application_name" json:"application_name"` + DatabaseName string `csv:"database" header:"database" json:"database,omitempty"` + ClientAddr string `csv:"client_addr" header:"client_addr" json:"client_addr"` + QueryText string `csv:"query_text" header:"query_text" json:"query_text"` + BlockedBy []int `csv:"blocked_by" header:"blocked_by" json:"blocked_by"` + QueryID *string `csv:"query_id" header:"query_id" json:"query_id"` + TransactionID *string `csv:"transaction_id" header:"transaction_id" json:"transaction_id"` + ConnectionID *string `csv:"connection_id" header:"connection_id" json:"connection_id"` +} + +func (p printableList) connectionsWithDatabase() []printableCSVConnectionWithDatabase { + connections := make([]printableCSVConnectionWithDatabase, 0, len(p.Connections)) + for _, conn := range p.Connections { + connections = append(connections, toPrintableCSVConnectionWithDatabase(conn)) + } + return connections +} + +func (p printableList) connectionsWithTopology() []printableConnectionWithTopology { + connections := make([]printableConnectionWithTopology, 0, len(p.Connections)) + for _, conn := range p.Connections { + var keyspace, shard, tablet string + if p.Topology != nil { + keyspace = p.Topology.Keyspace + shard = p.Topology.Shard + tablet = p.Topology.Tablet + } + connections = append(connections, printableConnectionWithTopology{ + Keyspace: keyspace, + Shard: shard, + Tablet: tablet, + PID: conn.PID, + Instance: conn.Instance, + InstanceRole: conn.InstanceRole, + State: conn.State, + DurationMS: conn.DurationMS, + WaitEventType: conn.WaitEventType, + WaitEvent: conn.WaitEvent, + Username: conn.Username, + ApplicationName: conn.ApplicationName, + DatabaseName: conn.DatabaseName, + ClientAddr: conn.ClientAddr, + QueryText: conn.QueryText, + BlockedBy: conn.BlockedBy, + QueryID: conn.QueryID, + TransactionID: conn.TransactionID, + ConnectionID: conn.ConnectionID, + }) + } + return connections +} + +func toPrintableCSVConnectionWithDatabase(conn printableConnection) printableCSVConnectionWithDatabase { + // Identical fields; the conversion re-applies the target type's csv tags so + // DatabaseName is emitted in CSV output (the base type hides it with csv:"-"). + return printableCSVConnectionWithDatabase(conn) +} diff --git a/internal/cmd/branch/connections/show_mysql_test.go b/internal/cmd/branch/connections/show_mysql_test.go new file mode 100644 index 00000000..96cd114f --- /dev/null +++ b/internal/cmd/branch/connections/show_mysql_test.go @@ -0,0 +1,262 @@ +package connections + +import ( + "bytes" + "encoding/csv" + "encoding/json" + "strings" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/cmdutil" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/printer" +) + +func TestPrintListJSONIncludesDatabaseAndTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.JSON, mysqlConnectionList(), commerceTopology()) + + var payload struct { + DatabaseKind live.DatabaseKind `json:"database_kind"` + Topology *ListTopology `json:"topology"` + Connections []struct { + DatabaseName string `json:"database"` + ApplicationName string `json:"application_name"` + } `json:"connections"` + } + decodeJSONForTest(c, got, &payload) + + c.Assert(payload.DatabaseKind, qt.Equals, live.DatabaseKindMySQL) + c.Assert(payload.Topology, qt.DeepEquals, &ListTopology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"}) + c.Assert(payload.Connections, qt.HasLen, 1) + c.Assert(payload.Connections[0].DatabaseName, qt.Equals, "checkout") + c.Assert(payload.Connections[0].ApplicationName, qt.Equals, "") +} + +func TestPrintListHumanIncludesDatabaseAndTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.Human, mysqlConnectionList(), commerceTopology()) + + c.Assert(got, qt.Contains, "topology:\n") + c.Assert(got, qt.Contains, " keyspace: commerce\n") + c.Assert(got, qt.Contains, " shard: -80\n") + c.Assert(got, qt.Contains, " tablet: zone1-1001\n") + c.Assert(got, qt.Contains, "database: checkout\n") +} + +func TestPrintListVitessHumanUsesTabletLabels(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.Human, mysqlConnectionList(), commerceTopology()) + + c.Assert(got, qt.Contains, "tablet: zone1-1001\n") + c.Assert(got, qt.Contains, "database: checkout\n") + c.Assert(got, qt.Not(qt.Contains), "instance: zone1-1001\n") + c.Assert(got, qt.Not(qt.Contains), "role:") + c.Assert(got, qt.Not(qt.Contains), "Database: checkout\n") +} + +func TestPrintListCSVIncludesDatabaseAndTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.CSV, mysqlConnectionList(), commerceTopology()) + rows := readCSVForTest(c, got) + + headers := rows[0] + c.Assert(headers, qt.Contains, "keyspace") + c.Assert(headers, qt.Contains, "shard") + c.Assert(headers, qt.Contains, "tablet") + c.Assert(headers, qt.Contains, "database") + c.Assert(rows[1], qt.Contains, "commerce") + c.Assert(rows[1], qt.Contains, "-80") + c.Assert(rows[1], qt.Contains, "zone1-1001") + c.Assert(rows[1], qt.Contains, "checkout") +} + +func TestPrintListCSVUsesConsistentHeaders(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.CSV, mysqlConnectionList(), commerceTopology()) + rows := readCSVForTest(c, got) + + c.Assert(rows[0], qt.DeepEquals, []string{ + "keyspace", + "shard", + "tablet", + "pid", + "instance", + "role", + "state", + "duration_ms", + "wait_event_type", + "wait_event", + "username", + "application_name", + "database", + "client_addr", + "query_text", + "blocked_by", + "query_id", + "transaction_id", + "connection_id", + }) +} + +func TestPrintListMySQLJSONIncludesDatabaseWithoutTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.JSON, mysqlConnectionList(), ListTopology{}) + + var payload struct { + DatabaseKind live.DatabaseKind `json:"database_kind"` + Topology *ListTopology `json:"topology"` + Connections []struct { + DatabaseName string `json:"database"` + } `json:"connections"` + } + decodeJSONForTest(c, got, &payload) + + c.Assert(payload.DatabaseKind, qt.Equals, live.DatabaseKindMySQL) + c.Assert(payload.Topology, qt.IsNil) + c.Assert(payload.Connections, qt.HasLen, 1) + c.Assert(payload.Connections[0].DatabaseName, qt.Equals, "checkout") +} + +func TestPrintListMySQLHumanIncludesDatabaseWithoutTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.Human, mysqlConnectionList(), ListTopology{}) + + c.Assert(got, qt.Contains, "database: checkout\n") + c.Assert(got, qt.Not(qt.Contains), "topology:\n") +} + +func TestPrintListMySQLCSVIncludesDatabaseWithoutTopology(t *testing.T) { + c := qt.New(t) + + got := printListForTest(c, printer.CSV, mysqlConnectionList(), ListTopology{}) + rows := readCSVForTest(c, got) + + headers := rows[0] + c.Assert(headers, qt.Contains, "database") + c.Assert(headers, qt.Not(qt.Contains), "keyspace") + c.Assert(headers, qt.Not(qt.Contains), "shard") + c.Assert(headers, qt.Not(qt.Contains), "tablet") + c.Assert(rows[1], qt.Contains, "checkout") +} + +func TestPrintListUsesTopologyFromList(t *testing.T) { + c := qt.New(t) + list := mysqlConnectionList(func(list *live.ConnectionList) { + list.Topology = &live.Topology{ + Keyspace: "commerce", + Shard: "-80", + Tablet: "zone1-1001", + } + }) + + got := printListForTest(c, printer.JSON, list, ListTopology{}) + + var payload struct { + Topology *ListTopology `json:"topology"` + } + decodeJSONForTest(c, got, &payload) + c.Assert(payload.Topology, qt.DeepEquals, &ListTopology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"}) +} + +func TestPrintListExplicitTopologyOverridesListTopology(t *testing.T) { + c := qt.New(t) + list := mysqlConnectionList(func(list *live.ConnectionList) { + list.Topology = &live.Topology{ + Keyspace: "stale", + Shard: "0", + Tablet: "stale-tablet", + } + }) + + got := printListForTest(c, printer.JSON, list, commerceTopology()) + + var payload struct { + Topology *ListTopology `json:"topology"` + } + decodeJSONForTest(c, got, &payload) + c.Assert(payload.Topology, qt.DeepEquals, &ListTopology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"}) +} + +func printListForTest(c *qt.C, format printer.Format, list live.ConnectionList, topology ListTopology) string { + var out bytes.Buffer + p := printer.NewPrinter(&format) + p.SetHumanOutput(&out) + p.SetResourceOutput(&out) + ch := &cmdutil.Helper{Printer: p} + + c.Assert(PrintList(ch, list, topology), qt.IsNil) + return out.String() +} + +func mysqlConnectionList(overrides ...func(*live.ConnectionList)) live.ConnectionList { + connectionID := "101" + list := live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC), + DatabaseKind: live.DatabaseKindMySQL, + Connections: []live.Connection{ + { + PID: 101, + Instance: "zone1-1001", + Username: "vt_app", + DatabaseName: "checkout", + ClientAddr: "10.0.0.12:54231", + State: "Query", + Duration: 42 * time.Second, + ConnectionID: &connectionID, + QueryText: "select * from orders", + }, + }, + } + for _, override := range overrides { + override(&list) + } + return list +} + +func commerceTopology() ListTopology { + return ListTopology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"} +} + +func decodeJSONForTest(c *qt.C, raw string, out any) { + c.Assert(json.Unmarshal([]byte(raw), out), qt.IsNil) +} + +func readCSVForTest(c *qt.C, raw string) [][]string { + rows, err := csv.NewReader(strings.NewReader(raw)).ReadAll() + c.Assert(err, qt.IsNil) + return rows +} + +func TestVitessShowSortsByDurationDesc(t *testing.T) { + c := qt.New(t) + list := live.ConnectionList{ + DatabaseKind: live.DatabaseKindMySQL, + Connections: []live.Connection{ + {PID: 1, Duration: 5 * time.Second}, + {PID: 2, Duration: 50 * time.Second}, + {PID: 3, Duration: 20 * time.Second}, + }, + } + sortListForDisplay(&list) + c.Assert([]int{list.Connections[0].PID, list.Connections[1].PID, list.Connections[2].PID}, qt.DeepEquals, []int{2, 3, 1}) + + pg := live.ConnectionList{ + DatabaseKind: live.DatabaseKindPostgreSQL, + Connections: []live.Connection{ + {PID: 1, Duration: 5 * time.Second}, + {PID: 2, Duration: 50 * time.Second}, + }, + } + sortListForDisplay(&pg) + c.Assert(pg.Connections[0].PID, qt.Equals, 1) +} diff --git a/internal/cmd/branch/connections/show_test.go b/internal/cmd/branch/connections/show_test.go new file mode 100644 index 00000000..6ca9cbf4 --- /dev/null +++ b/internal/cmd/branch/connections/show_test.go @@ -0,0 +1,492 @@ +package connections + +import ( + "bytes" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/printer" + "github.com/spf13/cobra" +) + +func TestToPrintableListMapsEnvelopeAndUsesConnectionsForCSV(t *testing.T) { + c := qt.New(t) + queryID := "replica-321-query" + transactionID := "replica-321-transaction" + connectionID := "replica-321-connection" + capturedAt := time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC) + + got := toPrintableList(live.ConnectionList{ + CapturedAt: capturedAt, + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-a", Role: "replica", Error: "timeout"}, + }, + Connections: []live.Connection{ + { + PID: 321, + Instance: "replica-a", + InstanceRole: "replica", + State: "idle", + Duration: 3*time.Second + 250*time.Millisecond, + WaitEventType: "Client", + WaitEvent: "ClientRead", + Username: "brett", + ApplicationName: "psql", + ClientAddr: "192.0.2.15", + QueryText: "select now()", + BlockedBy: []int{456, 789}, + QueryID: &queryID, + TransactionID: &transactionID, + ConnectionID: &connectionID, + }, + }, + }, ListTopology{}) + + c.Assert(got.CapturedAt, qt.Equals, capturedAt) + c.Assert(got.Instances, qt.DeepEquals, []printableInstance{ + {ID: "primary", Role: "primary"}, + {ID: "replica-a", Role: "replica", Error: "timeout"}, + }) + c.Assert(got.Connections, qt.DeepEquals, []printableConnection{ + { + PID: 321, + Instance: "replica-a", + InstanceRole: "replica", + State: "idle", + DurationMS: 3250, + WaitEventType: "Client", + WaitEvent: "ClientRead", + Username: "brett", + ApplicationName: "psql", + ClientAddr: "192.0.2.15", + QueryText: "select now()", + BlockedBy: []int{456, 789}, + QueryID: &queryID, + TransactionID: &transactionID, + ConnectionID: &connectionID, + }, + }) + c.Assert(got.MarshalCSVValue(), qt.DeepEquals, got.Connections) +} + +func TestPrintHumanConnectionListUsesVerticalRecordsWithoutTruncatingQueryText(t *testing.T) { + c := qt.New(t) + queryID := "primary-123-query" + transactionID := "primary-123-transaction" + connectionID := "primary-123-connection" + longQuery := "select " + string(bytes.Repeat([]byte("really_long_expression + "), 20)) + "42" + var out bytes.Buffer + + printHumanConnectionList(&out, live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC), + Connections: []live.Connection{ + { + PID: 123, + Instance: "primary", + InstanceRole: "primary", + State: "active", + Duration: 2*time.Second + 500*time.Millisecond, + WaitEventType: "Lock", + WaitEvent: "transactionid", + Username: "brett", + ApplicationName: "psql", + ClientAddr: "127.0.0.1", + QueryText: longQuery, + BlockedBy: []int{456, 789}, + QueryID: &queryID, + TransactionID: &transactionID, + ConnectionID: &connectionID, + }, + }, + }, ListTopology{}) + + got := out.String() + c.Assert(got, qt.Contains, "*************************** 1. row ***************************\n") + c.Assert(got, qt.Contains, "pid: 123\n") + c.Assert(got, qt.Contains, "blocked_by: 456,789\n") + c.Assert(got, qt.Contains, "query_id: primary-123-query\n") + c.Assert(got, qt.Contains, "transaction_id: primary-123-transaction\n") + c.Assert(got, qt.Contains, "connection_id: primary-123-connection\n") + c.Assert(got, qt.Contains, "query:\n"+longQuery+"\n") +} + +func TestPrintListPostgresHumanUnchanged(t *testing.T) { + c := qt.New(t) + queryID := "primary-123-query" + transactionID := "primary-123-transaction" + connectionID := "primary-123-connection" + + got := printListForTest(c, printer.Human, live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC), + Connections: []live.Connection{ + { + PID: 123, + Instance: "primary", + InstanceRole: "primary", + State: "active", + Duration: 2 * time.Second, + Username: "brett", + ApplicationName: "psql", + QueryText: "select 1", + QueryID: &queryID, + TransactionID: &transactionID, + ConnectionID: &connectionID, + }, + }, + }, ListTopology{}) + + c.Assert(got, qt.Contains, "instance: primary\n") + c.Assert(got, qt.Contains, "role: primary\n") + c.Assert(got, qt.Not(qt.Contains), "tablet: primary\n") + c.Assert(got, qt.Not(qt.Contains), "Database:") +} + +func TestPrintHumanConnectionListEmpty(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + + printHumanConnectionList(&out, live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC), + }, ListTopology{}) + + got := out.String() + c.Assert(got, qt.Contains, "captured_at: 2026-04-29T12:34:56Z\n") + c.Assert(got, qt.Contains, "No live connections found.\n") +} + +func TestPrintHumanConnectionListFlagsUnreachableInstances(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + + printHumanConnectionList(&out, live.ConnectionList{ + CapturedAt: time.Date(2026, 4, 29, 12, 34, 56, 0, time.UTC), + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-b", Role: "replica", Error: "timeout"}, + }, + }, ListTopology{}) + + c.Assert(out.String(), qt.Contains, "warning: partial results, unreachable instances: replica-b\n") +} + +func TestShowCmdHumanOutputFetchesOnceAndPrintsVerticalRecords(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + var seenPath string + server := liveConnectionsListServer(t, sampleListCmdResponse(), &seenPath) + cmd := showCmdForServer(server.URL, printer.Human, &out) + + cmd.SetArgs([]string{"pgload", "main"}) + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(seenPath, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections") + got := out.String() + c.Assert(got, qt.Contains, "*************************** 1. row ***************************\n") + c.Assert(got, qt.Contains, "pid: 123\n") + c.Assert(got, qt.Contains, "query_id: primary-123-q\n") + c.Assert(got, qt.Contains, "transaction_id: primary-123-t\n") + c.Assert(got, qt.Contains, "connection_id: primary-123-c\n") + c.Assert(got, qt.Not(qt.Contains), "database:") + c.Assert(got, qt.Contains, "query:\nSELECT pg_sleep(600)\n") +} + +func TestShowCmdCSVOutput(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + server := liveConnectionsListServer(t, sampleListCmdResponse(), nil) + cmd := showCmdForServer(server.URL, printer.CSV, &out) + + cmd.SetArgs([]string{"pgload", "main"}) + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + got := out.String() + c.Assert(got, qt.Not(qt.Contains), "database") + c.Assert(got, qt.Not(qt.Contains), "Database") + c.Assert(got, qt.Contains, "123,primary,primary,active,664000") +} + +func TestShowCmdJSONOutputPrintsEnvelope(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + server := liveConnectionsListServer(t, sampleListCmdResponse(), nil) + cmd := showCmdForServer(server.URL, printer.JSON, &out) + + cmd.SetArgs([]string{"pgload", "main"}) + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + var got printableList + c.Assert(json.Unmarshal(out.Bytes(), &got), qt.IsNil) + c.Assert(got.CapturedAt.IsZero(), qt.Equals, false) + c.Assert(got.Instances, qt.HasLen, 1) + c.Assert(got.Connections, qt.HasLen, 1) + c.Assert(got.Connections[0].PID, qt.Equals, 123) + c.Assert(got.Connections[0].DurationMS, qt.Equals, int64(664000)) + c.Assert(got.Connections[0].BlockedBy, qt.DeepEquals, []int{}) + c.Assert(got.Connections[0].QueryID, qt.DeepEquals, stringPtr("primary-123-q")) + c.Assert(got.Connections[0].TransactionID, qt.DeepEquals, stringPtr("primary-123-t")) + c.Assert(got.Connections[0].ConnectionID, qt.DeepEquals, stringPtr("primary-123-c")) + c.Assert(out.String(), qt.Contains, `"database_kind": "postgresql"`) +} + +func TestShowCmdFiltersByRoleAndInstance(t *testing.T) { + c := qt.New(t) + server := liveConnectionsListServer(t, sampleFilteredListCmdResponse(), nil) + + var primaryOut bytes.Buffer + primaryCmd := showCmdForServer(server.URL, printer.JSON, &primaryOut) + primaryCmd.SetArgs([]string{"--role", "primary", "pgload", "main"}) + c.Assert(primaryCmd.Execute(), qt.IsNil) + var primary printableList + c.Assert(json.Unmarshal(primaryOut.Bytes(), &primary), qt.IsNil) + c.Assert(primary.Connections, qt.HasLen, 1) + c.Assert(primary.Connections[0].Instance, qt.Equals, "primary") + + var replicaOut bytes.Buffer + replicaCmd := showCmdForServer(server.URL, printer.JSON, &replicaOut) + replicaCmd.SetArgs([]string{"--role", "replica", "pgload", "main"}) + c.Assert(replicaCmd.Execute(), qt.IsNil) + var replica printableList + c.Assert(json.Unmarshal(replicaOut.Bytes(), &replica), qt.IsNil) + c.Assert(replica.Connections, qt.HasLen, 1) + c.Assert(replica.Connections[0].Instance, qt.Equals, "replica-a") + + var instanceOut bytes.Buffer + instanceCmd := showCmdForServer(server.URL, printer.JSON, &instanceOut) + instanceCmd.SetArgs([]string{"--instance", "replica-a", "pgload", "main"}) + c.Assert(instanceCmd.Execute(), qt.IsNil) + var instance printableList + c.Assert(json.Unmarshal(instanceOut.Bytes(), &instance), qt.IsNil) + c.Assert(instance.Connections, qt.HasLen, 1) + c.Assert(instance.Connections[0].Instance, qt.Equals, "replica-a") +} + +func TestShowUnknownInstanceReturnsError(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + server := liveConnectionsListServer(t, sampleFilteredListCmdResponse(), nil) + cmd := showCmdForServer(server.URL, printer.JSON, &out) + cmd.SetArgs([]string{"--instance", "missing", "pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, `unknown instance "missing" \(valid instances: primary, replica-a\)`) + c.Assert(out.String(), qt.Equals, "") +} + +func TestShowValidInstanceWithNoConnectionsSucceeds(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + server := liveConnectionsListServer(t, samplePrimaryOnlyListCmdResponse(), nil) + cmd := showCmdForServer(server.URL, printer.JSON, &out) + cmd.SetArgs([]string{"--instance", "replica-a", "pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + var got printableList + c.Assert(json.Unmarshal(out.Bytes(), &got), qt.IsNil) + c.Assert(got.Instances, qt.DeepEquals, []printableInstance{{ID: "replica-a", Role: "replica"}}) + c.Assert(got.Connections, qt.HasLen, 0) +} + +func TestShowForbiddenReturnsNonLeakyPermissionError(t *testing.T) { + c := qt.New(t) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections") + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusForbidden) + _, _ = io.WriteString(w, `{"message":"denied by infra policy on tablet zone1-1001"}`) + })) + t.Cleanup(server.Close) + + var out bytes.Buffer + cmd := showCmdForServer(server.URL, printer.JSON, &out) + cmd.SetArgs([]string{"pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "permission denied: you don't have permission to view live connections") + c.Assert(err.Error(), qt.Not(qt.Contains), "zone1-1001") + c.Assert(out.String(), qt.Equals, "") +} + +func TestShowCmdRejectsRoleWithInstance(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + cmd := showCmdForServer("http://127.0.0.1:1", printer.JSON, &out) + cmd.SetArgs([]string{"--role", "primary", "--instance", "replica-a", "pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--role cannot be combined with --instance") +} + +func TestShowCmdRejectsUnknownRole(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + cmd := showCmdForServer("http://127.0.0.1:1", printer.JSON, &out) + cmd.SetArgs([]string{"--role", "writer", "pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--role must be primary or replica") +} + +func TestShowCmdRejectsPrimaryFlag(t *testing.T) { + c := qt.New(t) + var out bytes.Buffer + cmd := showCmdForServer("http://127.0.0.1:1", printer.JSON, &out) + cmd.SetArgs([]string{"--primary", "pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, `unknown flag: --primary`) +} + +func liveConnectionsListServer(t *testing.T, body string, seenPath *string) *httptest.Server { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/organizations/acme/databases/pgload/branches/main/connections" { + t.Fatalf("path = %q", r.URL.Path) + } + if seenPath != nil { + *seenPath = r.URL.Path + } + _, _ = io.WriteString(w, body) + })) + t.Cleanup(server.Close) + return server +} + +// showCmdForServer builds a minimal cobra command that exercises the Postgres +// RunList path (filtering, printing, error sanitization) against a test server. +// The shipping command is branch.ConnectionsShowCmd, which lives one package up +// and adds engine detection; these tests cover the engine-agnostic list path it +// delegates to. +func showCmdForServer(baseURL string, format printer.Format, out *bytes.Buffer) *cobra.Command { + ch := liveConnectionsTestHelper(baseURL, format, out) + var flags struct { + instance string + role string + } + cmd := &cobra.Command{ + Use: "show ", + Args: cmdutil.RequiredArgs("database", "branch"), + RunE: func(cmd *cobra.Command, args []string) error { + return RunList(cmd.Context(), ch, args[0], args[1], ConnectionFilter{Instance: flags.instance, Role: flags.role}, ConnectionTarget{}) + }, + } + cmd.Flags().StringVar(&flags.instance, "instance", "", "Filter the list to a single instance.") + cmd.Flags().StringVar(&flags.role, "role", "", "Filter the list to rows whose instance role is primary or replica.") + cmd.SilenceErrors = true + cmd.SilenceUsage = true + cmd.SetErr(io.Discard) + return cmd +} + +func liveConnectionsTestHelper(baseURL string, format printer.Format, out *bytes.Buffer) *cmdutil.Helper { + p := printer.NewPrinter(&format) + p.SetHumanOutput(out) + p.SetResourceOutput(out) + + return &cmdutil.Helper{ + Config: &config.Config{ + BaseURL: baseURL, + Organization: "acme", + ServiceTokenID: "tid", + ServiceToken: "secret", + }, + Printer: p, + } +} + +func sampleListCmdResponse() string { + return `{"type":"list","database_kind":"postgresql","captured_at":"2026-04-29T12:34:56Z","instances":[{"id":"primary","role":"primary","error":null}],"data":[{"pid":123,"instance":"primary","duration_ms":664000,"state":"active","usename":"alice","application_name":"psql","client_addr":"10.0.0.1","wait_event_type":"Client","wait_event":"ClientRead","query_text":"SELECT pg_sleep(600)","xact_start":"2026-04-29T12:23:52Z","query_start":"2026-04-29T12:23:52Z","query_id":"primary-123-q","transaction_id":"primary-123-t","connection_id":"primary-123-c"}]}` +} + +func sampleFilteredListCmdResponse() string { + return `{ + "type": "list", + "captured_at": "2026-04-29T12:34:56Z", + "instances": [ + {"id": "primary", "role": "primary", "error": null}, + {"id": "replica-a", "role": "replica", "error": null} + ], + "data": [ + { + "pid": 123, + "instance": "primary", + "duration_ms": 664000, + "state": "active", + "usename": "alice", + "application_name": "psql", + "client_addr": "10.0.0.1", + "query_text": "SELECT pg_sleep(600)", + "xact_start": "2026-04-29T12:23:52Z", + "query_start": "2026-04-29T12:23:52Z", + "query_id": "primary-123-q", + "transaction_id": "primary-123-t", + "connection_id": "primary-123-c" + }, + { + "pid": 456, + "instance": "replica-a", + "duration_ms": 2000, + "state": "idle", + "usename": "bob", + "application_name": "psql", + "client_addr": "10.0.0.2", + "query_text": "SELECT 1", + "xact_start": "2026-04-29T12:34:54Z", + "query_start": "2026-04-29T12:34:54Z", + "query_id": "replica-456-q", + "transaction_id": "replica-456-t", + "connection_id": "replica-456-c" + } + ] + }` +} + +func samplePrimaryOnlyListCmdResponse() string { + return `{ + "type": "list", + "captured_at": "2026-04-29T12:34:56Z", + "instances": [ + {"id": "primary", "role": "primary", "error": null}, + {"id": "replica-a", "role": "replica", "error": null} + ], + "data": [ + { + "pid": 123, + "instance": "primary", + "duration_ms": 664000, + "state": "active", + "usename": "alice", + "application_name": "psql", + "client_addr": "10.0.0.1", + "query_text": "SELECT pg_sleep(600)", + "xact_start": "2026-04-29T12:23:52Z", + "query_start": "2026-04-29T12:23:52Z", + "query_id": "primary-123-q", + "transaction_id": "primary-123-t", + "connection_id": "primary-123-c" + } + ] + }` +} + +func stringPtr(value string) *string { + return &value +} diff --git a/internal/cmd/branch/connections/top.go b/internal/cmd/branch/connections/top.go new file mode 100644 index 00000000..44078452 --- /dev/null +++ b/internal/cmd/branch/connections/top.go @@ -0,0 +1,654 @@ +package connections + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/AlecAivazis/survey/v2" + tea "github.com/charmbracelet/bubbletea" + "github.com/planetscale/cli/internal/cmdutil" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" + "github.com/planetscale/cli/internal/connections/tui" + "github.com/planetscale/cli/internal/printer" + "github.com/planetscale/cli/internal/promptutil" + ps "github.com/planetscale/planetscale-go/planetscale" + "github.com/spf13/cobra" +) + +func TopCmd(ch *cmdutil.Helper) *cobra.Command { + var flags topFlags + + cmd := &cobra.Command{ + Use: "top [database] [branch]", + Short: "Show live branch connection activity", + Long: `Show live branch connection activity. + +Run interactively in a terminal to launch the TUI. Pipe or redirect output to +run headlessly with --capture; --duration bounds either mode and without +--duration the command runs until interrupted. Pass --replay FILE to render a +previously captured trace in the TUI — actions are rejected in replay mode. + +For Postgres, connections top shows session activity across instances. For +Vitess, pass --keyspace and --shard or run interactively to select them when +the server reports available targets.`, + Args: func(cmd *cobra.Command, args []string) error { + if flags.replay != "" { + return nil + } + return cmdutil.RequiredArgs("database")(cmd, args) + }, + PreRunE: func(cmd *cobra.Command, args []string) error { + if flags.interval <= 0 { + return errors.New("--interval must be greater than 0") + } + if flags.duration < 0 { + return errors.New("--duration must not be negative") + } + if flags.replay != "" && flags.capture != "" { + return errors.New("--replay cannot be combined with --capture") + } + if flags.replay != "" && flags.duration > 0 { + return errors.New("--duration cannot be combined with --replay") + } + if err := validateConnectionFilter(flags.instance, flags.role); err != nil { + return err + } + if flags.replay != "" { + if _, err := os.Stat(flags.replay); err != nil { + return fmt.Errorf("--replay: %w", err) + } + if !isHumanMode(ch) { + return errors.New("--replay requires an interactive terminal") + } + } + if flags.capture != "" { + if err := validateCapturePath(flags.capture); err != nil { + return err + } + } + return nil + }, + RunE: func(cmd *cobra.Command, args []string) error { + return runTop(cmd.Context(), cmd, ch, args, flags) + }, + } + + cmd.Flags().DurationVar(&flags.interval, "interval", 1*time.Second, "Refresh interval.") + cmd.Flags().StringVar(&flags.capture, "capture", "", "Write captured samples to a trace file. Required in headless mode.") + cmd.Flags().StringVar(&flags.replay, "replay", "", "Replay a previously captured trace file in the TUI. Mutually exclusive with --capture.") + cmd.Flags().DurationVar(&flags.duration, "duration", 0, "Run for this duration. Default is to run until interrupted.") + cmd.Flags().StringVar(&flags.instance, "instance", "", "Filter the live view to a single instance (by id from the list response).") + cmd.Flags().StringVar(&flags.role, "role", "", "Filter the live view to rows whose instance role is primary or replica.") + cmd.Flags().StringVar(&flags.keyspace, "keyspace", "", "Vitess keyspace to target.") + cmd.Flags().StringVar(&flags.shard, "shard", "", "Vitess shard to target.") + + return cmd +} + +type topFlags struct { + interval time.Duration + capture string + replay string + duration time.Duration + instance string + role string + keyspace string + shard string +} + +func (f topFlags) filter() connectionFilter { + return connectionFilter{instance: f.instance, role: f.role} +} + +func (f topFlags) target() ConnectionTarget { + return ConnectionTarget{Keyspace: f.keyspace, Shard: f.shard} +} + +type topRequest struct { + Database string + Branch string + Engine ps.DatabaseEngine + Filter connectionFilter + Target ConnectionTarget + Interactive bool +} + +func runTop(ctx context.Context, cmd *cobra.Command, ch *cmdutil.Helper, args []string, flags topFlags) (err error) { + if flags.replay != "" { + return runReplay(ctx, flags.replay, flags.duration, flags.interval) + } + + request, err := newTopRequest(ctx, ch, args, flags) + if err != nil { + return err + } + + source, err := newTopSource(ctx, ch, request) + if err != nil { + return err + } + + return runTopWithSource(ctx, cmd, ch, request, source, flags) +} + +func newTopRequest(ctx context.Context, ch *cmdutil.Helper, args []string, flags topFlags) (topRequest, error) { + database := args[0] + engine, err := getTopDatabaseKind(ctx, ch, database) + if err != nil { + return topRequest{}, err + } + + filter := flags.filter() + target := flags.target() + if err := validateTopFlagsForEngine(engine, filter, target); err != nil { + return topRequest{}, err + } + + branch, err := resolveBranch(ctx, ch, database, args) + if err != nil { + return topRequest{}, err + } + + interactive := isHumanMode(ch) + if !interactive && flags.capture == "" { + return topRequest{}, errors.New("--capture is required when running without a TTY") + } + + return topRequest{ + Database: database, + Branch: branch, + Engine: engine, + Filter: filter, + Target: target, + Interactive: interactive, + }, nil +} + +func runTopWithSource(ctx context.Context, cmd *cobra.Command, ch *cmdutil.Helper, request topRequest, source topSource, flags topFlags) error { + if request.Filter.active() { + fmt.Fprintln(cmd.ErrOrStderr(), request.Filter.describe()) + } + + if request.Interactive { + return runTopInteractive(ctx, ch, request, source, flags) + } + return runTopHeadless(ctx, cmd, ch, request, source, flags) +} + +func runTopInteractive(ctx context.Context, ch *cmdutil.Helper, request topRequest, source topSource, flags topFlags) error { + control := newCaptureControl(flags.capture, ch.Config.Organization, request.Database, request.Branch, request.Filter, source.Target) + if flags.capture != "" { + writer, path, err := control.Open() + if err != nil { + return err + } + control.Writer = writer + control.Path = path + } + target := tui.Target{ + Database: request.Database, + Branch: request.Branch, + Keyspace: source.Target.Keyspace, + Shard: source.Target.Shard, + } + return runInteractive(ctx, source.Client, flags.duration, flags.interval, control, target, request.Filter.chip(), source.View) +} + +func runTopHeadless(ctx context.Context, cmd *cobra.Command, ch *cmdutil.Helper, request topRequest, source topSource, flags topFlags) error { + writer, err := openCaptureWriter(flags.capture, ch.Config.Organization, request.Database, request.Branch, request.Filter, source.Target) + if err != nil { + return err + } + + if flags.duration > 0 { + fmt.Fprintf(cmd.ErrOrStderr(), "Capturing for %s to %s\n", flags.duration, flags.capture) + } else { + fmt.Fprintf(cmd.ErrOrStderr(), "Capturing to %s (Ctrl-C to stop)\n", flags.capture) + } + + return runHeadlessCapture(ctx, sortedTopLister{client: source.Client, sort: source.View.DefaultSort()}, writer, flags.duration, flags.interval) +} + +func getTopDatabaseKind(ctx context.Context, ch *cmdutil.Helper, database string) (ps.DatabaseEngine, error) { + client, err := ch.Client() + if err != nil { + return "", err + } + db, err := client.Databases.Get(ctx, &ps.GetDatabaseRequest{ + Organization: ch.Config.Organization, + Database: database, + }) + if err != nil { + return "", err + } + if db == nil { + return "", errors.New("database not found") + } + return db.Kind, nil +} + +func validateTopFlagsForEngine(engine ps.DatabaseEngine, filter connectionFilter, target ConnectionTarget) error { + if err := validateEngineFlags(engine, filter, target); err != nil { + return err + } + switch engine { + case ps.DatabaseEnginePostgres, ps.DatabaseEngineMySQL: + return nil + default: + return fmt.Errorf("connections top is not supported for database kind %q", engine) + } +} + +func runReplay(ctx context.Context, path string, duration, interval time.Duration) error { + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("--replay: %w", err) + } + defer file.Close() + + source, err := history.NewReplaySource(file) + if err != nil { + return fmt.Errorf("--replay: %w", err) + } + + captures := source.Captures() + samples := history.NewCaptureHistory(len(captures)) + for _, capture := range captures { + samples.Push(capture.List) + } + view := replayConnectionView(captures) + + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + + model := tui.NewModel(runCtx, newReplayClient(source), interval, duration). + WithTarget(replayTarget(source)). + WithConnectionView(view). + WithCaptureHistory(samples). + WithReadOnlyActions(errReplayActionRejected.Error()) + return runTeaProgram(model, tea.WithAltScreen(), tea.WithContext(runCtx)) +} + +func replayTarget(source *history.ReplaySource) tui.Target { + start, ok := source.CaptureStart() + if !ok { + return tui.Target{} + } + target := tui.Target{ + Database: start.Database, + Branch: start.Branch, + } + if start.Target != nil { + target.Keyspace = start.Target.Keyspace + target.Shard = start.Target.Shard + } + return target +} + +func replayConnectionView(captures []history.Capture) tui.ConnectionViewProfile { + for _, capture := range captures { + if capture.List.DatabaseKind == live.DatabaseKindMySQL { + return tui.VitessConnectionView + } + } + return tui.PostgresConnectionView +} + +func isHumanMode(ch *cmdutil.Helper) bool { + if !printer.IsTTY { + return false + } + if ch.Printer != nil && ch.Printer.Format() != printer.Human { + return false + } + return true +} + +func resolveBranch(ctx context.Context, ch *cmdutil.Helper, database string, args []string) (string, error) { + if len(args) >= 2 && args[1] != "" { + return args[1], nil + } + client, err := ch.Client() + if err != nil { + return "", err + } + return promptutil.GetBranch(ctx, client, ch.Config.Organization, database) +} + +var runTeaProgram = func(model tea.Model, options ...tea.ProgramOption) error { + _, err := tea.NewProgram(model, options...).Run() + return err +} + +func runInteractive(ctx context.Context, client tui.ConnectionsClient, duration, interval time.Duration, control *tui.CaptureControl, target tui.Target, filterChip string, view tui.ConnectionViewProfile) (err error) { + runCtx, cancel := context.WithCancel(ctx) + defer cancel() + if control != nil { + defer func() { + err = errors.Join(err, control.Close()) + }() + } + + model := tui.NewModel(runCtx, client, interval, duration). + WithTarget(target). + WithFilter(filterChip). + WithConnectionView(view) + if control != nil { + model = model.WithCaptureControl(control) + } + return runTeaProgram(model, tea.WithAltScreen(), tea.WithContext(runCtx)) +} + +type topSource struct { + Client tui.ConnectionsClient + View tui.ConnectionViewProfile + Target ConnectionTarget +} + +func newTopSource(ctx context.Context, ch *cmdutil.Helper, request topRequest) (topSource, error) { + switch request.Engine { + case ps.DatabaseEnginePostgres: + client, err := newConnectionsClient(ch, request.Database, request.Branch, ConnectionTarget{}) + if err != nil { + return topSource{}, err + } + return topSource{ + Client: filteredLister{client: client, filter: request.Filter}, + View: tui.PostgresConnectionView, + }, nil + case ps.DatabaseEngineMySQL: + return newVitessTopSource(ctx, ch, request) + default: + return topSource{}, fmt.Errorf("connections top is not supported for database kind %q", request.Engine) + } +} + +func newVitessTopSource(ctx context.Context, ch *cmdutil.Helper, request topRequest) (topSource, error) { + resolved, err := probeAndResolveVitessTopTarget(ctx, ch, request.Database, request.Branch, request.Target, request.Interactive) + if err != nil { + return topSource{}, err + } + client, err := newConnectionsClient(ch, request.Database, request.Branch, resolved) + if err != nil { + return topSource{}, err + } + return topSource{ + Client: filteredLister{client: client}, + View: tui.VitessConnectionView, + Target: resolved, + }, nil +} + +const maxVitessTopTargetProbes = 3 + +func probeAndResolveVitessTopTarget(ctx context.Context, ch *cmdutil.Helper, database, branch string, target ConnectionTarget, canPrompt bool) (ConnectionTarget, error) { + resolved := target + for attempts := 0; attempts < maxVitessTopTargetProbes; attempts++ { + err := probeVitessTopTarget(ctx, ch, database, branch, resolved) + if err == nil { + return resolved, nil + } + + next, ok, resolveErr := nextVitessTopTarget(resolved, err, canPrompt) + if resolveErr != nil { + return ConnectionTarget{}, resolveErr + } + if !ok { + if canPrompt { + return resolved, nil + } + return ConnectionTarget{}, err + } + resolved = next + } + return ConnectionTarget{}, errors.New("could not resolve Vitess keyspace and shard") +} + +func probeVitessTopTarget(ctx context.Context, ch *cmdutil.Helper, database, branch string, target ConnectionTarget) error { + client, err := newConnectionsClient(ch, database, branch, target) + if err != nil { + return err + } + _, err = client.List(ctx, live.SortByDuration) + return err +} + +func nextVitessTopTarget(target ConnectionTarget, err error, canPrompt bool) (ConnectionTarget, bool, error) { + httpErr, ok := badRequestWithAlternatives(err) + if !ok || !canPrompt { + return ConnectionTarget{}, false, nil + } + + switch { + case target.Keyspace == "" && len(httpErr.Available.Keyspaces) > 0: + selected, err := selectTopTarget("Select a keyspace for connections top:", httpErr.Available.Keyspaces) + if err != nil { + return ConnectionTarget{}, false, err + } + target.Keyspace = selected + return target, true, nil + case target.Shard == "" && len(httpErr.Available.Shards) > 0: + selected, err := selectTopTarget("Select a shard for connections top:", httpErr.Available.Shards) + if err != nil { + return ConnectionTarget{}, false, err + } + target.Shard = selected + return target, true, nil + default: + return ConnectionTarget{}, false, nil + } +} + +func badRequestWithAlternatives(err error) (*live.HTTPError, bool) { + var httpErr *live.HTTPError + if !errors.As(err, &httpErr) { + return nil, false + } + return httpErr, httpErr.StatusCode == http.StatusBadRequest +} + +var selectTopTarget = func(message string, options []string) (string, error) { + prompt := &survey.Select{ + Message: message, + Options: options, + VimMode: true, + } + var selected string + err := survey.AskOne(prompt, &selected) + return selected, err +} + +type sortedTopLister struct { + client tui.ConnectionsClient + sort live.SortMode +} + +func (l sortedTopLister) List(ctx context.Context, _ live.SortMode) (live.ConnectionList, error) { + return l.client.List(ctx, l.sort) +} + +func newCaptureControl(path, org, database, branch string, filter connectionFilter, target ConnectionTarget) *tui.CaptureControl { + return &tui.CaptureControl{ + Open: func() (*history.CaptureWriter, string, error) { + capturePath := path + if capturePath == "" { + capturePath = defaultInteractiveCapturePath(time.Now()) + } + writer, err := openCaptureWriter(capturePath, org, database, branch, filter, target) + return writer, capturePath, err + }, + } +} + +type connectionFilter struct { + instance string + role string +} + +func (f connectionFilter) active() bool { + return f.instance != "" || f.role != "" +} + +func validateConnectionFilter(instance, role string) error { + if instance != "" && role != "" { + return errors.New("--role cannot be combined with --instance") + } + switch role { + case "", "primary", "replica": + return nil + default: + return errors.New("--role must be primary or replica") + } +} + +// captureFilter renders the filter for the capture file header. Returns nil +// when no filter is active so the header's "filter" field is omitted. +func (f connectionFilter) captureFilter() *history.CaptureFilter { + if !f.active() { + return nil + } + return &history.CaptureFilter{ + Instance: f.instance, + Role: f.role, + } +} + +// chip renders the compact header indicator shown in the interactive TUI so +// the operator can see the view is scoped. Empty when no filter is active. +func (f connectionFilter) chip() string { + switch { + case f.instance != "": + return fmt.Sprintf("filter: instance=%s", f.instance) + case f.role != "": + return fmt.Sprintf("filter: role=%s", f.role) + default: + return "" + } +} + +func (f connectionFilter) describe() string { + switch { + case f.instance != "": + return fmt.Sprintf("filtering to instance=%s", f.instance) + case f.role != "": + return fmt.Sprintf("filtering to role=%s", f.role) + default: + return "" + } +} + +func filterConnectionList(list live.ConnectionList, f connectionFilter) live.ConnectionList { + if !f.active() { + return list + } + kept := make([]live.Connection, 0, len(list.Connections)) + for _, conn := range list.Connections { + if f.instance != "" && conn.Instance != f.instance { + continue + } + if f.role != "" && conn.InstanceRole != f.role { + continue + } + kept = append(kept, conn) + } + keptInstances := make([]live.InstanceMeta, 0, len(list.Instances)) + for _, inst := range list.Instances { + if f.instance != "" && inst.ID != f.instance { + continue + } + if f.role != "" && inst.Role != f.role { + continue + } + keptInstances = append(keptInstances, inst) + } + out := list + out.Connections = kept + out.Instances = keptInstances + return out +} + +type filteredLister struct { + client *live.Client + filter connectionFilter +} + +func (f filteredLister) List(ctx context.Context, sort live.SortMode) (live.ConnectionList, error) { + list, err := f.client.List(ctx, sort) + if err != nil { + return list, err + } + if err := validateInstanceFilter(list, f.filter); err != nil { + return list, err + } + return filterConnectionList(list, f.filter), nil +} + +// Actions pass through to the wire client unchanged — the filter only scopes +// the read path so the operator sees a subset of rows; once they pick a row, +// the action targets that specific (instance, pid) on the real cluster. +func (f filteredLister) CancelQuery(ctx context.Context, target live.ActionTarget) error { + return f.client.CancelQuery(ctx, target) +} + +func (f filteredLister) TerminateTransaction(ctx context.Context, target live.ActionTarget) error { + return f.client.TerminateTransaction(ctx, target) +} + +func (f filteredLister) TerminateConnection(ctx context.Context, target live.ActionTarget) error { + return f.client.TerminateConnection(ctx, target) +} + +func openCaptureWriter(path, org, database, branch string, filter connectionFilter, target ConnectionTarget) (*history.CaptureWriter, error) { + file, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return nil, err + } + writer := history.NewCaptureWriter(file) + if err := writer.WriteCaptureStart(history.CaptureStart{ + At: time.Now().UTC(), + Organization: org, + Database: database, + Branch: branch, + Filter: filter.captureFilter(), + Target: captureTarget(target), + }); err != nil { + _ = writer.Close() + return nil, err + } + return writer, nil +} + +func captureTarget(target ConnectionTarget) *history.CaptureTarget { + if target.Keyspace == "" && target.Shard == "" { + return nil + } + return &history.CaptureTarget{ + Keyspace: target.Keyspace, + Shard: target.Shard, + } +} + +func defaultInteractiveCapturePath(now time.Time) string { + return "connections-" + now.UTC().Format("20060102T150405.000000000Z") + ".jsonl" +} + +func validateCapturePath(path string) error { + parent := filepath.Dir(path) + info, err := os.Stat(parent) + if err != nil { + if os.IsNotExist(err) { + return errors.New("capture parent directory " + parent + " does not exist") + } + return err + } + if !info.IsDir() { + return errors.New("capture parent path " + parent + " is not a directory") + } + return nil +} diff --git a/internal/cmd/branch/connections/top_test.go b/internal/cmd/branch/connections/top_test.go new file mode 100644 index 00000000..77be6288 --- /dev/null +++ b/internal/cmd/branch/connections/top_test.go @@ -0,0 +1,821 @@ +package connections + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + tea "github.com/charmbracelet/bubbletea" + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/mock" + "github.com/planetscale/cli/internal/printer" + ps "github.com/planetscale/planetscale-go/planetscale" + "github.com/spf13/cobra" +) + +func TestFilterConnectionList(t *testing.T) { + tests := []struct { + name string + input live.ConnectionList + filter connectionFilter + wantPIDs []int + wantInstances []live.InstanceMeta + }{ + { + name: "by instance", + input: live.ConnectionList{Connections: []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary"}, + {PID: 2, Instance: "replica-a", InstanceRole: "replica"}, + {PID: 3, Instance: "replica-b", InstanceRole: "replica"}, + }}, + filter: connectionFilter{instance: "replica-a"}, + wantPIDs: []int{2}, + }, + { + name: "by primary role", + input: live.ConnectionList{Connections: []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary"}, + {PID: 2, Instance: "replica-a", InstanceRole: "replica"}, + }}, + filter: connectionFilter{role: "primary"}, + wantPIDs: []int{1}, + }, + { + name: "by replica role", + input: live.ConnectionList{Connections: []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary"}, + {PID: 2, Instance: "replica-a", InstanceRole: "replica"}, + {PID: 3, Instance: "replica-b", InstanceRole: "replica"}, + }}, + filter: connectionFilter{role: "replica"}, + wantPIDs: []int{2, 3}, + }, + { + name: "no filter passes through", + input: live.ConnectionList{Connections: []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary"}, + }}, + filter: connectionFilter{}, + wantPIDs: []int{1}, + }, + { + name: "role filters instances metadata", + input: live.ConnectionList{ + Connections: []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary"}, + }, + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-1", Role: "replica"}, + {ID: "replica-2", Role: "replica", Error: "timeout"}, + }, + }, + filter: connectionFilter{role: "primary"}, + wantPIDs: []int{1}, + wantInstances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + }, + }, + { + name: "instance filter scopes instances to target", + input: live.ConnectionList{ + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-1", Role: "replica"}, + {ID: "replica-2", Role: "replica", Error: "timeout"}, + }, + }, + filter: connectionFilter{instance: "replica-2"}, + wantPIDs: []int{}, + wantInstances: []live.InstanceMeta{ + {ID: "replica-2", Role: "replica", Error: "timeout"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + got := filterConnectionList(tt.input, tt.filter) + gotPIDs := make([]int, 0, len(got.Connections)) + for _, conn := range got.Connections { + gotPIDs = append(gotPIDs, conn.PID) + } + c.Assert(gotPIDs, qt.DeepEquals, tt.wantPIDs) + if tt.wantInstances != nil { + c.Assert(got.Instances, qt.DeepEquals, tt.wantInstances) + } + }) + } +} + +func TestTopCmdValidation(t *testing.T) { + tests := []struct { + name string + engine ps.DatabaseEngine + args []string + wantErr string + }{ + { + name: "role with instance", + engine: ps.DatabaseEnginePostgres, + args: []string{"--role", "primary", "--instance", "replica-a", "pgload", "main"}, + wantErr: "--role cannot be combined with --instance", + }, + { + name: "zero interval", + engine: ps.DatabaseEnginePostgres, + args: []string{"pgload", "main", "--interval", "0s"}, + wantErr: "--interval must be greater than 0", + }, + { + name: "negative interval", + engine: ps.DatabaseEnginePostgres, + args: []string{"pgload", "main", "--interval", "-1s"}, + wantErr: "--interval must be greater than 0", + }, + { + name: "negative duration", + engine: ps.DatabaseEnginePostgres, + args: []string{"pgload", "main", "--duration", "-1s"}, + wantErr: "--duration must not be negative", + }, + { + name: "vitess target flags on postgres", + engine: ps.DatabaseEnginePostgres, + args: []string{"pgload", "main", "--keyspace", "commerce"}, + wantErr: "--keyspace/--shard are only supported for Vitess databases", + }, + { + name: "postgres filters on vitess", + engine: ps.DatabaseEngineMySQL, + args: []string{"shop", "main", "--role", "primary"}, + wantErr: "--instance/--role are only supported for Postgres databases", + }, + { + name: "unknown role", + engine: ps.DatabaseEnginePostgres, + args: []string{"--role", "writer", "pgload", "main"}, + wantErr: "--role must be primary or replica", + }, + { + name: "primary flag removed", + engine: ps.DatabaseEnginePostgres, + args: []string{"--primary", "pgload", "main"}, + wantErr: `unknown flag: --primary`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + cmd := topCmdForServerAndEngine("http://example.invalid", tt.engine, &bytes.Buffer{}) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, tt.wantErr) + }) + } +} + +func TestTopCmdRunEHeadlessHappyPathWritesCaptureHeaderAndCapture(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + server := liveConnectionsServer(t, sampleTopResponse()) + capture := filepath.Join(t.TempDir(), "trace.jsonl") + var out bytes.Buffer + cmd := topCmdForServer(server.URL, &out) + cmd.SetArgs([]string{"pgload", "main", "--capture", capture, "--duration", "200ms", "--interval", "1s"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + records := readJSONLines(t, capture) + c.Assert(records, qt.HasLen, 2) + header := records[0] + c.Assert(header["type"], qt.Equals, "capture_start") + c.Assert(header["org"], qt.Equals, "acme") + c.Assert(header["database"], qt.Equals, "pgload") + c.Assert(header["branch"], qt.Equals, "main") + c.Assert(header["schema_version"], qt.Equals, float64(1)) + captureRecord := capturedConnectionList(c, records[1]) + c.Assert(captureRecord["database_kind"], qt.Equals, "postgresql") + c.Assert(captureRecord["instances"], qt.DeepEquals, []any{map[string]any{"id": "primary", "role": "primary"}}) + info, err := os.Stat(capture) + c.Assert(err, qt.IsNil) + c.Assert(info.Mode().Perm(), qt.Equals, os.FileMode(0o600)) +} + +func TestTopCmdRunERecordsRoleFilterInCaptureHeader(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + server := liveConnectionsServer(t, sampleTopResponse()) + capture := filepath.Join(t.TempDir(), "trace.jsonl") + var out bytes.Buffer + cmd := topCmdForServer(server.URL, &out) + cmd.SetArgs([]string{"pgload", "main", "--role", "primary", "--capture", capture, "--duration", "200ms", "--interval", "1s"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + records := readJSONLines(t, capture) + c.Assert(records, qt.HasLen, 2) + c.Assert(records[0]["filter"], qt.DeepEquals, map[string]any{"role": "primary"}) +} + +func TestTopCmdInteractiveHappyPathFetchesAndExits(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + return nil + }) + defer restoreProgram() + server := liveConnectionsServer(t, sampleTopResponse()) + cmd := topCmdForServer(server.URL, &bytes.Buffer{}) + cmd.SetArgs([]string{"pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) +} + +func TestTopCmdInteractivePassesTargetAndFilterToTUI(t *testing.T) { + tests := []struct { + name string + args []string + want string + }{ + { + name: "target", + args: []string{"pgload", "main"}, + want: "pgload / main", + }, + { + name: "role filter", + args: []string{"pgload", "main", "--role", "primary"}, + want: "filter: role=primary", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + var view string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated, _ := model.Update(tea.WindowSizeMsg{Width: 200, Height: 24}) + view = updated.View() + return nil + }) + defer restoreProgram() + server := liveConnectionsServer(t, sampleTopResponse()) + cmd := topCmdForServer(server.URL, &bytes.Buffer{}) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(view, qt.Contains, tt.want) + }) + } +} + +func TestTopCmdKeepsPostgresConnectionCapabilities(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + var blockersView string + var actionView string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated := fetchInitialTopModel(t, model) + withBlockers, _ := updated.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + blockersView = withBlockers.View() + withAction, _ := updated.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + actionView = withAction.View() + return nil + }) + defer restoreProgram() + server := liveConnectionsServer(t, sampleTopResponseWithoutActionIDs()) + cmd := topCmdForServer(server.URL, &bytes.Buffer{}) + cmd.SetArgs([]string{"pgload", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(blockersView, qt.Contains, "[blockers]") + c.Assert(actionView, qt.Contains, "no open transaction to terminate on this connection") +} + +func TestTopCmdRunsVitessTopWithTarget(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections") + assertTopQueryParam(c, r, "keyspace", "commerce") + assertTopQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, sampleVitessTopResponse()) + }) + + var view string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated := fetchInitialTopModel(t, model) + view = updated.View() + return nil + }) + defer restoreProgram() + + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main", "--keyspace", "commerce", "--shard", "-80"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(view, qt.Contains, "shop / main / commerce / -80") + c.Assert(view, qt.Contains, "PID") + c.Assert(view, qt.Contains, "USER") + c.Assert(view, qt.Not(qt.Contains), "BLOCK") +} + +func TestTopCmdResolvesVitessTargetByPrompting(t *testing.T) { + tests := []struct { + name string + args []string + wantRequests int + wantPrompts []string + }{ + { + name: "keyspace and shard", + args: []string{"shop", "main"}, + wantRequests: 3, + wantPrompts: []string{"Select a keyspace for connections top:", "Select a shard for connections top:"}, + }, + { + name: "shard only", + args: []string{"shop", "main", "--keyspace", "commerce"}, + wantRequests: 2, + wantPrompts: []string{"Select a shard for connections top:"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + var requests []string + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r.URL.RawQuery) + keyspace := r.URL.Query().Get("keyspace") + switch { + case keyspace == "": + writeVitessMultipleKeyspacesResponse(w) + case r.URL.Query().Get("shard") == "": + c.Assert(keyspace, qt.Equals, "commerce") + writeVitessShardedKeyspaceResponse(w) + default: + c.Assert(keyspace, qt.Equals, "commerce") + c.Assert(r.URL.Query().Get("shard"), qt.Equals, "-80") + _, _ = io.WriteString(w, sampleVitessTopResponse()) + } + }) + + var prompts []string + restorePrompt := setSelectTargetForTest(t, func(message string, options []string) (string, error) { + prompts = append(prompts, message) + switch message { + case "Select a keyspace for connections top:": + c.Assert(options, qt.DeepEquals, []string{"commerce", "lookup"}) + return "commerce", nil + case "Select a shard for connections top:": + c.Assert(options, qt.DeepEquals, []string{"-80", "80-"}) + return "-80", nil + default: + t.Fatalf("unexpected prompt %q", message) + return "", nil + } + }) + defer restorePrompt() + + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + return nil + }) + defer restoreProgram() + + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(requests, qt.HasLen, tt.wantRequests) + c.Assert(prompts, qt.DeepEquals, tt.wantPrompts) + }) + } +} + +func TestTopCmdNonInteractiveVitessAmbiguityDoesNotPrompt(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + + prompts := 0 + restorePrompt := setSelectTargetForTest(t, func(message string, options []string) (string, error) { + prompts++ + return "", nil + }) + defer restorePrompt() + + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + writeVitessMultipleKeyspacesResponse(w) + }) + capture := filepath.Join(t.TempDir(), "trace.jsonl") + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main", "--capture", capture, "--duration", "200ms", "--interval", "1s"}) + + err := cmd.Execute() + + var httpErr *live.HTTPError + c.Assert(errors.As(err, &httpErr), qt.IsTrue) + c.Assert(httpErr.StatusCode, qt.Equals, http.StatusBadRequest) + c.Assert(prompts, qt.Equals, 0) +} + +func TestTopCmdHeadlessVitessMissingCaptureDoesNotPreflight(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + + requests := 0 + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + requests++ + t.Errorf("unexpected connections request: %s", r.URL.String()) + _, _ = io.WriteString(w, sampleVitessTopResponse()) + }) + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--capture is required when running without a TTY") + c.Assert(requests, qt.Equals, 0) +} + +func TestTopCmdVitessPromptCancellation(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + requests := 0 + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + requests++ + writeVitessMultipleKeyspacesResponse(w) + }) + + promptErr := errors.New("prompt canceled") + restorePrompt := setSelectTargetForTest(t, func(message string, options []string) (string, error) { + return "", promptErr + }) + defer restorePrompt() + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + t.Fatal("unexpected TUI launch") + return nil + }) + defer restoreProgram() + + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main"}) + + err := cmd.Execute() + + c.Assert(errors.Is(err, promptErr), qt.IsTrue) + c.Assert(requests, qt.Equals, 1) +} + +func TestTopCmdVitessTargetRetryExhaustion(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + requests := 0 + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + requests++ + writeVitessShardedKeyspaceResponse(w) + }) + + prompts := 0 + restorePrompt := setSelectTargetForTest(t, func(message string, options []string) (string, error) { + prompts++ + return "", nil + }) + defer restorePrompt() + + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main", "--keyspace", "commerce"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "could not resolve Vitess keyspace and shard") + c.Assert(requests, qt.Equals, 3) + c.Assert(prompts, qt.Equals, 3) +} + +func TestTopCmdVitessBackendErrorRendersInTUI(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + + prompts := 0 + restorePrompt := setSelectTargetForTest(t, func(message string, options []string) (string, error) { + prompts++ + return "", nil + }) + defer restorePrompt() + + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusTeapot) + _, _ = io.WriteString(w, `{"code":"teapot","message":"not a target prompt","available":{"keyspaces":["commerce"]}}`) + }) + + var view string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated := fetchInitialTopModel(t, model) + view = updated.View() + return nil + }) + defer restoreProgram() + + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &bytes.Buffer{}) + cmd.SetArgs([]string{"shop", "main"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(view, qt.Contains, "unable to load live connections") + c.Assert(view, qt.Contains, "not a target prompt") + c.Assert(prompts, qt.Equals, 0) +} + +func TestTopCmdRunEHeadlessVitessWritesCapture(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, false) + defer restoreTTY() + + var requests []string + server := liveConnectionsServerForTop(t, func(w http.ResponseWriter, r *http.Request) { + requests = append(requests, r.URL.RawQuery) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections") + assertTopQueryParam(c, r, "keyspace", "commerce") + assertTopQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, sampleVitessTopResponse()) + }) + capture := filepath.Join(t.TempDir(), "trace.jsonl") + var out bytes.Buffer + cmd := topCmdForServerAndEngine(server.URL, ps.DatabaseEngineMySQL, &out) + cmd.SetArgs([]string{"shop", "main", "--keyspace", "commerce", "--shard", "-80", "--capture", capture, "--duration", "200ms", "--interval", "1s"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(requests, qt.HasLen, 2) + records := readJSONLines(t, capture) + c.Assert(records, qt.HasLen, 2) + captureRecord := capturedConnectionList(c, records[1]) + c.Assert(captureRecord["database_kind"], qt.Equals, "mysql") + c.Assert(captureRecord["topology"], qt.DeepEquals, map[string]any{ + "keyspace": "commerce", + "shard": "-80", + "tablet": "zone1-1001", + }) +} + +func TestTopCmdInteractiveCaptureStartsEnabled(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + var view string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated, _ := model.Update(tea.WindowSizeMsg{Width: 200, Height: 24}) + view = updated.View() + return nil + }) + defer restoreProgram() + server := liveConnectionsServer(t, sampleTopResponse()) + capture := filepath.Join(t.TempDir(), "trace.jsonl") + cmd := topCmdForServer(server.URL, &bytes.Buffer{}) + cmd.SetArgs([]string{"pgload", "main", "--capture", capture}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(view, qt.Contains, "rec "+capture) + + records := readJSONLines(t, capture) + c.Assert(records, qt.HasLen, 1) + c.Assert(records[0]["type"], qt.Equals, "capture_start") +} + +func TestTopCmdInteractiveToggleCaptureCreatesDefaultFile(t *testing.T) { + c := qt.New(t) + restoreTTY := setPrinterTTY(t, true) + defer restoreTTY() + dir := t.TempDir() + originalCwd, err := os.Getwd() + c.Assert(err, qt.IsNil) + t.Cleanup(func() { _ = os.Chdir(originalCwd) }) + c.Assert(os.Chdir(dir), qt.IsNil) + + var view string + restoreProgram := setRunTeaProgram(t, func(model tea.Model, options ...tea.ProgramOption) error { + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("C")}) + updated, _ = updated.Update(tea.WindowSizeMsg{Width: 200, Height: 24}) + view = updated.View() + return nil + }) + defer restoreProgram() + server := liveConnectionsServer(t, sampleTopResponse()) + cmd := topCmdForServer(server.URL, &bytes.Buffer{}) + cmd.SetArgs([]string{"pgload", "main"}) + + err = cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(view, qt.Contains, "rec connections-") + matches, err := filepath.Glob(filepath.Join(dir, "connections-*.jsonl")) + c.Assert(err, qt.IsNil) + c.Assert(matches, qt.HasLen, 1) + records := readJSONLines(t, matches[0]) + c.Assert(records, qt.HasLen, 1) + c.Assert(records[0]["type"], qt.Equals, "capture_start") +} + +func setRunTeaProgram(t *testing.T, run func(tea.Model, ...tea.ProgramOption) error) func() { + t.Helper() + previous := runTeaProgram + runTeaProgram = run + return func() { runTeaProgram = previous } +} + +func liveConnectionsServer(t *testing.T, body string) *httptest.Server { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/organizations/acme/databases/pgload/branches/main/connections" { + t.Fatalf("path = %q", r.URL.Path) + } + _, _ = io.WriteString(w, body) + })) + t.Cleanup(server.Close) + return server +} + +func liveConnectionsServerForTop(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + return server +} + +func topCmdForServer(baseURL string, out *bytes.Buffer) *cobra.Command { + return topCmdForServerAndEngine(baseURL, ps.DatabaseEnginePostgres, out) +} + +func topCmdForServerAndEngine(baseURL string, engine ps.DatabaseEngine, out *bytes.Buffer) *cobra.Command { + cmd := testTopCmd(&cmdutil.Helper{Config: &config.Config{ + BaseURL: baseURL, + Organization: "acme", + ServiceTokenID: "tid", + ServiceToken: "secret", + }, Client: topDatabaseClient(engine)}) + cmd.SetOut(out) + return cmd +} + +func topDatabaseClient(engine ps.DatabaseEngine) func() (*ps.Client, error) { + return func() (*ps.Client, error) { + return &ps.Client{ + Databases: &mock.DatabaseService{ + GetFn: func(ctx context.Context, req *ps.GetDatabaseRequest) (*ps.Database, error) { + return &ps.Database{Name: req.Database, Kind: engine}, nil + }, + }, + }, nil + } +} + +func testTopCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := TopCmd(ch) + cmd.SilenceErrors = true + cmd.SilenceUsage = true + cmd.SetErr(io.Discard) + return cmd +} + +func sampleTopResponse() string { + return `{"type":"list","database_kind":"postgresql","captured_at":"2026-04-29T12:34:56Z","instances":[{"id":"primary","role":"primary","error":null}],"data":[{"pid":123,"instance":"primary","duration_ms":1000,"state":"active","usename":"alice","query_text":"select 1","xact_start":"2026-04-29T12:34:00.000Z","query_start":"2026-04-29T12:34:30.000Z","transaction_id":"primary-123-1777466040000000","query_id":"primary-123-1777466070000000"}]}` +} + +func sampleTopResponseWithoutActionIDs() string { + return `{"type":"list","database_kind":"postgresql","captured_at":"2026-04-29T12:34:56Z","instances":[{"id":"primary","role":"primary","error":null}],"data":[{"pid":123,"instance":"primary","duration_ms":1000,"state":"active","usename":"alice","query_text":"select 1","xact_start":"2026-04-29T12:34:00.000Z","query_start":"2026-04-29T12:34:30.000Z"}]}` +} + +func sampleVitessTopResponse() string { + return `{"type":"list","database_kind":"mysql","captured_at":"2026-06-04T12:30:00Z","instances":[],"topology":{"keyspace":"commerce","shard":"-80","tablet":"zone1-1001"},"data":[{"pid":101,"instance":"zone1-1001","duration_ms":42000,"state":"Query/executing","usename":"vt_app","datname":"checkout","client_addr":"10.0.0.1:1234","query_text":"SELECT 1","connection_id":"zone1-1001-101","query_id":"zone1-1001-101"}]}` +} + +func writeVitessMultipleKeyspacesResponse(w http.ResponseWriter) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"code":"bad_request","message":"This database has multiple keyspaces. Specify which keyspace to target.","available":{"keyspaces":["commerce","lookup"]}}`) +} + +func writeVitessShardedKeyspaceResponse(w http.ResponseWriter) { + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{"code":"bad_request","message":"Keyspace 'commerce' is sharded. Specify which shard to target.","available":{"shards":["-80","80-"]}}`) +} + +func fetchInitialTopModel(t *testing.T, model tea.Model) tea.Model { + t.Helper() + updated, _ := model.Update(tea.WindowSizeMsg{Width: 200, Height: 24}) + cmd := updated.Init() + if cmd == nil { + return updated + } + msg := cmd() + if batch, ok := msg.(tea.BatchMsg); ok { + if len(batch) == 0 { + return updated + } + msg = batch[0]() + } + updated, _ = updated.Update(msg) + return updated +} + +func setSelectTargetForTest(t *testing.T, fn func(string, []string) (string, error)) func() { + t.Helper() + previous := selectTopTarget + selectTopTarget = fn + return func() { selectTopTarget = previous } +} + +func assertTopQueryParam(c *qt.C, r *http.Request, key, want string) { + c.Assert(r.URL.Query().Get(key), qt.Equals, want) +} + +func capturedConnectionList(c *qt.C, record map[string]any) map[string]any { + capture, ok := record["capture"].(map[string]any) + c.Assert(ok, qt.IsTrue) + return capture +} + +func setPrinterTTY(t *testing.T, value bool) func() { + t.Helper() + previous := printer.IsTTY + printer.IsTTY = value + return func() { printer.IsTTY = previous } +} + +func readJSONLines(t *testing.T, path string) []map[string]any { + t.Helper() + lines := readLines(t, path) + records := make([]map[string]any, 0, len(lines)) + for _, line := range lines { + var record map[string]any + if err := json.Unmarshal([]byte(line), &record); err != nil { + t.Fatal(err) + } + records = append(records, record) + } + return records +} + +func readLines(t *testing.T, path string) []string { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + var lines []string + for _, line := range strings.Split(strings.TrimSpace(string(data)), "\n") { + if line != "" { + lines = append(lines, line) + } + } + return lines +} diff --git a/internal/cmd/branch/connections_commands.go b/internal/cmd/branch/connections_commands.go new file mode 100644 index 00000000..c196070b --- /dev/null +++ b/internal/cmd/branch/connections_commands.go @@ -0,0 +1,176 @@ +package branch + +import ( + "context" + "errors" + "strconv" + "strings" + + "github.com/planetscale/cli/internal/cmd/branch/connections" + "github.com/planetscale/cli/internal/cmdutil" + ps "github.com/planetscale/planetscale-go/planetscale" + "github.com/spf13/cobra" +) + +func ConnectionsShowCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + keyspace string + shard string + instance string + role string + } + + cmd := &cobra.Command{ + Use: "show ", + Short: "Show branch connections once", + Long: `Show branch connections once. + +Use --format json when an agent or script needs to inspect query_id, +transaction_id, and connection_id fields. Human output uses vertical records so +query text and action IDs are not truncated.`, + Args: cmdutil.RequiredArgs("database", "branch"), + RunE: func(cmd *cobra.Command, args []string) error { + filter := connections.ConnectionFilter{Instance: flags.instance, Role: flags.role} + if err := connections.ValidateConnectionFilter(filter); err != nil { + return err + } + engine, err := databaseEngine(cmd.Context(), ch, args[0]) + if err != nil { + return err + } + target := connections.ConnectionTarget{Keyspace: flags.keyspace, Shard: flags.shard} + if err := connections.ValidateEngineFlags(engine, filter, target); err != nil { + return err + } + return connections.RunList(cmd.Context(), ch, args[0], args[1], filter, target) + }, + } + + cmd.Flags().StringVar(&flags.keyspace, "keyspace", "", "Vitess keyspace to target") + cmd.Flags().StringVar(&flags.shard, "shard", "", "Vitess shard to target") + cmd.Flags().StringVar(&flags.instance, "instance", "", "Postgres instance to target") + cmd.Flags().StringVar(&flags.role, "role", "", "Postgres instance role to target: primary or replica") + + return cmd +} + +func ConnectionsKillCmd(ch *cmdutil.Helper) *cobra.Command { + var flags struct { + keyspace string + shard string + query bool + } + + cmd := &cobra.Command{ + Use: "kill ", + Short: "Kill a branch connection or query", + Long: `Kill a branch connection or query. + +This is destructive. Pass a connection_id from connections show to terminate a +connection, or pass --query with a query_id from connections show to cancel only +the current query.`, + Args: cmdutil.RequiredArgs("database", "branch", "id"), + RunE: func(cmd *cobra.Command, args []string) error { + if flags.query { + if err := connections.ValidateQueryID(args[2]); err != nil { + return err + } + } else { + if err := connections.ValidateConnectionID(args[2]); err != nil { + return err + } + } + engine, err := databaseEngine(cmd.Context(), ch, args[0]) + if err != nil { + return err + } + target := connections.ConnectionTarget{Keyspace: flags.keyspace, Shard: flags.shard} + if err := connections.ValidateEngineFlags(engine, connections.ConnectionFilter{}, target); err != nil { + return err + } + if flags.query { + return connections.RunCancelQueryForEngine(cmd.Context(), ch, args[0], args[1], args[2], engine, target) + } + return connections.RunKillConnectionForEngine(cmd.Context(), ch, args[0], args[1], args[2], engine, target) + }, + } + + cmd.Flags().StringVar(&flags.keyspace, "keyspace", "", "Vitess keyspace to target") + cmd.Flags().StringVar(&flags.shard, "shard", "", "Vitess shard to target") + cmd.Flags().BoolVar(&flags.query, "query", false, "Cancel the query_id instead of terminating the connection_id") + + return cmd +} + +func ConnectionsKillTransactionCmd(ch *cmdutil.Helper) *cobra.Command { + cmd := &cobra.Command{ + Use: "kill-transaction ", + Short: "Kill a Postgres branch transaction", + Long: `Kill a Postgres branch transaction. + +This is destructive. Pass a transaction_id from connections show to terminate +the matching Postgres connection.`, + Args: cmdutil.RequiredArgs("database", "branch", "transaction-id"), + RunE: func(cmd *cobra.Command, args []string) error { + if err := connections.ValidateTransactionID(args[2]); err != nil { + return err + } + engine, err := databaseEngine(cmd.Context(), ch, args[0]) + if err != nil { + return err + } + if engine != ps.DatabaseEnginePostgres { + return errors.New("connections kill-transaction is only supported for Postgres databases") + } + return connections.RunKillTransactionForEngine(cmd.Context(), ch, args[0], args[1], args[2], engine, connections.ConnectionTarget{}) + }, + } + + return cmd +} + +func databaseEngine(ctx context.Context, ch *cmdutil.Helper, database string) (ps.DatabaseEngine, error) { + client, err := ch.Client() + if err != nil { + return "", err + } + + db, err := client.Databases.Get(ctx, &ps.GetDatabaseRequest{ + Organization: ch.Config.Organization, + Database: database, + }) + if err != nil { + return "", err + } + if db == nil { + return "", errors.New("database not found") + } + return db.Kind, nil +} + +func vitessConnectionID(raw string) (int64, error) { + id, err := strconv.ParseInt(strings.TrimSpace(raw), 10, 64) + if err != nil || id <= 0 { + return 0, errors.New("id must be a positive integer") + } + return id, nil +} + +func isNegativeIDFlagError(err error) bool { + const marker = " in -" + + _, suffix, ok := strings.Cut(err.Error(), marker) + if !ok { + return false + } + + _, parseErr := strconv.ParseInt("-"+suffix, 10, 64) + return parseErr == nil +} + +func vitessConnectionIDFlagError(_ *cobra.Command, err error) error { + if isNegativeIDFlagError(err) { + return errors.New("id must be a positive integer") + } + return err +} diff --git a/internal/cmd/branch/connections_test.go b/internal/cmd/branch/connections_test.go new file mode 100644 index 00000000..a880e6ba --- /dev/null +++ b/internal/cmd/branch/connections_test.go @@ -0,0 +1,763 @@ +package branch + +import ( + "bytes" + "context" + "encoding/csv" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "testing" + + qt "github.com/frankban/quicktest" + + "github.com/planetscale/cli/internal/cmdutil" + "github.com/planetscale/cli/internal/config" + "github.com/planetscale/cli/internal/mock" + "github.com/planetscale/cli/internal/printer" + ps "github.com/planetscale/planetscale-go/planetscale" + "github.com/spf13/cobra" +) + +func TestConnectionsCmdConstruction(t *testing.T) { + c := qt.New(t) + + cmd := ConnectionsCmd(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, "http://example.invalid", printer.JSON, &bytes.Buffer{})) + + c.Assert(cmd.Use, qt.Equals, "connections ") + c.Assert(cmd.Aliases, qt.HasLen, 0) + names := commandNames(cmd) + for _, name := range []string{"kill", "kill-transaction", "show", "top"} { + c.Assert(slices.Contains(names, name), qt.IsTrue) + } +} + +func TestConnectionsShowHelpListsTargetAndFilterFlags(t *testing.T) { + c := qt.New(t) + + help := connectionsHelpForTest(c, "show", "--help") + c.Assert(help, qt.Contains, "--keyspace") + c.Assert(help, qt.Contains, "--shard") + c.Assert(help, qt.Contains, "--instance") + c.Assert(help, qt.Contains, "--role") +} + +func TestConnectionsShowHelpKeepsAgentWorkflow(t *testing.T) { + c := qt.New(t) + + help := connectionsHelpForTest(c, "show", "--help") + c.Assert(help, qt.Contains, "Use --format json when an agent or script needs to inspect query_id,") + c.Assert(help, qt.Contains, "transaction_id, and connection_id fields.") + c.Assert(help, qt.Contains, "Human output uses vertical records so") + c.Assert(help, qt.Contains, "query text and action IDs are not truncated.") +} + +func TestConnectionsKillHelpWarnsAboutDestructiveActions(t *testing.T) { + tests := []struct { + name string + args []string + want []string + }{ + { + name: "kill", + args: []string{"kill", "--help"}, + want: []string{"destructive", "connection_id", "query_id", "--query"}, + }, + { + name: "kill transaction", + args: []string{"kill-transaction", "--help"}, + want: []string{"Postgres", "destructive", "transaction_id"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + help := connectionsHelpForTest(c, tt.args...) + for _, want := range tt.want { + c.Assert(help, qt.Contains, want) + } + }) + } +} + +func TestBranchCmdRegistersConnectionsAndHiddenProcesslist(t *testing.T) { + c := qt.New(t) + + cmd := BranchCmd(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, "http://example.invalid", printer.JSON, &bytes.Buffer{})) + + connections := findCommand(cmd, "connections") + c.Assert(connections, qt.Not(qt.IsNil)) + c.Assert(connections.Hidden, qt.Equals, false) + + processlist := findCommand(cmd, "processlist") + c.Assert(processlist, qt.Not(qt.IsNil)) + c.Assert(processlist.Hidden, qt.Equals, true) +} + +func TestProcesslistHelpPointsToConnectionsCommands(t *testing.T) { + tests := []struct { + name string + args []string + want []string + wantNot []string + }{ + { + name: "show", + args: []string{"--org", "acme", "processlist", "show", "--help"}, + want: []string{ + "pscale branch connections show", + "pscale branch connections kill", + "connection_id", + "query_id", + }, + wantNot: []string{"pscale branch processlist kill"}, + }, + { + name: "kill", + args: []string{"--org", "acme", "processlist", "kill", "--help"}, + want: []string{ + "pscale branch connections kill", + "connection_id", + "query_id", + }, + wantNot: []string{"as shown in \"pscale branch processlist show\""}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + help := branchHelpForTest(c, tt.args...) + for _, want := range tt.want { + c.Assert(help, qt.Contains, want) + } + for _, wantNot := range tt.wantNot { + c.Assert(help, qt.Not(qt.Contains), wantNot) + } + }) + } +} + +func TestProcesslistShowMatchesConnectionsShowOutput(t *testing.T) { + c := qt.New(t) + + body := processlistListBody("commerce", "-80", "zone1-1001", `{"pid":101,"instance":"zone1-1001","usename":"vt_app","client_addr":"10.0.0.12:54231","datname":"checkout","state":"Query","duration_ms":42000,"connection_id":"zone1-1001-101","query_id":"zone1-1001-101","query_text":"SELECT 1"}`) + + var connectionsOut bytes.Buffer + connectionsServer := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections") + assertQueryParam(c, r, "keyspace", "commerce") + assertQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, body) + }) + connections := BranchCmd(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, connectionsServer.URL, printer.JSON, &connectionsOut)) + connections.SetArgs([]string{"--org", "acme", "connections", "show", "shop", "main", "--keyspace", "commerce", "--shard", "-80"}) + + var processlistOut bytes.Buffer + processlistServer := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections") + assertQueryParam(c, r, "keyspace", "commerce") + assertQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, body) + }) + processlist := BranchCmd(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, processlistServer.URL, printer.JSON, &processlistOut)) + processlist.SetArgs([]string{"--org", "acme", "processlist", "show", "shop", "main", "--keyspace", "commerce", "--shard", "-80"}) + + c.Assert(connections.Execute(), qt.IsNil) + c.Assert(processlist.Execute(), qt.IsNil) + assertJSONEqual(c, connectionsOut.String(), processlistOut.String()) +} + +func TestProcesslistRejectsPostgresBranches(t *testing.T) { + tests := []struct { + name string + args []string + }{ + {name: "show", args: []string{"--org", "acme", "processlist", "show", "pgload", "main"}}, + {name: "kill", args: []string{"--org", "acme", "processlist", "kill", "pgload", "main", "101"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Fatalf("processlist should reject Postgres before calling %s", r.URL.Path) + }) + + out, errOut, err := executeBranchCommandForTest( + connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{}), + tt.args, + ) + + c.Assert(err, qt.ErrorMatches, "processlist is only supported for Vitess databases") + c.Assert(out, qt.Equals, "") + c.Assert(errOut, qt.Equals, "") + }) + } +} + +func TestConnectionsShowDispatchesVitess(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections") + assertQueryParam(c, r, "keyspace", "commerce") + assertQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, processlistListBody("commerce", "-80", "zone1-1001", `{"pid":101,"instance":"zone1-1001","usename":"vt_app","client_addr":"10.0.0.12:54231","datname":"checkout","state":"Query","duration_ms":42000,"connection_id":"zone1-1001-101","query_id":"zone1-1001-101","query_text":"SELECT 1"}`)) + }) + + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, server.URL, printer.JSON, &out)) + cmd.SetArgs([]string{"show", "shop", "main", "--keyspace", "commerce", "--shard", "-80"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(out.String(), qt.Contains, `"topology": {`) + c.Assert(out.String(), qt.Contains, `"tablet": "zone1-1001"`) + c.Assert(out.String(), qt.Contains, `"database": "checkout"`) + c.Assert(out.String(), qt.Contains, `"connection_id": "zone1-1001-101"`) +} + +func TestConnectionsShowRejectsEngineSpecificFlags(t *testing.T) { + tests := []struct { + name string + engine ps.DatabaseEngine + args []string + wantErr string + }{ + { + name: "vitess rejects postgres instance filter", + engine: ps.DatabaseEngineMySQL, + args: []string{"show", "shop", "main", "--instance", "primary"}, + wantErr: "--instance/--role are only supported for Postgres databases", + }, + { + name: "vitess rejects postgres role filter", + engine: ps.DatabaseEngineMySQL, + args: []string{"show", "shop", "main", "--role", "primary"}, + wantErr: "--instance/--role are only supported for Postgres databases", + }, + { + name: "postgres rejects vitess target", + engine: ps.DatabaseEnginePostgres, + args: []string{"show", "pgload", "main", "--keyspace", "commerce", "--shard", "-80"}, + wantErr: "--keyspace/--shard are only supported for Vitess databases", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Fatalf("show should reject flags before calling %s", r.URL.Path) + }) + cmd := connectionsCmdForTest(connectionsTestHelper("acme", tt.engine, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, tt.wantErr) + }) + } +} + +func TestEngineFlagValidationAfterLookupFailure(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Fatalf("show should surface database lookup errors before calling %s", r.URL.Path) + }) + databases := &mock.DatabaseService{ + GetFn: func(context.Context, *ps.GetDatabaseRequest) (*ps.Database, error) { + return nil, errors.New("database lookup failed") + }, + } + cmd := connectionsCmdForTest(connectionsTestHelperWithDatabaseService("acme", databases, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"show", "shop", "main", "--keyspace", "commerce", "--instance", "primary"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "database lookup failed") +} + +func TestConnectionsShowDispatchesPostgres(t *testing.T) { + c := qt.New(t) + + var gotPath string + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + c.Assert(r.Method, qt.Equals, http.MethodGet) + _, _ = io.WriteString(w, sampleBranchConnectionsListResponse()) + }) + + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &out)) + cmd.SetArgs([]string{"show", "pgload", "main", "--instance", "primary"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + c.Assert(gotPath, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections") + c.Assert(out.String(), qt.Contains, `"connection_id": "primary-123-c"`) + c.Assert(out.String(), qt.Not(qt.Contains), `"topology"`) +} + +func TestConnectionsKillDispatchesVitess(t *testing.T) { + tests := []struct { + name string + args []string + wantKind string + wantID string + }{ + { + name: "connection", + args: []string{"kill", "shop", "main", "zone1-1001-101", "--keyspace", "commerce", "--shard", "-80"}, + wantKind: "connection", + wantID: "zone1-1001-101", + }, + { + name: "query", + args: []string{"kill", "shop", "main", "zone1-1001-101", "--query", "--keyspace", "commerce", "--shard", "-80"}, + wantKind: "query", + wantID: "zone1-1001-101", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + wantPath := "/v1/organizations/acme/databases/shop/branches/main/connections/connection/" + tt.wantID + if tt.wantKind == "query" { + wantPath = "/v1/organizations/acme/databases/shop/branches/main/connections/query/" + tt.wantID + } + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, wantPath) + assertQueryParam(c, r, "keyspace", "commerce") + assertQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, `{"success":true,"id":101,"kind":"`+tt.wantKind+`","keyspace":"commerce","shard":"-80","tablet":"zone1-1001"}`) + }) + + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, server.URL, printer.JSON, &out)) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + }) + } +} + +func TestConnectionsKillRejectsWrongEngineTargetFlags(t *testing.T) { + tests := []struct { + name string + args []string + }{ + {name: "connection", args: []string{"kill", "pgload", "main", "primary-123-c", "--keyspace", "commerce", "--shard", "-80"}}, + {name: "query", args: []string{"kill", "pgload", "main", "primary-123-q", "--query", "--keyspace", "commerce", "--shard", "-80"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Fatalf("kill should reject flags before calling %s", r.URL.Path) + }) + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--keyspace/--shard are only supported for Vitess databases") + }) + } +} + +func TestProcesslistKillKeepsNumericIDValidation(t *testing.T) { + tests := []struct { + name string + args []string + }{ + {name: "not numeric", args: []string{"--org", "acme", "processlist", "kill", "shop", "main", "primary-123-c"}}, + {name: "zero", args: []string{"--org", "acme", "processlist", "kill", "shop", "main", "0"}}, + {name: "negative", args: []string{"--org", "acme", "processlist", "kill", "shop", "main", "--", "-1"}}, + {name: "negative without separator", args: []string{"--org", "acme", "processlist", "kill", "shop", "main", "-1"}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + + out, errOut, err := executeBranchCommandForTest( + processlistTestHelper("acme", "http://127.0.0.1:1", printer.JSON, &bytes.Buffer{}), + tt.args, + ) + + c.Assert(err, qt.ErrorMatches, "id must be a positive integer") + c.Assert(out, qt.Equals, "") + c.Assert(errOut, qt.Equals, "") + }) + } +} + +func TestConnectionsKillDispatchesPostgres(t *testing.T) { + tests := []struct { + name string + args []string + wantPath string + }{ + {name: "connection", args: []string{"kill", "pgload", "main", "primary-123-c"}, wantPath: "/v1/organizations/acme/databases/pgload/branches/main/connections/connection/primary-123-c"}, + {name: "query", args: []string{"kill", "pgload", "main", "primary-123-q", "--query"}, wantPath: "/v1/organizations/acme/databases/pgload/branches/main/connections/query/primary-123-q"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, tt.wantPath) + w.WriteHeader(http.StatusNoContent) + }) + + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + }) + } +} + +func TestConnectionsKillTrimsPostgresIDAndRejectsEmpty(t *testing.T) { + c := qt.New(t) + + var gotPath string + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + w.WriteHeader(http.StatusNoContent) + }) + + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"kill", "pgload", "main", " primary-123-c "}) + err := cmd.Execute() + c.Assert(err, qt.IsNil) + c.Assert(gotPath, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections/connection/primary-123-c") + + empty := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{})) + empty.SetArgs([]string{"kill", "pgload", "main", " "}) + err = empty.Execute() + c.Assert(err, qt.ErrorMatches, "connection-id is required") +} + +func TestConnectionsKillRejectsEmptyIDBeforeLookup(t *testing.T) { + c := qt.New(t) + + databases := &mock.DatabaseService{ + GetFn: func(context.Context, *ps.GetDatabaseRequest) (*ps.Database, error) { + return nil, errors.New("database lookup should not be called") + }, + } + cmd := connectionsCmdForTest(connectionsTestHelperWithDatabaseService("acme", databases, nil, "http://127.0.0.1:1", printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"kill", "pgload", "main", " "}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "connection-id is required") + c.Assert(databases.GetFnInvoked, qt.IsFalse) +} + +func TestConnectionsKillTransactionDispatches(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections/transaction/primary-123-t") + w.WriteHeader(http.StatusNoContent) + }) + + pg := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &bytes.Buffer{})) + pg.SetArgs([]string{"kill-transaction", "pgload", "main", "primary-123-t"}) + c.Assert(pg.Execute(), qt.IsNil) +} + +func TestConnectionsKillTransactionRejectsEmptyIDBeforeLookup(t *testing.T) { + c := qt.New(t) + + databases := &mock.DatabaseService{ + GetFn: func(context.Context, *ps.GetDatabaseRequest) (*ps.Database, error) { + return nil, errors.New("database lookup should not be called") + }, + } + cmd := connectionsCmdForTest(connectionsTestHelperWithDatabaseService("acme", databases, nil, "http://127.0.0.1:1", printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"kill-transaction", "pgload", "main", " "}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "transaction-id is required") + c.Assert(databases.GetFnInvoked, qt.IsFalse) +} + +func TestConnectionsKillTransactionRejectsVitess(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Fatalf("kill-transaction should reject Vitess before calling %s", r.URL.Path) + }) + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, server.URL, printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"kill-transaction", "shop", "main", "tx-123"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "connections kill-transaction is only supported for Postgres databases") +} + +func TestPostgresActionResultOmitsVitessTopologyColumns(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/pgload/branches/main/connections/connection/primary-123-c") + _, _ = io.WriteString(w, `{"success":true,"id":101,"kind":"connection","keyspace":"commerce","shard":"-80","tablet":"zone1-1001"}`) + }) + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.CSV, &out)) + cmd.SetArgs([]string{"kill", "pgload", "main", "primary-123-c"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + headers := readCSVRows(c, out.String())[0] + c.Assert(headers, qt.Not(qt.Contains), "keyspace") + c.Assert(headers, qt.Not(qt.Contains), "shard") + c.Assert(headers, qt.Not(qt.Contains), "tablet") + c.Assert(headers, qt.Contains, "success") + c.Assert(headers, qt.Contains, "id") + c.Assert(headers, qt.Contains, "kind") +} + +func TestVitessActionResultKeepsTopologyShape(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/acme/databases/shop/branches/main/connections/connection/zone1-1001-101") + assertQueryParam(c, r, "keyspace", "commerce") + assertQueryParam(c, r, "shard", "-80") + _, _ = io.WriteString(w, `{"success":true,"id":101,"kind":"connection","keyspace":"commerce","shard":"-80","tablet":"zone1-1001"}`) + }) + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, server.URL, printer.CSV, &out)) + cmd.SetArgs([]string{"kill", "shop", "main", "zone1-1001-101", "--keyspace", "commerce", "--shard", "-80"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + rows := readCSVRows(c, out.String()) + c.Assert(rows[0], qt.Contains, "keyspace") + c.Assert(rows[0], qt.Contains, "shard") + c.Assert(rows[0], qt.Contains, "tablet") + c.Assert(rows[1], qt.Contains, "commerce") + c.Assert(rows[1], qt.Contains, "-80") + c.Assert(rows[1], qt.Contains, "zone1-1001") +} + +func TestActionResultJSONShapeStable(t *testing.T) { + c := qt.New(t) + + server := liveConnectionsBranchServer(t, func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + _, _ = io.WriteString(w, `{"success":true,"id":101,"kind":"connection","keyspace":"commerce","shard":"-80","tablet":"zone1-1001"}`) + }) + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, server.URL, printer.JSON, &out)) + cmd.SetArgs([]string{"kill", "pgload", "main", "primary-123-c"}) + + err := cmd.Execute() + + c.Assert(err, qt.IsNil) + var got map[string]any + c.Assert(json.Unmarshal(out.Bytes(), &got), qt.IsNil) + c.Assert(got, qt.DeepEquals, map[string]any{ + "success": true, + "keyspace": "commerce", + "shard": "-80", + "tablet": "zone1-1001", + "id": float64(101), + "kind": "connection", + }) +} + +func TestConnectionsRoleValidationStillWorksOnPostgres(t *testing.T) { + tests := []struct { + name string + args []string + wantErr string + }{ + {name: "unknown role", args: []string{"show", "pgload", "main", "--role", "writer"}, wantErr: "--role must be primary or replica"}, + {name: "role with instance", args: []string{"show", "pgload", "main", "--role", "primary", "--instance", "primary"}, wantErr: "--role cannot be combined with --instance"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEnginePostgres, nil, "http://127.0.0.1:1", printer.JSON, &bytes.Buffer{})) + cmd.SetArgs(tt.args) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, tt.wantErr) + }) + } +} + +func TestConnectionsShowRejectsInvalidRoleBeforeLookup(t *testing.T) { + c := qt.New(t) + + databases := &mock.DatabaseService{ + GetFn: func(context.Context, *ps.GetDatabaseRequest) (*ps.Database, error) { + return nil, errors.New("database lookup should not be called") + }, + } + cmd := connectionsCmdForTest(connectionsTestHelperWithDatabaseService("acme", databases, nil, "http://127.0.0.1:1", printer.JSON, &bytes.Buffer{})) + cmd.SetArgs([]string{"show", "pgload", "main", "--role", "writer"}) + + err := cmd.Execute() + + c.Assert(err, qt.ErrorMatches, "--role must be primary or replica") + c.Assert(databases.GetFnInvoked, qt.IsFalse) +} + +func connectionsTestHelper(org string, engine ps.DatabaseEngine, processlist ps.ProcesslistService, baseURL string, format printer.Format, out *bytes.Buffer) *cmdutil.Helper { + return connectionsTestHelperWithDatabaseService(org, databaseServiceForEngine(org, engine), processlist, baseURL, format, out) +} + +func connectionsTestHelperWithDatabaseService(org string, databases ps.DatabasesService, processlist ps.ProcesslistService, baseURL string, format printer.Format, out *bytes.Buffer) *cmdutil.Helper { + p := printer.NewPrinter(&format) + p.SetHumanOutput(out) + p.SetResourceOutput(out) + + return &cmdutil.Helper{ + Config: &config.Config{ + AccessToken: "token", + BaseURL: baseURL, + Organization: org, + }, + Printer: p, + Client: func() (*ps.Client, error) { + return &ps.Client{Databases: databases, Processlist: processlist}, nil + }, + } +} + +func databaseServiceForEngine(org string, engine ps.DatabaseEngine) *mock.DatabaseService { + return &mock.DatabaseService{ + GetFn: func(ctx context.Context, req *ps.GetDatabaseRequest) (*ps.Database, error) { + if req.Organization != org { + return nil, errors.New("unexpected organization") + } + return &ps.Database{Name: req.Database, Kind: engine}, nil + }, + } +} + +func connectionsCmdForTest(ch *cmdutil.Helper) *cobra.Command { + cmd := ConnectionsCmd(ch) + cmd.SilenceErrors = true + cmd.SilenceUsage = true + cmd.SetErr(io.Discard) + return cmd +} + +func connectionsHelpForTest(c *qt.C, args ...string) string { + var out bytes.Buffer + cmd := connectionsCmdForTest(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, "http://example.invalid", printer.Human, &out)) + cmd.SetOut(&out) + cmd.SetArgs(args) + + c.Assert(cmd.Execute(), qt.IsNil) + return out.String() +} + +func branchHelpForTest(c *qt.C, args ...string) string { + var out bytes.Buffer + cmd := BranchCmd(connectionsTestHelper("acme", ps.DatabaseEngineMySQL, nil, "http://example.invalid", printer.Human, &out)) + cmd.SilenceErrors = true + cmd.SilenceUsage = true + cmd.SetOut(&out) + cmd.SetArgs(args) + + c.Assert(cmd.Execute(), qt.IsNil) + return out.String() +} + +func assertQueryParam(c *qt.C, r *http.Request, key, want string) { + c.Assert(r.URL.Query().Get(key), qt.Equals, want) +} + +func assertJSONEqual(c *qt.C, want, got string) { + var wantJSON any + var gotJSON any + c.Assert(json.Unmarshal([]byte(want), &wantJSON), qt.IsNil) + c.Assert(json.Unmarshal([]byte(got), &gotJSON), qt.IsNil) + c.Assert(gotJSON, qt.DeepEquals, wantJSON) +} + +func liveConnectionsBranchServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + server := httptest.NewServer(handler) + t.Cleanup(server.Close) + return server +} + +func sampleBranchConnectionsListResponse() string { + return `{"type":"list","database_kind":"postgresql","captured_at":"2026-04-29T12:34:56Z","instances":[{"id":"primary","role":"primary","error":null}],"data":[{"pid":123,"instance":"primary","duration_ms":664000,"state":"active","usename":"alice","application_name":"psql","client_addr":"10.0.0.1","query_text":"SELECT pg_sleep(600)","xact_start":"2026-04-29T12:23:52Z","query_start":"2026-04-29T12:23:52Z","query_id":"primary-123-q","transaction_id":"primary-123-t","connection_id":"primary-123-c"}]}` +} + +func commandNames(cmd *cobra.Command) []string { + names := make([]string, 0, len(cmd.Commands())) + for _, child := range cmd.Commands() { + names = append(names, child.Name()) + } + return names +} + +func findCommand(cmd *cobra.Command, name string) *cobra.Command { + for _, child := range cmd.Commands() { + if child.Name() == name { + return child + } + } + return nil +} + +func executeBranchCommandForTest(ch *cmdutil.Helper, args []string) (string, string, error) { + var out bytes.Buffer + var errOut bytes.Buffer + cmd := BranchCmd(ch) + cmd.SilenceErrors = true + cmd.SilenceUsage = true + cmd.SetOut(&out) + cmd.SetErr(&errOut) + cmd.SetArgs(args) + + err := cmd.Execute() + + return out.String(), errOut.String(), err +} + +func readCSVRows(c *qt.C, raw string) [][]string { + rows, err := csv.NewReader(strings.NewReader(raw)).ReadAll() + c.Assert(err, qt.IsNil) + return rows +} diff --git a/internal/cmd/branch/kill.go b/internal/cmd/branch/kill.go index cfd5720c..26004095 100644 --- a/internal/cmd/branch/kill.go +++ b/internal/cmd/branch/kill.go @@ -1,40 +1,13 @@ package branch import ( - "fmt" "strconv" + "github.com/planetscale/cli/internal/cmd/branch/connections" "github.com/planetscale/cli/internal/cmdutil" - "github.com/planetscale/cli/internal/printer" - ps "github.com/planetscale/planetscale-go/planetscale" "github.com/spf13/cobra" ) -// KillProcessResult is the CLI representation of a killed process response. -type KillProcessResult struct { - Success bool `header:"success" json:"success"` - Keyspace string `header:"keyspace" json:"keyspace"` - Shard string `header:"shard" json:"shard"` - Tablet string `header:"tablet" json:"tablet"` - ID int64 `header:"id,text" json:"id"` - Kind string `header:"kind" json:"kind"` -} - -func (k *KillProcessResult) MarshalCSVValue() interface{} { - return []*KillProcessResult{k} -} - -func toKillProcessResult(result *ps.KillProcessResult) *KillProcessResult { - return &KillProcessResult{ - Success: result.Success, - Keyspace: result.Keyspace, - Shard: result.Shard, - Tablet: result.Tablet, - ID: result.ID, - Kind: result.Kind, - } -} - // ProcesslistKillCmd kills a running MySQL process on a Vitess branch. func ProcesslistKillCmd(ch *cmdutil.Helper) *cobra.Command { var flags struct { @@ -46,76 +19,43 @@ func ProcesslistKillCmd(ch *cmdutil.Helper) *cobra.Command { cmd := &cobra.Command{ Use: "kill ", Short: "Kill a running MySQL process on a branch (Vitess only)", - Long: `Kill a MySQL process by ID, as shown in "pscale branch processlist show". + Long: `Compatibility command for killing Vitess processlist IDs. This command is only supported for Vitess (MySQL) databases. +Use "pscale branch connections kill" with connection_id and query_id values for +new workflows. + By default the entire connection is killed (KILL ). Pass --query to kill only the currently running statement (KILL QUERY ). The process must live on the same primary tablet that the process list was read from, so the same --keyspace/--shard targeting rules apply.`, Args: cmdutil.RequiredArgs("database", "branch", "id"), RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() database, branch := args[0], args[1] - id, err := strconv.ParseInt(args[2], 10, 64) - if err != nil || id <= 0 { - return fmt.Errorf("id must be a positive integer, got %q", args[2]) - } - - kind := "connection" - if flags.query { - kind = "query" - } - - client, err := ch.Client() - if err != nil { + if err := requireProcesslistDatabase(cmd.Context(), ch, database); err != nil { return err } - end := ch.Printer.PrintProgress( - fmt.Sprintf("Killing process %s on %s\u2026", - printer.BoldBlue(id), printer.BoldBlue(fmt.Sprintf("%s/%s/%s", ch.Config.Organization, database, branch)))) - defer end() - - result, err := client.Processlist.Kill(ctx, &ps.KillProcessRequest{ - Organization: ch.Config.Organization, - Database: database, - Branch: branch, - Keyspace: flags.keyspace, - Shard: flags.shard, - ID: id, - Kind: kind, - }) + id, err := vitessConnectionID(args[2]) if err != nil { - switch cmdutil.ErrCode(err) { - case ps.ErrNotFound: - return cmdutil.HandleNotFoundWithServiceTokenCheck( - ctx, cmd, ch.Config, ch.Client, err, "read_branch", - "process %s or branch %s does not exist in database %s (organization: %s)", - printer.BoldBlue(id), printer.BoldBlue(branch), printer.BoldBlue(database), printer.BoldBlue(ch.Config.Organization)) - default: - return cmdutil.HandleError(err) - } + return err } - end() - - if ch.Printer.Format() == printer.Human { - ch.Printer.Printf("Killed %s %s on keyspace %s shard %s (tablet %s).\n", - result.Kind, printer.BoldBlue(result.ID), - printer.BoldBlue(result.Keyspace), printer.BoldBlue(result.Shard), printer.BoldBlue(result.Tablet)) - return nil + target := connections.ConnectionTarget{Keyspace: flags.keyspace, Shard: flags.shard} + idText := strconv.FormatInt(id, 10) + if flags.query { + return connections.RunCancelQuery(cmd.Context(), ch, database, branch, idText, target) } - - return ch.Printer.PrintResource(toKillProcessResult(result)) + return connections.RunKillConnection(cmd.Context(), ch, database, branch, idText, target) }, } cmd.Flags().StringVar(&flags.keyspace, "keyspace", "", "Keyspace to target (required when the database has multiple keyspaces)") cmd.Flags().StringVar(&flags.shard, "shard", "", "Shard to target (required when the targeted keyspace is sharded)") cmd.Flags().BoolVar(&flags.query, "query", false, "Kill only the running query (KILL QUERY) instead of the whole connection") + cmd.SetFlagErrorFunc(vitessConnectionIDFlagError) return cmd } diff --git a/internal/cmd/branch/kill_test.go b/internal/cmd/branch/kill_test.go index aaad40c5..eb3a1237 100644 --- a/internal/cmd/branch/kill_test.go +++ b/internal/cmd/branch/kill_test.go @@ -2,61 +2,85 @@ package branch import ( "bytes" - "context" + "io" + "net/http" + "net/http/httptest" "testing" qt "github.com/frankban/quicktest" - "github.com/planetscale/cli/internal/mock" "github.com/planetscale/cli/internal/printer" - ps "github.com/planetscale/planetscale-go/planetscale" ) +func processlistKillServer(t *testing.T, c *qt.C, wantMethod, wantPath, wantQuery, body string) *httptest.Server { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, wantMethod) + c.Assert(r.URL.Path, qt.Equals, wantPath) + c.Assert(r.URL.RawQuery, qt.Equals, wantQuery) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, body) + })) + t.Cleanup(server.Close) + return server +} + func TestKill(t *testing.T) { c := qt.New(t) org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - KillFn: func(ctx context.Context, req *ps.KillProcessRequest) (*ps.KillProcessResult, error) { - c.Assert(req.Organization, qt.Equals, org) - c.Assert(req.Database, qt.Equals, db) - c.Assert(req.Branch, qt.Equals, branch) - c.Assert(req.ID, qt.Equals, int64(101)) - c.Assert(req.Kind, qt.Equals, "connection") - return &ps.KillProcessResult{ - Success: true, Keyspace: "main", Shard: "-", Tablet: "zone1-2001", ID: 101, Kind: "connection", - }, nil - }, - } + server := processlistKillServer(t, c, + http.MethodDelete, + "/v1/organizations/my-org/databases/my-db/branches/my-branch/connections/connection/101", + "", + `{"success":true,"keyspace":"main","shard":"-","tablet":"zone1-2001","id":101,"kind":"connection"}`) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"kill", db, branch, "101"}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(svc.KillFnInvoked, qt.IsTrue) c.Assert(buf.String(), qt.Contains, `"success": true`) } +func TestConnectionsKillUsesConnectionsEndpoint(t *testing.T) { + c := qt.New(t) + + org, db, branch := "my-org", "my-db", "my-branch" + + server := processlistKillServer(t, c, + http.MethodDelete, + "/v1/organizations/my-org/databases/my-db/branches/my-branch/connections/query/zone1-1001-101", + "keyspace=commerce&shard=-80", + `{"success":true,"keyspace":"commerce","shard":"-80","tablet":"zone1-1001","id":101,"kind":"query"}`) + + var buf bytes.Buffer + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) + + cmd := ConnectionsCmd(ch) + cmd.SetArgs([]string{"kill", db, branch, "zone1-1001-101", "--query", "--keyspace", "commerce", "--shard", "-80"}) + err := cmd.Execute() + + c.Assert(err, qt.IsNil) +} + func TestKill_CSVOutput(t *testing.T) { c := qt.New(t) org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - KillFn: func(ctx context.Context, req *ps.KillProcessRequest) (*ps.KillProcessResult, error) { - return &ps.KillProcessResult{ - Success: true, Keyspace: "main", Shard: "-", Tablet: "zone1-2001", ID: req.ID, Kind: "connection", - }, nil - }, - } + server := processlistKillServer(t, c, + http.MethodDelete, + "/v1/organizations/my-org/databases/my-db/branches/my-branch/connections/connection/101", + "", + `{"success":true,"keyspace":"main","shard":"-","tablet":"zone1-2001","id":101,"kind":"connection"}`) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.CSV, &buf) + ch := processlistTestHelper(org, server.URL, printer.CSV, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"kill", db, branch, "101"}) @@ -73,42 +97,33 @@ func TestKill_QueryFlag(t *testing.T) { org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - KillFn: func(ctx context.Context, req *ps.KillProcessRequest) (*ps.KillProcessResult, error) { - c.Assert(req.Kind, qt.Equals, "query") - return &ps.KillProcessResult{Success: true, ID: req.ID, Kind: "query"}, nil - }, - } + server := processlistKillServer(t, c, + http.MethodDelete, + "/v1/organizations/my-org/databases/my-db/branches/my-branch/connections/query/101", + "", + `{"success":true,"id":101,"kind":"query"}`) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"kill", db, branch, "101", "--query"}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(svc.KillFnInvoked, qt.IsTrue) } func TestKill_InvalidID(t *testing.T) { c := qt.New(t) - svc := &mock.ProcesslistService{ - KillFn: func(ctx context.Context, req *ps.KillProcessRequest) (*ps.KillProcessResult, error) { - return nil, nil - }, - } - var buf bytes.Buffer - ch := processlistTestHelper("my-org", svc, printer.JSON, &buf) + ch := processlistTestHelper("my-org", "http://127.0.0.1:1", printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"kill", "my-db", "my-branch", "not-a-number"}) err := cmd.Execute() c.Assert(err, qt.IsNotNil) - c.Assert(svc.KillFnInvoked, qt.IsFalse) } func TestKill_NotFound(t *testing.T) { @@ -116,26 +131,22 @@ func TestKill_NotFound(t *testing.T) { org, db, branch := "my-org", "missing-db", "missing-branch" - svc := &mock.ProcesslistService{ - KillFn: func(ctx context.Context, req *ps.KillProcessRequest) (*ps.KillProcessResult, error) { - return nil, &ps.Error{Code: ps.ErrNotFound} - }, - } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodDelete) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/my-org/databases/missing-db/branches/missing-branch/connections/connection/101") + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, `{"message":"not found"}`) + })) + t.Cleanup(server.Close) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"kill", db, branch, "101"}) err := cmd.Execute() c.Assert(err, qt.IsNotNil) - c.Assert(err.Error(), qt.Contains, "process") - c.Assert(err.Error(), qt.Contains, "or branch") - c.Assert(err.Error(), qt.Contains, "101") - c.Assert(err.Error(), qt.Contains, branch) - c.Assert(err.Error(), qt.Contains, db) - c.Assert(err.Error(), qt.Contains, org) + c.Assert(err.Error(), qt.Contains, "connection_id not found") c.Assert(err.Error(), qt.Not(qt.Contains), "Not Found") - c.Assert(svc.KillFnInvoked, qt.IsTrue) } diff --git a/internal/cmd/branch/processlist.go b/internal/cmd/branch/processlist.go index be764482..8b6e3349 100644 --- a/internal/cmd/branch/processlist.go +++ b/internal/cmd/branch/processlist.go @@ -1,64 +1,15 @@ package branch import ( - "fmt" + "context" + "errors" + "github.com/planetscale/cli/internal/cmd/branch/connections" "github.com/planetscale/cli/internal/cmdutil" - "github.com/planetscale/cli/internal/printer" ps "github.com/planetscale/planetscale-go/planetscale" "github.com/spf13/cobra" ) -// Process is the table/json representation of a single MySQL process. -type Process struct { - ID int64 `header:"id,text" json:"id"` - User string `header:"user" json:"user"` - Host string `header:"host" json:"host"` - DB string `header:"db" json:"db"` - Command string `header:"command" json:"command"` - Time int64 `header:"time (seconds),text" json:"time"` - State string `header:"state" json:"state"` - Info string `header:"info" json:"info"` -} - -// ProcesslistResult is the CLI representation of a process list response. -type ProcesslistResult struct { - Keyspace string `json:"keyspace"` - Shard string `json:"shard"` - Tablet string `json:"tablet"` - Processes []ps.Process `json:"processes"` -} - -func (p *ProcesslistResult) MarshalCSVValue() interface{} { - return toProcesses(p.Processes) -} - -func toProcesses(processes []ps.Process) []*Process { - rows := make([]*Process, 0, len(processes)) - for _, p := range processes { - rows = append(rows, &Process{ - ID: p.ID, - User: p.User, - Host: p.Host, - DB: p.DB, - Command: p.Command, - Time: p.Time, - State: p.State, - Info: p.Info, - }) - } - return rows -} - -func toProcesslistResult(result *ps.ProcesslistResult) *ProcesslistResult { - return &ProcesslistResult{ - Keyspace: result.Keyspace, - Shard: result.Shard, - Tablet: result.Tablet, - Processes: result.Processes, - } -} - // ProcesslistCmd manages MySQL process lists for a Vitess branch. func ProcesslistCmd(ch *cmdutil.Helper) *cobra.Command { cmd := &cobra.Command{ @@ -67,6 +18,7 @@ func ProcesslistCmd(ch *cmdutil.Helper) *cobra.Command { Long: `Show and kill running MySQL processes for a branch. This command is only supported for Vitess databases.`, + Hidden: true, } cmd.AddCommand(ProcesslistShowCmd(ch)) @@ -85,58 +37,27 @@ func ProcesslistShowCmd(ch *cmdutil.Helper) *cobra.Command { cmd := &cobra.Command{ Use: "show ", Short: "Show the running MySQL processes for a branch", - Long: `Show the output of "SHOW FULL PROCESSLIST" for a branch. + Long: `Compatibility alias for "pscale branch connections show". This command is only supported for Vitess databases. +Use "pscale branch connections show" for new workflows. Use connection_id and +query_id values with "pscale branch connections kill". + The process list is read from a single primary tablet. If the database has a single unsharded keyspace, that primary is targeted automatically. If it has multiple keyspaces, pass --keyspace; if the targeted keyspace is sharded, also -pass --shard. Process IDs shown here can be passed to -"pscale branch processlist kill".`, +pass --shard.`, Args: cmdutil.RequiredArgs("database", "branch"), RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - database, branch := args[0], args[1] - - client, err := ch.Client() - if err != nil { + if err := requireProcesslistDatabase(cmd.Context(), ch, args[0]); err != nil { return err } - end := ch.Printer.PrintProgress( - fmt.Sprintf("Fetching process list for %s\u2026", - printer.BoldBlue(fmt.Sprintf("%s/%s/%s", ch.Config.Organization, database, branch)))) - defer end() - - result, err := client.Processlist.List(ctx, &ps.ProcesslistRequest{ - Organization: ch.Config.Organization, - Database: database, - Branch: branch, - Keyspace: flags.keyspace, - Shard: flags.shard, + return connections.RunList(cmd.Context(), ch, args[0], args[1], connections.ConnectionFilter{}, connections.ConnectionTarget{ + Keyspace: flags.keyspace, + Shard: flags.shard, }) - if err != nil { - switch cmdutil.ErrCode(err) { - case ps.ErrNotFound: - return cmdutil.HandleNotFoundWithServiceTokenCheck( - ctx, cmd, ch.Config, ch.Client, err, "read_branch", - "branch %s does not exist in database %s (organization: %s)", - printer.BoldBlue(branch), printer.BoldBlue(database), printer.BoldBlue(ch.Config.Organization)) - default: - return cmdutil.HandleError(err) - } - } - - end() - - if ch.Printer.Format() == printer.Human { - ch.Printer.Printf("Process list for keyspace %s shard %s (tablet %s):\n", - printer.BoldBlue(result.Keyspace), printer.BoldBlue(result.Shard), printer.BoldBlue(result.Tablet)) - return ch.Printer.PrintResource(toProcesses(result.Processes)) - } - - return ch.Printer.PrintResource(toProcesslistResult(result)) }, } @@ -145,3 +66,14 @@ pass --shard. Process IDs shown here can be passed to return cmd } + +func requireProcesslistDatabase(ctx context.Context, ch *cmdutil.Helper, database string) error { + engine, err := databaseEngine(ctx, ch, database) + if err != nil { + return err + } + if engine != ps.DatabaseEngineMySQL { + return errors.New("processlist is only supported for Vitess databases") + } + return nil +} diff --git a/internal/cmd/branch/processlist_test.go b/internal/cmd/branch/processlist_test.go index f3fc2640..b27337aa 100644 --- a/internal/cmd/branch/processlist_test.go +++ b/internal/cmd/branch/processlist_test.go @@ -2,65 +2,69 @@ package branch import ( "bytes" - "context" + "io" + "net/http" + "net/http/httptest" "testing" qt "github.com/frankban/quicktest" "github.com/planetscale/cli/internal/cmdutil" "github.com/planetscale/cli/internal/config" - "github.com/planetscale/cli/internal/mock" "github.com/planetscale/cli/internal/printer" ps "github.com/planetscale/planetscale-go/planetscale" ) -func processlistTestHelper(org string, svc ps.ProcesslistService, format printer.Format, buf *bytes.Buffer) *cmdutil.Helper { +func processlistTestHelper(org, baseURL string, format printer.Format, buf *bytes.Buffer) *cmdutil.Helper { p := printer.NewPrinter(&format) p.SetResourceOutput(buf) + p.SetHumanOutput(buf) return &cmdutil.Helper{ Printer: p, - Config: &config.Config{Organization: org}, + Config: &config.Config{AccessToken: "token", Organization: org, BaseURL: baseURL}, Client: func() (*ps.Client, error) { - return &ps.Client{Processlist: svc}, nil + return &ps.Client{Databases: databaseServiceForEngine(org, ps.DatabaseEngineMySQL)}, nil }, } } +func processlistListServer(t *testing.T, c *qt.C, wantQuery string, body string) *httptest.Server { + t.Helper() + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodGet) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/my-org/databases/my-db/branches/my-branch/connections") + c.Assert(r.URL.RawQuery, qt.Equals, wantQuery) + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, body) + })) + t.Cleanup(server.Close) + return server +} + +func processlistListBody(keyspace, shard, tablet, connections string) string { + return `{"type":"list","database_kind":"mysql","next_page":null,"prev_page":null,"captured_at":"2026-06-04T12:30:00Z","instances":[],"topology":{"keyspace":"` + keyspace + `","shard":"` + shard + `","tablet":"` + tablet + `"},"data":[` + connections + `]}` +} + func TestProcesslist(t *testing.T) { c := qt.New(t) org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - ListFn: func(ctx context.Context, req *ps.ProcesslistRequest) (*ps.ProcesslistResult, error) { - c.Assert(req.Organization, qt.Equals, org) - c.Assert(req.Database, qt.Equals, db) - c.Assert(req.Branch, qt.Equals, branch) - c.Assert(req.Keyspace, qt.Equals, "commerce") - c.Assert(req.Shard, qt.Equals, "-80") - return &ps.ProcesslistResult{ - Keyspace: "commerce", - Shard: "-80", - Tablet: "zone1-1001", - Processes: []ps.Process{ - {ID: 101, User: "vt_app", Command: "Query", Time: 42, Info: "SELECT 1"}, - }, - }, nil - }, - } + server := processlistListServer(t, c, "keyspace=commerce&shard=-80", + processlistListBody("commerce", "-80", "zone1-1001", `{"pid":101,"instance":"zone1-1001","usename":"vt_app","state":"Query","duration_ms":42000,"connection_id":"101","query_id":"101","query_text":"SELECT 1"}`)) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"show", db, branch, "--keyspace", "commerce", "--shard", "-80"}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(svc.ListFnInvoked, qt.IsTrue) c.Assert(buf.String(), qt.Contains, `"tablet": "zone1-1001"`) - c.Assert(buf.String(), qt.Contains, `"user": "vt_app"`) + c.Assert(buf.String(), qt.Contains, `"username": "vt_app"`) + c.Assert(buf.String(), qt.Contains, `"connection_id": "101"`) } func TestProcesslist_CSVOutput(t *testing.T) { @@ -68,28 +72,20 @@ func TestProcesslist_CSVOutput(t *testing.T) { org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - ListFn: func(ctx context.Context, req *ps.ProcesslistRequest) (*ps.ProcesslistResult, error) { - return &ps.ProcesslistResult{ - Keyspace: "commerce", - Shard: "-80", - Tablet: "zone1-1001", - Processes: []ps.Process{ - {ID: 101, User: "vt_app", Host: "10.0.0.1", DB: "main", Command: "Query", Time: 42, State: "running", Info: "SELECT 1"}, - }, - }, nil - }, - } + server := processlistListServer(t, c, "", + processlistListBody("commerce", "-80", "zone1-1001", `{"pid":101,"instance":"zone1-1001","usename":"vt_app","client_addr":"10.0.0.1","datname":"main","state":"Query/running","duration_ms":42000,"connection_id":"101","query_id":"101","query_text":"SELECT 1"}`)) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.CSV, &buf) + ch := processlistTestHelper(org, server.URL, printer.CSV, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"show", db, branch}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(buf.String(), qt.Contains, "101,vt_app,10.0.0.1,main,Query,42,running,SELECT 1") + c.Assert(buf.String(), qt.Contains, "keyspace,shard,tablet") + c.Assert(buf.String(), qt.Contains, "commerce,-80,zone1-1001,101,zone1-1001,,Query/running,42000") + c.Assert(buf.String(), qt.Contains, "vt_app,,main,10.0.0.1,SELECT 1") c.Assert(buf.String(), qt.Not(qt.Contains), "{") c.Assert(buf.String(), qt.Not(qt.Contains), `"processes"`) } @@ -99,23 +95,17 @@ func TestProcesslist_NoTargetFlags(t *testing.T) { org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - ListFn: func(ctx context.Context, req *ps.ProcesslistRequest) (*ps.ProcesslistResult, error) { - c.Assert(req.Keyspace, qt.Equals, "") - c.Assert(req.Shard, qt.Equals, "") - return &ps.ProcesslistResult{Keyspace: "main", Shard: "-", Tablet: "zone1-2001"}, nil - }, - } + server := processlistListServer(t, c, "", + processlistListBody("main", "-", "zone1-2001", "")) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"show", db, branch}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(svc.ListFnInvoked, qt.IsTrue) } func TestProcesslist_HumanOutputDoesNotAbbreviateNumericFields(t *testing.T) { @@ -123,31 +113,22 @@ func TestProcesslist_HumanOutputDoesNotAbbreviateNumericFields(t *testing.T) { org, db, branch := "my-org", "my-db", "my-branch" - svc := &mock.ProcesslistService{ - ListFn: func(ctx context.Context, req *ps.ProcesslistRequest) (*ps.ProcesslistResult, error) { - return &ps.ProcesslistResult{ - Keyspace: "main", - Shard: "-", - Tablet: "zone1-2001", - Processes: []ps.Process{ - {ID: 121500, User: "vt_app", Command: "Sleep", Time: 2100}, - }, - }, nil - }, - } + server := processlistListServer(t, c, "", + processlistListBody("main", "-", "zone1-2001", `{"pid":121500,"instance":"zone1-2001","usename":"vt_app","state":"Sleep","duration_ms":2100000,"connection_id":"121500","query_id":"121500"}`)) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.Human, &buf) + ch := processlistTestHelper(org, server.URL, printer.Human, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"show", db, branch}) err := cmd.Execute() c.Assert(err, qt.IsNil) - c.Assert(buf.String(), qt.Contains, "TIME (SECONDS)") + c.Assert(buf.String(), qt.Contains, "pid: 121500") + c.Assert(buf.String(), qt.Contains, "duration: 35m0s") + c.Assert(buf.String(), qt.Contains, "connection_id: 121500") c.Assert(buf.String(), qt.Contains, "121500") c.Assert(buf.String(), qt.Not(qt.Contains), "121.5K") - c.Assert(buf.String(), qt.Contains, "2100") c.Assert(buf.String(), qt.Not(qt.Contains), "2.1K") } @@ -156,14 +137,16 @@ func TestProcesslist_NotFound(t *testing.T) { org, db, branch := "my-org", "missing-db", "missing-branch" - svc := &mock.ProcesslistService{ - ListFn: func(ctx context.Context, req *ps.ProcesslistRequest) (*ps.ProcesslistResult, error) { - return nil, &ps.Error{Code: ps.ErrNotFound} - }, - } + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + c.Assert(r.Method, qt.Equals, http.MethodGet) + c.Assert(r.URL.Path, qt.Equals, "/v1/organizations/my-org/databases/missing-db/branches/missing-branch/connections") + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, `{"message":"not found"}`) + })) + t.Cleanup(server.Close) var buf bytes.Buffer - ch := processlistTestHelper(org, svc, printer.JSON, &buf) + ch := processlistTestHelper(org, server.URL, printer.JSON, &buf) cmd := ProcesslistCmd(ch) cmd.SetArgs([]string{"show", db, branch}) @@ -175,5 +158,4 @@ func TestProcesslist_NotFound(t *testing.T) { c.Assert(err.Error(), qt.Contains, db) c.Assert(err.Error(), qt.Contains, org) c.Assert(err.Error(), qt.Not(qt.Contains), "Not Found") - c.Assert(svc.ListFnInvoked, qt.IsTrue) } diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 673ecd9c..38f03e09 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -25,7 +25,6 @@ import ( "strings" "time" - "github.com/planetscale/cli/internal/cmd/dataimports" "github.com/planetscale/cli/internal/cmd/mcp" "github.com/planetscale/cli/internal/cmd/role" "github.com/planetscale/cli/internal/cmd/size" @@ -40,6 +39,7 @@ import ( "github.com/planetscale/cli/internal/cmd/branch" "github.com/planetscale/cli/internal/cmd/connect" "github.com/planetscale/cli/internal/cmd/database" + "github.com/planetscale/cli/internal/cmd/dataimports" "github.com/planetscale/cli/internal/cmd/deployrequest" "github.com/planetscale/cli/internal/cmd/keyspace" "github.com/planetscale/cli/internal/cmd/org" @@ -316,7 +316,6 @@ func runCmd(ctx context.Context, ver, commit, buildDate string, format *printer. workflowCmd.GroupID = "vitess" rootCmd.AddCommand(workflowCmd) - // Postgres-specific commands roleCmd := role.RoleCmd(ch) roleCmd.GroupID = "postgres" rootCmd.AddCommand(roleCmd) diff --git a/internal/connections/actions.go b/internal/connections/actions.go new file mode 100644 index 00000000..8f66fc9b --- /dev/null +++ b/internal/connections/actions.go @@ -0,0 +1,12 @@ +package connections + +// ActionTarget identifies the connection an action acts on. Instance and PID +// describe the selected row; ConnectionID, TransactionID, and QueryID are the +// server-issued IDs for connection, transaction, and query verification. +type ActionTarget struct { + Instance string + PID int + ConnectionID *string + TransactionID *string + QueryID *string +} diff --git a/internal/connections/client.go b/internal/connections/client.go new file mode 100644 index 00000000..63617868 --- /dev/null +++ b/internal/connections/client.go @@ -0,0 +1,580 @@ +// Package connections is a branch-connections HTTP client that follows the +// patterns of github.com/planetscale/planetscale-go: a typed Client over +// shared auth, base-URL, and transport machinery. The auth header +// construction, error formatting, and response decode helpers mirror the +// equivalents in planetscale-go. +package connections + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand/v2" + "net/http" + "net/url" + "strconv" + "strings" + "time" +) + +const ( + clientUserAgent = "pscale-cli" + defaultTimeout = 5 * time.Second + + // Retry defaults for list-only 503s returned while the server's + // connection snapshot cache is being populated. The floor is 50ms + // because the server typically repopulates that cache in tens of + // milliseconds; retrying at 100ms would routinely wait 50ms after the + // cache is already warm, while 50ms catches it shortly after the lock + // releases. Jitter (25ms) defends against lock-step retries across + // multiple concurrent CLI users on the same branch. Budget (2s) caps + // total perceived list latency before surfacing a friendly warming + // message: ~2x the polling cadence, so a single hiccup doesn't surface + // as an immediate user-visible failure. + defaultListRetryBudget = 2 * time.Second + defaultListRetryBackoff = 50 * time.Millisecond + defaultListRetryJitter = 25 * time.Millisecond +) + +var ( + errListWarming = errors.New("list connections: server warming") + errListWarmingExhausted = errors.New("list connections: server is warming up, please retry in a moment") + errListInvalidResponse = errors.New("list connections: received an invalid response, please retry") +) + +type ClientConfig struct { + BaseURL string + Organization string + Database string + Branch string + Keyspace string + Shard string + AccessToken string + ServiceTokenID string + ServiceToken string + HTTPClient *http.Client + RequestTimeout time.Duration +} + +type Client struct { + cfg ClientConfig + client *http.Client + retryBudget time.Duration + retryBackoff time.Duration + retryJitter time.Duration +} + +type AvailableTargets struct { + Keyspaces []string `json:"keyspaces,omitempty"` + Shards []string `json:"shards,omitempty"` +} + +type HTTPError struct { + Op string + StatusCode int + Message string + Available AvailableTargets +} + +func (e *HTTPError) Error() string { + message := e.Message + if message == "" { + message = http.StatusText(e.StatusCode) + } + detail := fmt.Sprintf("%s: HTTP %d: %s", e.Op, e.StatusCode, message) + if len(e.Available.Keyspaces) > 0 { + detail += fmt.Sprintf(" (available keyspaces: %s)", strings.Join(e.Available.Keyspaces, ", ")) + } + if len(e.Available.Shards) > 0 { + detail += fmt.Sprintf(" (available shards: %s)", strings.Join(e.Available.Shards, ", ")) + } + return detail +} + +func UserFacingError(err error, action string) error { + if err == nil { + return nil + } + return errors.New(UserFacingErrorText(err, action)) +} + +func UserFacingErrorText(err error, action string) string { + var httpErr *HTTPError + if errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusForbidden { + return fmt.Sprintf("permission denied: you don't have permission to %s live connections", action) + } + return err.Error() +} + +func NewClient(cfg ClientConfig) (*Client, error) { + if cfg.Organization == "" || cfg.Database == "" || cfg.Branch == "" { + return nil, errors.New("organization, database, and branch are required") + } + if cfg.AccessToken == "" && (cfg.ServiceTokenID == "" || cfg.ServiceToken == "") { + return nil, errors.New("not authenticated: provide AccessToken or ServiceTokenID/ServiceToken") + } + client := cfg.HTTPClient + if client == nil { + client = &http.Client{} + } + if cfg.RequestTimeout == 0 { + cfg.RequestTimeout = defaultTimeout + } + return &Client{ + cfg: cfg, + client: client, + retryBudget: defaultListRetryBudget, + retryBackoff: defaultListRetryBackoff, + retryJitter: defaultListRetryJitter, + }, nil +} + +type listResponse struct { + DatabaseKind DatabaseKind `json:"database_kind"` + CapturedAt time.Time `json:"captured_at"` + Instances []InstanceMeta `json:"instances"` + Topology *Topology `json:"topology"` + Data []listEntry `json:"data"` +} + +type listEntry struct { + PID int `json:"pid"` + Instance string `json:"instance"` + DatabaseName string `json:"datname"` + Username string `json:"usename"` + ApplicationName string `json:"application_name"` + ClientAddr string `json:"client_addr"` + State string `json:"state"` + WaitEventType string `json:"wait_event_type"` + WaitEvent string `json:"wait_event"` + BackendType string `json:"backend_type"` + XactStart *time.Time `json:"xact_start"` + QueryStart *time.Time `json:"query_start"` + ConnectionID *string `json:"connection_id"` + TransactionID *string `json:"transaction_id"` + QueryID *string `json:"query_id"` + DurationMS int64 `json:"duration_ms"` + BlockedBy []int `json:"blocked_by"` + QueryText string `json:"query_text"` +} + +func (e listEntry) connection() Connection { + return Connection{ + PID: e.PID, + Instance: e.Instance, + DatabaseName: e.DatabaseName, + Username: e.Username, + ApplicationName: e.ApplicationName, + ClientAddr: e.ClientAddr, + State: e.State, + WaitEventType: e.WaitEventType, + WaitEvent: e.WaitEvent, + BackendType: e.BackendType, + XactStart: e.XactStart, + QueryStart: e.QueryStart, + ConnectionID: e.ConnectionID, + TransactionID: e.TransactionID, + QueryID: e.QueryID, + Duration: time.Duration(e.DurationMS) * time.Millisecond, + BlockedBy: e.BlockedBy, + QueryText: e.QueryText, + } +} + +func (s *Client) List(ctx context.Context, sort SortMode) (ConnectionList, error) { + callerCtx := ctx + if s.cfg.RequestTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.cfg.RequestTimeout) + defer cancel() + } + + deadline := time.Now().Add(s.retryBudget) + for { + list, retryAfter, err := s.tryList(ctx, sort) + if err == nil { + return list, nil + } + if isTimeoutError(err) && ctx.Err() != nil { + if callerCtx.Err() != nil { + return ConnectionList{}, fmt.Errorf("list connections: %w", callerCtx.Err()) + } + return ConnectionList{}, fmt.Errorf("list connections: request timed out after %s, please retry", s.cfg.RequestTimeout) + } + if !errors.Is(err, errListWarming) { + return ConnectionList{}, err + } + if s.retryBudget <= 0 { + return ConnectionList{}, errListWarmingExhausted + } + + delay := retryAfter + if delay <= 0 { + delay = s.retryDelay() + } + if time.Now().Add(delay).After(deadline) { + return ConnectionList{}, errListWarmingExhausted + } + + select { + case <-time.After(delay): + case <-ctx.Done(): + return ConnectionList{}, fmt.Errorf("list connections: %w", ctx.Err()) + } + } +} + +func (s *Client) tryList(ctx context.Context, sort SortMode) (ConnectionList, time.Duration, error) { + resp, err := s.do(ctx, http.MethodGet, s.connectionsURL("")) + if err != nil { + return ConnectionList{}, 0, err + } + defer resp.Body.Close() + + if resp.StatusCode == http.StatusServiceUnavailable { + _, _ = io.Copy(io.Discard, resp.Body) + return ConnectionList{}, parseRetryAfter(resp.Header.Get("Retry-After")), errListWarming + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return ConnectionList{}, 0, fmt.Errorf("list connections: read body: %w", err) + } + + if resp.StatusCode > 299 { + return ConnectionList{}, 0, s.formatHTTPError("list connections", resp.StatusCode, body) + } + + var listed listResponse + if err := json.Unmarshal(body, &listed); err != nil { + return ConnectionList{}, 0, errListInvalidResponse + } + + capturedAt := listed.CapturedAt + if capturedAt.IsZero() { + return ConnectionList{}, 0, errors.New("list connections: response missing captured_at") + } + + list := listed.connectionList(sort) + if len(list.Instances) > 0 { + roles := make(map[string]string, len(list.Instances)) + for _, m := range list.Instances { + roles[m.ID] = m.Role + } + for i := range list.Connections { + list.Connections[i].InstanceRole = roles[list.Connections[i].Instance] + } + } + return list, 0, nil +} + +func (r listResponse) connectionList(sort SortMode) ConnectionList { + list := NewConnectionList(r.CapturedAt, connectionsFromEntries(r.Data), sort) + list.DatabaseKind = r.DatabaseKind + list.Instances = r.Instances + list.Topology = r.Topology + return list +} + +func connectionsFromEntries(entries []listEntry) []Connection { + connections := make([]Connection, 0, len(entries)) + for _, entry := range entries { + connections = append(connections, entry.connection()) + } + return connections +} + +func isTimeoutError(err error) bool { + if errors.Is(err, context.DeadlineExceeded) { + return true + } + var netErr interface{ Timeout() bool } + return errors.As(err, &netErr) && netErr.Timeout() +} + +func (s *Client) retryDelay() time.Duration { + if s.retryJitter <= 0 { + return s.retryBackoff + } + return s.retryBackoff + rand.N(s.retryJitter) +} + +func parseRetryAfter(h string) time.Duration { + h = strings.TrimSpace(h) + if h == "" { + return 0 + } + if isPlainDecimalSeconds(h) { + seconds, err := strconv.ParseFloat(h, 64) + if err == nil && seconds > 0 { + return time.Duration(seconds * float64(time.Second)) + } + } + t, err := http.ParseTime(h) + if err != nil { + return 0 + } + d := time.Until(t) + if d <= 0 { + return 0 + } + return d +} + +func isPlainDecimalSeconds(s string) bool { + seenDot := false + digitsBeforeDot := 0 + digitsAfterDot := 0 + for _, r := range s { + switch { + case r >= '0' && r <= '9': + if seenDot { + digitsAfterDot++ + } else { + digitsBeforeDot++ + } + case r == '.': + if seenDot { + return false + } + seenDot = true + default: + return false + } + } + if digitsBeforeDot == 0 { + return false + } + return !seenDot || digitsAfterDot > 0 +} + +// CancelQuery asks the backend to cancel the active query identified by +// target.QueryID. +func (s *Client) CancelQuery(ctx context.Context, target ActionTarget) error { + _, err := s.CancelQueryResult(ctx, target) + return err +} + +func (s *Client) CancelQueryResult(ctx context.Context, target ActionTarget) (ActionResult, error) { + if target.QueryID == nil || *target.QueryID == "" { + return ActionResult{}, errors.New("cancel query: query_id is required") + } + suffix := fmt.Sprintf("query/%s", url.PathEscape(*target.QueryID)) + return s.deleteAction(ctx, "cancel query", suffix) +} + +// TerminateTransaction asks the backend to terminate the connection only if +// its current transaction matches target.TransactionID — i.e., reject if the +// connection has moved on. Same-transaction semantics live on the server. +func (s *Client) TerminateTransaction(ctx context.Context, target ActionTarget) error { + _, err := s.TerminateTransactionResult(ctx, target) + return err +} + +// TerminateTransactionResult asks the backend to terminate a transaction and +// returns any action metadata the backend provides. +func (s *Client) TerminateTransactionResult(ctx context.Context, target ActionTarget) (ActionResult, error) { + if target.TransactionID == nil || *target.TransactionID == "" { + return ActionResult{}, errors.New("terminate transaction: transaction_id is required") + } + suffix := fmt.Sprintf("transaction/%s", url.PathEscape(*target.TransactionID)) + return s.deleteAction(ctx, "terminate transaction", suffix) +} + +// TerminateConnection force-terminates the connection identified by +// target.ConnectionID without regard to its current query or transaction. +// Operator-initiated; gated by a y/n confirmation in the TUI. +func (s *Client) TerminateConnection(ctx context.Context, target ActionTarget) error { + _, err := s.TerminateConnectionResult(ctx, target) + return err +} + +func (s *Client) TerminateConnectionResult(ctx context.Context, target ActionTarget) (ActionResult, error) { + if target.ConnectionID == nil || *target.ConnectionID == "" { + return ActionResult{}, errors.New("terminate connection: connection_id is required") + } + suffix := fmt.Sprintf("connection/%s", url.PathEscape(*target.ConnectionID)) + return s.deleteAction(ctx, "terminate connection", suffix) +} + +type ActionResult struct { + Success bool `json:"success"` + Keyspace string `json:"keyspace,omitempty"` + Shard string `json:"shard,omitempty"` + Tablet string `json:"tablet,omitempty"` + ID int64 `json:"id,omitempty"` + Kind string `json:"kind,omitempty"` +} + +// deleteAction issues a typed DELETE against the connection path with the +// caller-supplied suffix and returns a sanitized error on non-2xx. +func (s *Client) deleteAction(ctx context.Context, op, suffix string) (ActionResult, error) { + if s.cfg.RequestTimeout > 0 { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, s.cfg.RequestTimeout) + defer cancel() + } + resp, err := s.do(ctx, http.MethodDelete, s.connectionsURL(suffix)) + if err != nil { + return ActionResult{}, err + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return ActionResult{}, fmt.Errorf("%s: read body: %w", op, err) + } + if resp.StatusCode > 299 { + if resp.StatusCode == http.StatusNotFound { + if idName := actionIDName(op); idName != "" { + return ActionResult{}, fmt.Errorf("%s: %s not found; run connections show again and use a current %s", op, idName, idName) + } + } + if resp.StatusCode == http.StatusUnprocessableEntity { + if message := httpErrorEnvelopeMessage(body); message != "" { + return ActionResult{}, fmt.Errorf("%s: %s", op, message) + } + } + return ActionResult{}, s.formatHTTPError(op, resp.StatusCode, body) + } + if len(strings.TrimSpace(string(body))) == 0 { + return ActionResult{Success: true}, nil + } + var result ActionResult + if err := json.Unmarshal(body, &result); err != nil { + return ActionResult{}, fmt.Errorf("%s: invalid response", op) + } + if !result.Success { + return ActionResult{}, fmt.Errorf("%s: action did not succeed", op) + } + return result, nil +} + +func actionIDName(op string) string { + switch op { + case "cancel query": + return "query_id" + case "terminate transaction": + return "transaction_id" + case "terminate connection": + return "connection_id" + default: + return "" + } +} + +func (s *Client) formatHTTPError(op string, status int, body []byte) error { + if status == http.StatusNotFound { + hint := fmt.Sprintf( + "verify --org=%q database=%q branch=%q exists and live connections are enabled", + s.cfg.Organization, s.cfg.Database, s.cfg.Branch, + ) + return fmt.Errorf("%s: not found (%s)", op, hint) + } + if status == http.StatusTooManyRequests { + return fmt.Errorf("%s: rate limited, please retry in a moment", op) + } + if status >= 500 { + return fmt.Errorf("%s: HTTP %d: %s", op, status, http.StatusText(status)) + } + message, available := httpErrorDetails(body) + return &HTTPError{ + Op: op, + StatusCode: status, + Message: message, + Available: available, + } +} + +func httpErrorDetails(body []byte) (string, AvailableTargets) { + var envelope struct { + Message string `json:"message"` + Available AvailableTargets `json:"available"` + } + if err := json.Unmarshal(body, &envelope); err == nil { + return envelope.Message, envelope.Available + } + if detail := nonJSONBody(body); detail != "" { + return detail, AvailableTargets{} + } + return "", AvailableTargets{} +} + +func nonJSONBody(body []byte) string { + trimmed := strings.TrimSpace(string(body)) + if trimmed == "" { + return "" + } + if trimmed[0] == '{' || trimmed[0] == '[' { + return "" + } + return trimmed +} + +func httpErrorEnvelopeMessage(body []byte) string { + var envelope struct { + Message string `json:"message"` + } + if err := json.Unmarshal(body, &envelope); err == nil { + return envelope.Message + } + return "" +} + +// UnknownInstanceError reports an --instance filter value that matches none of +// the instances in the list response. +type UnknownInstanceError struct { + Instance string + Valid []string +} + +func (e *UnknownInstanceError) Error() string { + return fmt.Sprintf("unknown instance %q (valid instances: %s)", e.Instance, strings.Join(e.Valid, ", ")) +} + +func (s *Client) connectionsURL(suffix string) string { + raw := fmt.Sprintf( + "%s/v1/organizations/%s/databases/%s/branches/%s/connections", + strings.TrimRight(s.cfg.BaseURL, "/"), + url.PathEscape(s.cfg.Organization), + url.PathEscape(s.cfg.Database), + url.PathEscape(s.cfg.Branch), + ) + if suffix != "" { + raw += "/" + suffix + } + u, err := url.Parse(raw) + if err != nil { + return raw + } + q := u.Query() + if s.cfg.Keyspace != "" { + q.Set("keyspace", s.cfg.Keyspace) + } + if s.cfg.Shard != "" { + q.Set("shard", s.cfg.Shard) + } + u.RawQuery = q.Encode() + return u.String() +} + +func (s *Client) do(ctx context.Context, method, urlStr string) (*http.Response, error) { + req, err := http.NewRequestWithContext(ctx, method, urlStr, nil) + if err != nil { + return nil, fmt.Errorf("build request: %w", err) + } + req.Header.Set("User-Agent", clientUserAgent) + req.Header.Set("Accept", "application/json") + if s.cfg.AccessToken != "" { + req.Header.Set("Authorization", "Bearer "+s.cfg.AccessToken) + } else { + req.Header.Set("Authorization", s.cfg.ServiceTokenID+":"+s.cfg.ServiceToken) + } + + resp, err := s.client.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + return resp, nil +} diff --git a/internal/connections/client_test.go b/internal/connections/client_test.go new file mode 100644 index 00000000..c2b273aa --- /dev/null +++ b/internal/connections/client_test.go @@ -0,0 +1,992 @@ +package connections + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync/atomic" + "testing" + "time" +) + +const sampleListResponse = `{ + "type": "list", + "next_page": null, + "prev_page": null, + "captured_at": "2026-04-29T12:34:56.789Z", + "instances": [ + {"id": "primary", "role": "primary", "error": null}, + {"id": "replica-1", "role": "replica", "error": "timeout after 2s"} + ], + "data": [ + { + "pid": 123, + "instance": "primary", + "captured_at": "2026-04-29T12:34:56.789Z", + "duration_ms": 1234, + "blocked_by": [456], + "query_text": "SELECT 1 FROM t", + "state": "active", + "wait_event": "ClientRead", + "wait_event_type": "Client", + "client_addr": "10.0.0.1", + "application_name": "psql", + "backend_type": "client backend", + "usename": "alice", + "xact_start": "2026-04-29T12:34:00.000Z", + "query_start": "2026-04-29T12:34:30.000Z", + "connection_id": "primary-123-1779113716123456", + "transaction_id": "primary-123-1777466040000000", + "query_id": "primary-123-1777466070000000" + }, + { + "pid": 456, + "instance": "primary", + "captured_at": "2026-04-29T12:34:56.789Z", + "duration_ms": 50, + "blocked_by": [], + "query_text": "", + "state": "idle", + "wait_event": "", + "wait_event_type": "", + "client_addr": "", + "application_name": "", + "backend_type": "client backend", + "usename": "bob", + "xact_start": null, + "query_start": null, + "transaction_id": null, + "query_id": null + } + ] +}` + +func TestClient_ListDecodesConnectionList(t *testing.T) { + var gotPath, gotMethod, gotAuth, gotUA string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotMethod = r.Method + gotAuth = r.Header.Get("Authorization") + gotUA = r.Header.Get("User-Agent") + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, sampleListResponse) + })) + defer srv.Close() + + client, err := NewClient(ClientConfig{ + BaseURL: srv.URL, + Organization: "acme", + Database: "prod", + Branch: "main", + ServiceTokenID: "tid", + ServiceToken: "secret", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + list, err := client.List(context.Background(), SortByTransactionStart) + if err != nil { + t.Fatalf("List: %v", err) + } + + if gotMethod != http.MethodGet { + t.Errorf("method = %q, want GET", gotMethod) + } + wantPath := "/v1/organizations/acme/databases/prod/branches/main/connections" + if gotPath != wantPath { + t.Errorf("path = %q, want %q", gotPath, wantPath) + } + if gotAuth != "tid:secret" { + t.Errorf("Authorization = %q, want %q", gotAuth, "tid:secret") + } + if gotUA != "pscale-cli" { + t.Errorf("User-Agent = %q, want %q", gotUA, "pscale-cli") + } + + wantCaptured, _ := time.Parse(time.RFC3339Nano, "2026-04-29T12:34:56.789Z") + if !list.CapturedAt.Equal(wantCaptured) { + t.Errorf("CapturedAt = %v, want %v", list.CapturedAt, wantCaptured) + } + if list.Sort != SortByTransactionStart { + t.Errorf("Sort = %q, want %q", list.Sort, SortByTransactionStart) + } + if len(list.Connections) != 2 { + t.Fatalf("len(Connections) = %d, want 2", len(list.Connections)) + } + + first := list.Connections[0] + if first.PID != 123 { + t.Errorf("first.PID = %d, want 123", first.PID) + } + if first.Instance != "primary" { + t.Errorf("first.Instance = %q, want primary", first.Instance) + } + if first.Username != "alice" { + t.Errorf("first.Username = %q, want alice", first.Username) + } + if first.State != "active" { + t.Errorf("first.State = %q, want active", first.State) + } + if first.WaitEvent != "ClientRead" { + t.Errorf("first.WaitEvent = %q, want ClientRead", first.WaitEvent) + } + if first.WaitEventType != "Client" { + t.Errorf("first.WaitEventType = %q, want Client", first.WaitEventType) + } + if first.ApplicationName != "psql" { + t.Errorf("first.ApplicationName = %q, want psql", first.ApplicationName) + } + if first.ClientAddr != "10.0.0.1" { + t.Errorf("first.ClientAddr = %q, want 10.0.0.1", first.ClientAddr) + } + if first.BackendType != "client backend" { + t.Errorf("first.BackendType = %q, want client backend", first.BackendType) + } + if first.QueryText != "SELECT 1 FROM t" { + t.Errorf("first.QueryText = %q, want SELECT 1 FROM t", first.QueryText) + } + if first.Duration != 1234*time.Millisecond { + t.Errorf("first.Duration = %v, want 1234ms", first.Duration) + } + if len(first.BlockedBy) != 1 || first.BlockedBy[0] != 456 { + t.Errorf("first.BlockedBy = %v, want [456]", first.BlockedBy) + } + if first.XactStart == nil { + t.Error("first.XactStart = nil, want non-nil") + } + if first.QueryStart == nil { + t.Error("first.QueryStart = nil, want non-nil") + } + + second := list.Connections[1] + if second.Username != "bob" { + t.Errorf("second.Username = %q, want bob", second.Username) + } + if len(second.BlockedBy) != 0 { + t.Errorf("second.BlockedBy = %v, want empty", second.BlockedBy) + } + if second.XactStart != nil { + t.Errorf("second.XactStart = %v, want nil", second.XactStart) + } + + if len(list.Instances) != 2 { + t.Fatalf("len(Instances) = %d, want 2", len(list.Instances)) + } + if list.Instances[0] != (InstanceMeta{ID: "primary", Role: "primary"}) { + t.Errorf("Instances[0] = %+v, want primary/primary/(no error)", list.Instances[0]) + } + if list.Instances[1] != (InstanceMeta{ID: "replica-1", Role: "replica", Error: "timeout after 2s"}) { + t.Errorf("Instances[1] = %+v, want replica-1/replica/timeout", list.Instances[1]) + } + + if first.InstanceRole != "primary" { + t.Errorf("first.InstanceRole = %q, want primary (joined from Instances metadata)", first.InstanceRole) + } + if second.InstanceRole != "primary" { + t.Errorf("second.InstanceRole = %q, want primary (both rows on primary)", second.InstanceRole) + } + + if first.TransactionID == nil || *first.TransactionID != "primary-123-1777466040000000" { + got := "nil" + if first.TransactionID != nil { + got = *first.TransactionID + } + t.Errorf("first.TransactionID = %s, want primary-123-1777466040000000", got) + } + if first.QueryID == nil || *first.QueryID != "primary-123-1777466070000000" { + got := "nil" + if first.QueryID != nil { + got = *first.QueryID + } + t.Errorf("first.QueryID = %s, want primary-123-1777466070000000", got) + } + if first.ConnectionID == nil || *first.ConnectionID != "primary-123-1779113716123456" { + got := "nil" + if first.ConnectionID != nil { + got = *first.ConnectionID + } + t.Errorf("first.ConnectionID = %s, want primary-123-1779113716123456", got) + } + if second.TransactionID != nil { + t.Errorf("second.TransactionID = %v, want nil (wire null)", *second.TransactionID) + } + if second.QueryID != nil { + t.Errorf("second.QueryID = %v, want nil (wire null)", *second.QueryID) + } +} + +func TestClient_ListDecodesVitessProcesslist(t *testing.T) { + var gotPath, gotQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQuery = r.URL.RawQuery + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{ + "type": "list", + "database_kind": "mysql", + "next_page": null, + "prev_page": null, + "captured_at": "2026-06-04T12:30:00.000Z", + "instances": [], + "topology": { + "keyspace": "commerce", + "shard": "-80", + "tablet": "zone1-1001" + }, + "data": [ + { + "pid": 101, + "instance": "zone1-1001", + "connection_id": "101", + "transaction_id": null, + "query_id": "101", + "leader_pid": null, + "datname": "checkout", + "state": "Query/executing", + "usename": "vt_app", + "wait_event": null, + "client_addr": "10.0.0.12:54231", + "wait_event_type": null, + "application_name": "", + "backend_type": "client backend", + "query_start": null, + "xact_start": null, + "backend_start": null, + "state_change": null, + "duration_ms": 42000, + "blocked_by": [], + "query_text": "SELECT 1" + } + ] + }`) + })) + defer srv.Close() + + client, err := NewClient(ClientConfig{ + BaseURL: srv.URL, + Organization: "acme", + Database: "shop", + Branch: "main", + Keyspace: "commerce", + Shard: "-80", + ServiceTokenID: "tid", + ServiceToken: "secret", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + list, err := client.List(context.Background(), SortByDuration) + if err != nil { + t.Fatalf("List: %v", err) + } + + if gotPath != "/v1/organizations/acme/databases/shop/branches/main/connections" { + t.Errorf("path = %q", gotPath) + } + if gotQuery != "keyspace=commerce&shard=-80" { + t.Errorf("query = %q, want keyspace=commerce&shard=-80", gotQuery) + } + wantCaptured, _ := time.Parse(time.RFC3339Nano, "2026-06-04T12:30:00.000Z") + if !list.CapturedAt.Equal(wantCaptured) { + t.Errorf("CapturedAt = %v, want %v", list.CapturedAt, wantCaptured) + } + if list.DatabaseKind != DatabaseKindMySQL { + t.Errorf("DatabaseKind = %q, want %q", list.DatabaseKind, DatabaseKindMySQL) + } + if list.Topology == nil { + t.Fatal("Topology = nil, want Vitess topology") + } + if list.Topology.Keyspace != "commerce" || list.Topology.Shard != "-80" || list.Topology.Tablet != "zone1-1001" { + t.Errorf("Topology = %+v", list.Topology) + } + if len(list.Connections) != 1 { + t.Fatalf("len(Connections) = %d, want 1", len(list.Connections)) + } + conn := list.Connections[0] + if conn.PID != 101 || conn.Instance != "zone1-1001" || conn.Username != "vt_app" || conn.DatabaseName != "checkout" { + t.Errorf("Connection = %+v", conn) + } + if conn.State != "Query/executing" { + t.Errorf("State = %q, want Query/executing", conn.State) + } + if conn.Duration != 42*time.Second { + t.Errorf("Duration = %v, want 42s", conn.Duration) + } + if conn.ConnectionID == nil || *conn.ConnectionID != "101" { + t.Errorf("ConnectionID = %v, want 101", conn.ConnectionID) + } + if conn.QueryID == nil || *conn.QueryID != "101" { + t.Errorf("QueryID = %v, want 101", conn.QueryID) + } +} + +func TestClient_ListLeavesInstanceRoleEmptyForUnknownInstance(t *testing.T) { + const body = `{ + "type": "list", + "captured_at": "2026-04-29T12:34:56.789Z", + "instances": [{"id": "primary", "role": "primary"}], + "data": [{"pid": 1, "instance": "ghost", "duration_ms": 0, "state": "idle"}] + }` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + list, err := mustClient(t, srv.URL).List(context.Background(), SortByTransactionStart) + if err != nil { + t.Fatalf("List: %v", err) + } + if list.Connections[0].InstanceRole != "" { + t.Errorf("InstanceRole = %q, want empty for unknown instance id", list.Connections[0].InstanceRole) + } +} + +func TestClient_ListLeavesInstanceRoleEmptyWithoutInstancesField(t *testing.T) { + const body = `{ + "type": "list", + "captured_at": "2026-04-29T12:34:56.789Z", + "data": [{"pid": 1, "instance": "primary", "duration_ms": 0, "state": "idle"}] + }` + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, body) + })) + defer srv.Close() + + list, err := mustClient(t, srv.URL).List(context.Background(), SortByTransactionStart) + if err != nil { + t.Fatalf("List: %v", err) + } + if list.Connections[0].InstanceRole != "" { + t.Errorf("InstanceRole = %q, want empty when Instances is absent", list.Connections[0].InstanceRole) + } +} + +func TestClient_ListRetriesOn503AndReturnsList(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + call := calls.Add(1) + if call < 3 { + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "cache lock held") + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, sampleListResponse) + })) + defer srv.Close() + + c := newClientWithTimings(t, srv.URL, 500*time.Millisecond, 10*time.Millisecond, 0) + list, err := c.List(context.Background(), SortByTransactionStart) + if err != nil { + t.Fatalf("List: %v", err) + } + if got := calls.Load(); got != 3 { + t.Fatalf("calls = %d, want 3", got) + } + if len(list.Connections) != 2 { + t.Fatalf("len(Connections) = %d, want 2", len(list.Connections)) + } +} + +func TestClient_ListReturnsFriendlyWarmingMessageAfter503Budget(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusServiceUnavailable) + _, _ = io.WriteString(w, "cache lock held") + })) + defer srv.Close() + + c := newClientWithTimings(t, srv.URL, 35*time.Millisecond, 10*time.Millisecond, 0) + start := time.Now() + _, err := c.List(context.Background(), SortByTransactionStart) + elapsed := time.Since(start) + if err == nil { + t.Fatal("List: want warming error, got nil") + } + if got, want := err.Error(), "list connections: server is warming up, please retry in a moment"; got != want { + t.Fatalf("err = %q, want %q", got, want) + } + if calls.Load() < 2 { + t.Fatalf("calls = %d, want at least 2", calls.Load()) + } + if elapsed > time.Second { + t.Fatalf("List took %v, want retry budget to expire promptly", elapsed) + } +} + +func TestClient_ListDoesNotRetryOn500(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusInternalServerError) + _, _ = io.WriteString(w, "internal host detail") + })) + defer srv.Close() + + c := newClientWithTimings(t, srv.URL, 500*time.Millisecond, 10*time.Millisecond, 0) + _, err := c.List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want error, got nil") + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls = %d, want 1", got) + } + if strings.Contains(err.Error(), "internal host detail") { + t.Fatalf("err = %q, must not surface 5xx body", err.Error()) + } + if !strings.Contains(err.Error(), "HTTP 500: Internal Server Error") { + t.Fatalf("err = %q, want HTTP 500 status text", err.Error()) + } +} + +func TestClient_ListDoesNotRetryOn429(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusTooManyRequests) + _, _ = io.WriteString(w, "slow down") + })) + defer srv.Close() + + c := newClientWithTimings(t, srv.URL, 500*time.Millisecond, 10*time.Millisecond, 0) + _, err := c.List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want error, got nil") + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls = %d, want 1", got) + } + if got, want := err.Error(), "list connections: rate limited, please retry in a moment"; got != want { + t.Fatalf("err = %q, want %q", got, want) + } +} + +func TestClient_ListReturnsStructuredAvailableTargets(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusBadRequest) + _, _ = io.WriteString(w, `{ + "message": "keyspace is required", + "available": { + "keyspaces": ["lookup", "main"], + "shards": ["-80", "80-"] + } + }`) + })) + defer srv.Close() + + _, err := mustClient(t, srv.URL).List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want error, got nil") + } + + var httpErr *HTTPError + if !errors.As(err, &httpErr) { + t.Fatalf("List error = %T %[1]v, want *HTTPError", err) + } + if httpErr.StatusCode != http.StatusBadRequest { + t.Errorf("StatusCode = %d, want 400", httpErr.StatusCode) + } + if httpErr.Message != "keyspace is required" { + t.Errorf("Message = %q, want keyspace is required", httpErr.Message) + } + if !slices.Equal(httpErr.Available.Keyspaces, []string{"lookup", "main"}) { + t.Errorf("Available.Keyspaces = %v, want [lookup main]", httpErr.Available.Keyspaces) + } + if !slices.Equal(httpErr.Available.Shards, []string{"-80", "80-"}) { + t.Errorf("Available.Shards = %v, want [-80 80-]", httpErr.Available.Shards) + } + if got, want := err.Error(), "list connections: HTTP 400: keyspace is required (available keyspaces: lookup, main) (available shards: -80, 80-)"; got != want { + t.Fatalf("err = %q, want %q", got, want) + } +} + +func TestClientListPartialResponseFriendlyMessage(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `{"connections":[`) // deliberately truncated JSON + })) + defer srv.Close() + + c := newClientWithTimings(t, srv.URL, 500*time.Millisecond, 10*time.Millisecond, 0) + _, err := c.List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want error on truncated body, got nil") + } + if got, want := err.Error(), "list connections: received an invalid response, please retry"; got != want { + t.Fatalf("err = %q, want %q", got, want) + } +} + +func TestClient_ListRetryWaitHonorsContextCancellation(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + calls.Add(1) + w.WriteHeader(http.StatusServiceUnavailable) + })) + defer srv.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + c := newClientWithTimings(t, srv.URL, 500*time.Millisecond, 250*time.Millisecond, 0) + start := time.Now() + _, err := c.List(ctx, SortByTransactionStart) + elapsed := time.Since(start) + if err == nil { + t.Fatal("List: want context error, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("err = %v, want errors.Is(err, context.DeadlineExceeded)", err) + } + if got := calls.Load(); got != 1 { + t.Fatalf("calls = %d, want 1 because context expired during retry wait", got) + } + if elapsed > time.Second { + t.Fatalf("List took %v, want context cancellation to interrupt retry wait", elapsed) + } +} + +func TestClient_ListHonorsRetryAfterHeader(t *testing.T) { + var calls atomic.Int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + call := calls.Add(1) + if call == 1 { + // api-bb format from phase-2 RFC line 234: fractional seconds. + w.Header().Set("Retry-After", "0.020") + w.WriteHeader(http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, sampleListResponse) + })) + defer srv.Close() + + // Backoff floor is longer than RequestTimeout. The call can only succeed if + // the server-supplied Retry-After hint overrides the local floor. + c := newClientWithTimings(t, srv.URL, 2*time.Second, 2*time.Second, 0) + c.cfg.RequestTimeout = 750 * time.Millisecond + _, err := c.List(context.Background(), SortByTransactionStart) + if err != nil { + t.Fatalf("List: %v", err) + } +} + +func TestClient_RetryDelayUsesJitter(t *testing.T) { + c := &Client{retryBackoff: 100 * time.Millisecond, retryJitter: 25 * time.Millisecond} + seen := make(map[time.Duration]bool) + for range 20 { + d := c.retryDelay() + if d < 100*time.Millisecond { + t.Fatalf("retryDelay() = %v, want >= 100ms floor", d) + } + if d >= 125*time.Millisecond { + t.Fatalf("retryDelay() = %v, want < 125ms ceiling", d) + } + seen[d] = true + } + if len(seen) <= 1 { + t.Fatalf("retryDelay() produced %d distinct values across 20 calls, want > 1", len(seen)) + } +} + +func TestParseRetryAfterAcceptsPlainDecimalSeconds(t *testing.T) { + cases := []struct { + value string + want time.Duration + }{ + {value: "1", want: time.Second}, + {value: "0.020", want: 20 * time.Millisecond}, + {value: " 2.5 ", want: 2500 * time.Millisecond}, + } + + for _, tc := range cases { + t.Run(tc.value, func(t *testing.T) { + if got := parseRetryAfter(tc.value); got != tc.want { + t.Fatalf("parseRetryAfter(%q) = %v, want %v", tc.value, got, tc.want) + } + }) + } +} + +func TestParseRetryAfterRejectsInvalidDeltaSeconds(t *testing.T) { + for _, value := range []string{ + "", + "NaN", + "+Inf", + "-1", + "+1", + "1e3", + "0x1p+2", + ".5", + "1.", + "1..2", + "1.2.3", + } { + t.Run(value, func(t *testing.T) { + if got := parseRetryAfter(value); got != 0 { + t.Fatalf("parseRetryAfter(%q) = %v, want 0", value, got) + } + }) + } +} + +func TestParseRetryAfterAcceptsFutureHTTPDate(t *testing.T) { + value := time.Now().Add(time.Hour).UTC().Format(http.TimeFormat) + got := parseRetryAfter(value) + if got <= 0 { + t.Fatalf("parseRetryAfter(%q) = %v, want positive duration", value, got) + } +} + +func TestParseRetryAfterRejectsPastHTTPDate(t *testing.T) { + value := time.Now().Add(-time.Hour).UTC().Format(http.TimeFormat) + if got := parseRetryAfter(value); got != 0 { + t.Fatalf("parseRetryAfter(%q) = %v, want 0", value, got) + } +} + +func TestClient_ListServerError(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusBadGateway) + _, _ = io.WriteString(w, "upstream exploded: pod=postgres-primary-7-internal") + })) + defer srv.Close() + + client := mustClient(t, srv.URL) + _, err := client.List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want error, got nil") + } + if strings.Contains(err.Error(), "upstream exploded") || strings.Contains(err.Error(), "postgres-primary-7-internal") { + t.Errorf("err = %v, must not surface 5xx body (would leak backend topology)", err) + } + if strings.Contains(err.Error(), "api-bb") { + t.Errorf("err = %v, must not leak internal service name", err) + } + if !strings.Contains(err.Error(), "502") { + t.Errorf("err = %v, want status 502 in message", err) + } + if !strings.Contains(err.Error(), "Bad Gateway") { + t.Errorf("err = %v, want generic status text 'Bad Gateway'", err) + } + if !strings.Contains(err.Error(), "list connections") { + t.Errorf("err = %v, want user-facing 'list connections' prefix", err) + } +} + +func TestClient_ListRequiresCapturedAt(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.WriteString(w, `{"type":"list","data":[]}`) + })) + defer srv.Close() + + client := mustClient(t, srv.URL) + _, err := client.List(context.Background(), SortByTransactionStart) + if err == nil { + t.Fatal("List: want missing captured_at error, got nil") + } + if !strings.Contains(err.Error(), "captured_at") { + t.Fatalf("err = %v, want captured_at", err) + } +} + +func mustClient(t *testing.T, baseURL string) *Client { + t.Helper() + client, err := NewClient(ClientConfig{ + BaseURL: baseURL, + Organization: "acme", + Database: "prod", + Branch: "main", + ServiceTokenID: "tid", + ServiceToken: "secret", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + return client +} + +// newClientWithTimings builds a Client with custom retry timings for tests. +// Direct field mutation is intentional and safe: no package-level state, so +// parallel tests cannot interfere with one another. +func newClientWithTimings(t *testing.T, url string, budget, backoff, jitter time.Duration) *Client { + t.Helper() + c := mustClient(t, url) + c.retryBudget = budget + c.retryBackoff = backoff + c.retryJitter = jitter + return c +} + +func TestClient_ListHonorsContextDeadline(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(30 * time.Second): + } + })) + defer srv.Close() + + client := mustClient(t, srv.URL) + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := client.List(ctx, SortByTransactionStart) + elapsed := time.Since(start) + if err == nil { + t.Fatal("List: want context error, got nil") + } + if elapsed > 5*time.Second { + t.Errorf("List took %v, want context deadline to fire promptly", elapsed) + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("List error = %v, want context deadline", err) + } +} + +func TestClient_ListHonorsRequestTimeout(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case <-r.Context().Done(): + case <-time.After(30 * time.Second): + } + })) + defer srv.Close() + + client, err := NewClient(ClientConfig{ + BaseURL: srv.URL, + Organization: "acme", + Database: "prod", + Branch: "main", + ServiceTokenID: "tid", + ServiceToken: "secret", + RequestTimeout: 50 * time.Millisecond, + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + start := time.Now() + _, err = client.List(context.Background(), SortByTransactionStart) + elapsed := time.Since(start) + if err == nil { + t.Fatal("List: want timeout error, got nil") + } + if elapsed > time.Second { + t.Errorf("List took %v, want request timeout to fire promptly", elapsed) + } +} + +func TestClient_ActionMethodsIssuePathSegmentedDeletes(t *testing.T) { + qid := "primary-123-1777466070000000" + xid := "primary-123-1777466040000000" + cid := "primary-123-1779113716123456" + const base = "/v1/organizations/acme/databases/prod/branches/main/connections" + + cases := []struct { + name string + call func(*Client) error + wantPath string + }{ + { + name: "CancelQuery", + call: func(c *Client) error { + return c.CancelQuery(context.Background(), ActionTarget{Instance: "primary", PID: 123, QueryID: &qid}) + }, + wantPath: base + "/query/primary-123-1777466070000000", + }, + { + name: "TerminateTransaction", + call: func(c *Client) error { + return c.TerminateTransaction(context.Background(), ActionTarget{Instance: "primary", PID: 123, TransactionID: &xid}) + }, + wantPath: base + "/transaction/primary-123-1777466040000000", + }, + { + name: "TerminateConnection", + call: func(c *Client) error { + return c.TerminateConnection(context.Background(), ActionTarget{Instance: "primary", PID: 123, ConnectionID: &cid}) + }, + wantPath: base + "/connection/primary-123-1779113716123456", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var gotMethod, gotPath, gotQuery string + var gotBody []byte + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotMethod = r.Method + gotPath = r.URL.Path + gotQuery = r.URL.RawQuery + gotBody, _ = io.ReadAll(r.Body) + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + if err := tc.call(mustClient(t, srv.URL)); err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + + if gotMethod != http.MethodDelete { + t.Errorf("method = %q, want DELETE", gotMethod) + } + if gotPath != tc.wantPath { + t.Errorf("path = %q, want %q", gotPath, tc.wantPath) + } + if gotQuery != "" { + t.Errorf("query string = %q, want empty (no mode= param)", gotQuery) + } + if len(gotBody) != 0 { + t.Errorf("body = %q, want empty", gotBody) + } + }) + } +} + +func TestDeleteActionUsesStaleGuardMessage(t *testing.T) { + qid := "primary-123-1777466070000000" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnprocessableEntity) + _, _ = io.WriteString(w, `{"message":"selected query has already ended; refresh and try again"}`) + })) + defer srv.Close() + + err := mustClient(t, srv.URL).CancelQuery(context.Background(), ActionTarget{QueryID: &qid}) + if err == nil { + t.Fatal("CancelQuery: want error, got nil") + } + + if got, want := err.Error(), "cancel query: selected query has already ended; refresh and try again"; got != want { + t.Fatalf("err = %q, want %q", got, want) + } + if strings.Contains(err.Error(), "HTTP 422") { + t.Fatalf("err = %q, must not include HTTP 422", err.Error()) + } +} + +func TestClient_ActionResultMethodsIncludeTargetQueryParams(t *testing.T) { + var gotPath, gotQuery string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotPath = r.URL.Path + gotQuery = r.URL.RawQuery + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"success":true,"keyspace":"commerce","shard":"-80","tablet":"zone1-1001","id":101,"kind":"query"}`) + })) + defer srv.Close() + + queryID := "101" + client, err := NewClient(ClientConfig{ + BaseURL: srv.URL, + Organization: "acme", + Database: "shop", + Branch: "main", + Keyspace: "commerce", + Shard: "-80", + ServiceTokenID: "tid", + ServiceToken: "secret", + }) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + + result, err := client.CancelQueryResult(context.Background(), ActionTarget{QueryID: &queryID}) + if err != nil { + t.Fatalf("CancelQueryResult: %v", err) + } + + if gotPath != "/v1/organizations/acme/databases/shop/branches/main/connections/query/101" { + t.Errorf("path = %q", gotPath) + } + if gotQuery != "keyspace=commerce&shard=-80" { + t.Errorf("query = %q, want keyspace=commerce&shard=-80", gotQuery) + } + if !result.Success || result.ID != 101 || result.Kind != "query" || result.Tablet != "zone1-1001" { + t.Errorf("result = %+v", result) + } +} + +func TestClient_ActionResultMethodsRejectUnsuccessfulResponse(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"success":false}`) + })) + defer srv.Close() + + connectionID := "primary-123-c" + client := mustClient(t, srv.URL) + + _, err := client.TerminateConnectionResult(context.Background(), ActionTarget{ConnectionID: &connectionID}) + if err == nil { + t.Fatal("TerminateConnectionResult returned nil error for success:false") + } + if !strings.Contains(err.Error(), "action did not succeed") { + t.Fatalf("error = %q, want action did not succeed", err.Error()) + } +} + +func TestClient_CancelQueryRejectsMissingQueryID(t *testing.T) { + client := mustClient(t, "http://example") + + for _, target := range []ActionTarget{ + {Instance: "primary", PID: 123}, + {Instance: "primary", PID: 123, QueryID: ptrString("")}, + } { + err := client.CancelQuery(context.Background(), target) + if err == nil { + t.Fatal("CancelQuery returned nil error for missing QueryID") + } + if !strings.Contains(err.Error(), "query_id") { + t.Errorf("error = %q, want message mentioning query_id", err.Error()) + } + } +} + +func TestClient_TerminateTransactionRejectsMissingTransactionID(t *testing.T) { + client := mustClient(t, "http://example") + + for _, target := range []ActionTarget{ + {Instance: "primary", PID: 123}, + {Instance: "primary", PID: 123, TransactionID: ptrString("")}, + } { + err := client.TerminateTransaction(context.Background(), target) + if err == nil { + t.Fatal("TerminateTransaction returned nil error for missing TransactionID") + } + if !strings.Contains(err.Error(), "transaction_id") { + t.Errorf("error = %q, want message mentioning transaction_id", err.Error()) + } + } +} + +func TestClient_TerminateConnectionRejectsMissingConnectionID(t *testing.T) { + client := mustClient(t, "http://example") + + for _, target := range []ActionTarget{ + {Instance: "primary", PID: 123}, + {Instance: "primary", PID: 123, ConnectionID: ptrString("")}, + } { + err := client.TerminateConnection(context.Background(), target) + if err == nil { + t.Fatal("TerminateConnection returned nil error for missing ConnectionID") + } + if !strings.Contains(err.Error(), "connection_id") { + t.Errorf("error = %q, want message mentioning connection_id", err.Error()) + } + } +} + +func ptrString(v string) *string { + return &v +} diff --git a/internal/connections/connection_list.go b/internal/connections/connection_list.go new file mode 100644 index 00000000..45ba3f78 --- /dev/null +++ b/internal/connections/connection_list.go @@ -0,0 +1,260 @@ +package connections + +import ( + "slices" + "strconv" + "strings" + "time" +) + +type SortMode string + +const ( + SortByTransactionStart SortMode = "xact_start" + SortByDuration SortMode = "duration" + SortByBlocked SortMode = "blocked" +) + +type DatabaseKind string + +const ( + DatabaseKindMySQL DatabaseKind = "mysql" + DatabaseKindPostgreSQL DatabaseKind = "postgresql" +) + +type InstanceMeta struct { + ID string `json:"id"` + Role string `json:"role"` + Error string `json:"error,omitempty"` +} + +type Topology struct { + Keyspace string `json:"keyspace,omitempty"` + Shard string `json:"shard,omitempty"` + Tablet string `json:"tablet,omitempty"` +} + +type ConnectionList struct { + CapturedAt time.Time `json:"captured_at"` + DatabaseKind DatabaseKind `json:"database_kind,omitempty"` + Connections []Connection `json:"connections"` + Sort SortMode `json:"sort"` + Instances []InstanceMeta `json:"instances,omitempty"` + Topology *Topology `json:"topology,omitempty"` +} + +type Connection struct { + PID int `json:"pid"` + Instance string `json:"instance"` + InstanceRole string `json:"instance_role,omitempty"` + Username string `json:"username"` + ApplicationName string `json:"application_name"` + DatabaseName string `json:"database,omitempty"` + ClientAddr string `json:"client_addr"` + State string `json:"state"` + WaitEventType string `json:"wait_event_type"` + WaitEvent string `json:"wait_event"` + BackendType string `json:"backend_type"` + XactStart *time.Time `json:"xact_start,omitempty"` + QueryStart *time.Time `json:"query_start,omitempty"` + ConnectionID *string `json:"connection_id,omitempty"` + TransactionID *string `json:"transaction_id,omitempty"` + QueryID *string `json:"query_id,omitempty"` + Duration time.Duration `json:"duration"` + BlockedBy []int `json:"blocked_by,omitempty"` + QueryText string `json:"query_text,omitempty"` +} + +func NewConnectionList(capturedAt time.Time, connections []Connection, sort SortMode) ConnectionList { + out := slices.Clone(connections) + SortConnections(out, sort) + + return ConnectionList{ + CapturedAt: capturedAt, + Connections: out, + Sort: sort, + } +} + +func SortConnections(connections []Connection, mode SortMode) { + switch mode { + case SortByBlocked: + counts := BlockingCounts(connections) + slices.SortStableFunc(connections, func(a, b Connection) int { + return compareBlocked(a, b, counts) + }) + case SortByDuration: + slices.SortStableFunc(connections, compareDuration) + default: + slices.SortStableFunc(connections, compareTransactionStart) + } +} + +func compareDuration(a, b Connection) int { + if a.Duration > b.Duration { + return -1 + } + if a.Duration < b.Duration { + return 1 + } + if cmp := compareBool(durationSortHasWork(a), durationSortHasWork(b)); cmp != 0 { + return cmp + } + return a.PID - b.PID +} + +func durationSortHasWork(c Connection) bool { + state := strings.ToLower(strings.TrimSpace(c.State)) + if state == "sleep" || state == "idle" || strings.HasPrefix(state, "idle ") { + return false + } + return state != "" || strings.TrimSpace(c.QueryText) != "" +} + +func compareBlocked(a, b Connection, counts map[int]int) int { + if cmp := compareIntDesc(counts[a.PID], counts[b.PID]); cmp != 0 { + return cmp + } + if cmp := compareBool(activeAndBlocked(a), activeAndBlocked(b)); cmp != 0 { + return cmp + } + if cmp := compareBool(len(a.BlockedBy) > 0, len(b.BlockedBy) > 0); cmp != 0 { + return cmp + } + return compareTransactionStart(a, b) +} + +func activeAndBlocked(c Connection) bool { + return c.State == "active" && len(c.BlockedBy) > 0 +} + +func compareBool(a, b bool) int { + switch { + case a == b: + return 0 + case a: + return -1 + default: + return 1 + } +} + +func compareIntDesc(a, b int) int { + if a > b { + return -1 + } + if a < b { + return 1 + } + return 0 +} + +// BlockingCounts returns each connection's downstream blocking depth by PID. +func BlockingCounts(connections []Connection) map[int]int { + downstream := make(map[int][]int) + for _, conn := range connections { + for _, blockerPID := range conn.BlockedBy { + downstream[blockerPID] = append(downstream[blockerPID], conn.PID) + } + } + + counts := make(map[int]int) + for _, conn := range connections { + count := downstreamCount(conn.PID, downstream, map[int]bool{conn.PID: true}) + if count > 0 { + counts[conn.PID] = count + } + } + return counts +} + +func downstreamCount(pid int, downstream map[int][]int, seen map[int]bool) int { + count := 0 + for _, blockedPID := range downstream[pid] { + if seen[blockedPID] { + continue + } + seen[blockedPID] = true + count += 1 + downstreamCount(blockedPID, downstream, seen) + } + return count +} + +func compareTransactionStart(a, b Connection) int { + if cmp := compareNullableTime(a.XactStart, b.XactStart); cmp != 0 { + return cmp + } + if cmp := compareNullableTime(a.QueryStart, b.QueryStart); cmp != 0 { + return cmp + } + return a.PID - b.PID +} + +func compareNullableTime(a, b *time.Time) int { + switch { + case a == nil && b == nil: + return 0 + case a == nil: + return 1 + case b == nil: + return -1 + case a.Before(*b): + return -1 + case a.After(*b): + return 1 + default: + return 0 + } +} + +// HumanFields returns the connection's scalar fields as ordered (name, value) +// pairs for vertical, MySQL `\G`-style rendering. The query text is excluded — +// callers render it separately because it is multi-line. This is the single +// source of truth shared by the agent-cli `list --format human` output and the +// interactive detail view, so the two never drift. +func (c Connection) HumanFields() [][2]string { + return [][2]string{ + {"pid", strconv.Itoa(c.PID)}, + {"instance", c.Instance}, + {"role", c.InstanceRole}, + {"state", c.State}, + {"duration", c.Duration.String()}, + {"wait", JoinWaitEvents(c.WaitEventType, c.WaitEvent)}, + {"user", c.Username}, + {"application", c.ApplicationName}, + {"client_addr", c.ClientAddr}, + {"blocked_by", JoinInts(c.BlockedBy)}, + {"query_id", DerefString(c.QueryID)}, + {"transaction_id", DerefString(c.TransactionID)}, + {"connection_id", DerefString(c.ConnectionID)}, + } +} + +// JoinWaitEvents renders the wait-event type/event pair as "type/event", +// collapsing to whichever side is present when one is empty. +func JoinWaitEvents(waitEventType, waitEvent string) string { + if waitEventType == "" { + return waitEvent + } + if waitEvent == "" { + return waitEventType + } + return waitEventType + "/" + waitEvent +} + +// JoinInts renders a blocked-by PID slice as a comma-separated string. +func JoinInts(values []int) string { + parts := make([]string, 0, len(values)) + for _, value := range values { + parts = append(parts, strconv.Itoa(value)) + } + return strings.Join(parts, ",") +} + +// DerefString returns the pointed-to string, or "" when nil. +func DerefString(value *string) string { + if value == nil { + return "" + } + return *value +} diff --git a/internal/connections/connection_list_test.go b/internal/connections/connection_list_test.go new file mode 100644 index 00000000..32c52405 --- /dev/null +++ b/internal/connections/connection_list_test.go @@ -0,0 +1,267 @@ +package connections + +import ( + "encoding/json" + "testing" + "time" + + qt "github.com/frankban/quicktest" +) + +func TestNewConnectionListSortsByTransactionStartWithNullsLast(t *testing.T) { + c := qt.New(t) + now := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + oldXact := now.Add(-20 * time.Minute) + newXact := now.Add(-5 * time.Minute) + oldQuery := now.Add(-10 * time.Minute) + newQuery := now.Add(-2 * time.Minute) + + list := NewConnectionList(now, []Connection{ + {PID: 4, XactStart: nil, QueryStart: &newQuery}, + {PID: 2, XactStart: &newXact, QueryStart: &oldQuery}, + {PID: 3, XactStart: nil, QueryStart: &oldQuery}, + {PID: 1, XactStart: &oldXact, QueryStart: &newQuery}, + }, SortByTransactionStart) + + c.Assert(pids(list.Connections), qt.DeepEquals, []int{1, 2, 3, 4}) +} + +func TestNewConnectionListSortsByTransactionStartWithQueryStartAndPIDTieBreakers(t *testing.T) { + c := qt.New(t) + now := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + xactStart := now.Add(-20 * time.Minute) + oldQuery := now.Add(-10 * time.Minute) + newQuery := now.Add(-2 * time.Minute) + + list := NewConnectionList(now, []Connection{ + {PID: 6, XactStart: &xactStart, QueryStart: nil}, + {PID: 4, XactStart: &xactStart, QueryStart: &oldQuery}, + {PID: 2, XactStart: &xactStart, QueryStart: &oldQuery}, + {PID: 5, XactStart: &xactStart, QueryStart: &newQuery}, + {PID: 3, XactStart: &xactStart, QueryStart: nil}, + }, SortByTransactionStart) + + c.Assert(pids(list.Connections), qt.DeepEquals, []int{2, 4, 5, 3, 6}) +} + +func TestConnectionListJSONUsesFoundationShape(t *testing.T) { + c := qt.New(t) + capturedAt := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + xactStart := capturedAt.Add(-20 * time.Minute) + queryStart := capturedAt.Add(-5 * time.Minute) + + list := ConnectionList{ + CapturedAt: capturedAt, + Connections: []Connection{{ + PID: 101, + Instance: "primary", + Username: "brett", + ApplicationName: "psql", + ClientAddr: "127.0.0.1", + State: "active", + WaitEventType: "Lock", + WaitEvent: "transactionid", + BackendType: "client backend", + XactStart: &xactStart, + QueryStart: &queryStart, + Duration: 5 * time.Second, + BlockedBy: []int{201}, + QueryText: "SELECT * FROM widgets", + }}, + Sort: SortByTransactionStart, + } + + data, err := json.Marshal(list) + c.Assert(err, qt.IsNil) + c.Assert(string(data), qt.JSONEquals, map[string]any{ + "captured_at": capturedAt.Format(time.RFC3339), + "connections": []any{map[string]any{ + "pid": 101, + "instance": "primary", + "username": "brett", + "application_name": "psql", + "client_addr": "127.0.0.1", + "state": "active", + "wait_event_type": "Lock", + "wait_event": "transactionid", + "backend_type": "client backend", + "xact_start": xactStart.Format(time.RFC3339), + "query_start": queryStart.Format(time.RFC3339), + "duration": 5000000000, + "blocked_by": []any{201}, + "query_text": "SELECT * FROM widgets", + }}, + "sort": "xact_start", + }) +} + +func TestConnectionRoundTripsCompositeKey(t *testing.T) { + c := qt.New(t) + + original := Connection{PID: 123, Instance: "primary"} + + data, err := json.Marshal(original) + c.Assert(err, qt.IsNil) + + var decoded Connection + c.Assert(json.Unmarshal(data, &decoded), qt.IsNil) + c.Assert(decoded.PID, qt.Equals, 123) + c.Assert(decoded.Instance, qt.Equals, "primary") +} + +func TestConnectionListJSONRoundTripsInstances(t *testing.T) { + c := qt.New(t) + capturedAt := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + + original := ConnectionList{ + CapturedAt: capturedAt, + Connections: []Connection{}, + Sort: SortByTransactionStart, + Instances: []InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-1", Role: "replica", Error: "timeout after 2s"}, + }, + } + + data, err := json.Marshal(original) + c.Assert(err, qt.IsNil) + + var decoded ConnectionList + c.Assert(json.Unmarshal(data, &decoded), qt.IsNil) + c.Assert(decoded.Instances, qt.DeepEquals, original.Instances) +} + +func TestConnectionJSONOmitsOpaqueIDField(t *testing.T) { + c := qt.New(t) + conn := Connection{PID: 123, Instance: "primary"} + + data, err := json.Marshal(conn) + c.Assert(err, qt.IsNil) + c.Assert(string(data), qt.Not(qt.Contains), `"id"`) +} + +func TestConnectionCarriesActionIDs(t *testing.T) { + c := qt.New(t) + xactStart := time.Date(2026, 4, 28, 15, 12, 4, 123_000, time.UTC) + queryStart := xactStart.Add(time.Second) + connID := "primary-101-1779113716123456" + txID := "primary-101-1777476480123456" + qID := "primary-101-1777476481456789" + + conn := Connection{ + PID: 101, + Instance: "primary", + XactStart: &xactStart, + QueryStart: &queryStart, + ConnectionID: &connID, + TransactionID: &txID, + QueryID: &qID, + } + + data, err := json.Marshal(conn) + c.Assert(err, qt.IsNil) + c.Assert(string(data), qt.Contains, `"connection_id":"primary-101-1779113716123456"`) + c.Assert(string(data), qt.Contains, `"transaction_id":"primary-101-1777476480123456"`) + c.Assert(string(data), qt.Contains, `"query_id":"primary-101-1777476481456789"`) + + var decoded Connection + c.Assert(json.Unmarshal(data, &decoded), qt.IsNil) + c.Assert(decoded.ConnectionID, qt.IsNotNil) + c.Assert(*decoded.ConnectionID, qt.Equals, "primary-101-1779113716123456") + c.Assert(decoded.TransactionID, qt.IsNotNil) + c.Assert(*decoded.TransactionID, qt.Equals, "primary-101-1777476480123456") + c.Assert(decoded.QueryID, qt.IsNotNil) + c.Assert(*decoded.QueryID, qt.Equals, "primary-101-1777476481456789") +} + +func TestConnectionOmitsNilActionIDs(t *testing.T) { + c := qt.New(t) + conn := Connection{PID: 1, Instance: "primary"} + + data, err := json.Marshal(conn) + c.Assert(err, qt.IsNil) + c.Assert(string(data), qt.Not(qt.Contains), `"connection_id"`) + c.Assert(string(data), qt.Not(qt.Contains), `"transaction_id"`) + c.Assert(string(data), qt.Not(qt.Contains), `"query_id"`) +} + +func TestSortConnectionsByDurationDescending(t *testing.T) { + c := qt.New(t) + + connections := []Connection{ + {PID: 1, Duration: 1 * time.Second}, + {PID: 2, Duration: 10 * time.Second}, + {PID: 3, Duration: 5 * time.Second}, + } + SortConnections(connections, SortByDuration) + + c.Assert(pids(connections), qt.DeepEquals, []int{2, 3, 1}) +} + +func TestSortConnectionsByDurationPrioritizesActiveRowsOnTie(t *testing.T) { + c := qt.New(t) + + connections := []Connection{ + {PID: 10, State: "Sleep", Duration: 0}, + {PID: 20, State: "Query/update", Duration: 0, QueryText: "INSERT INTO events VALUES (1)"}, + {PID: 30, State: "Sleep", Duration: 0}, + } + SortConnections(connections, SortByDuration) + + c.Assert(pids(connections), qt.DeepEquals, []int{20, 10, 30}) +} + +func TestSortConnectionsByBlockedPutsActiveBlockedFirst(t *testing.T) { + c := qt.New(t) + xactStart := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + + connections := []Connection{ + {PID: 1, State: "idle", BlockedBy: []int{99}, XactStart: &xactStart}, + {PID: 2, State: "active", BlockedBy: []int{99}, XactStart: &xactStart}, + {PID: 3, State: "active", XactStart: &xactStart}, + } + SortConnections(connections, SortByBlocked) + + c.Assert(pids(connections), qt.DeepEquals, []int{2, 1, 3}) +} + +func TestSortConnectionsByBlockedPutsRootBlockersFirst(t *testing.T) { + c := qt.New(t) + xactStart := time.Date(2026, 4, 28, 15, 12, 4, 0, time.UTC) + + connections := []Connection{ + {PID: 20, State: "active", BlockedBy: []int{10}, XactStart: &xactStart}, + {PID: 30, State: "active", BlockedBy: []int{20}, XactStart: &xactStart}, + {PID: 40, State: "active", BlockedBy: []int{10}, XactStart: &xactStart}, + {PID: 10, State: "idle in transaction", XactStart: &xactStart}, + {PID: 50, State: "active", XactStart: &xactStart}, + } + SortConnections(connections, SortByBlocked) + + c.Assert(pids(connections), qt.DeepEquals, []int{10, 20, 30, 40, 50}) +} + +func TestBlockingCountsCountsDownstreamDepth(t *testing.T) { + c := qt.New(t) + + connections := []Connection{ + {PID: 1}, + {PID: 2, BlockedBy: []int{1}}, + {PID: 3, BlockedBy: []int{2}}, + {PID: 4, BlockedBy: []int{1}}, + } + + counts := BlockingCounts(connections) + + c.Assert(counts[1], qt.Equals, 3) + c.Assert(counts[2], qt.Equals, 1) + c.Assert(counts[3], qt.Equals, 0) +} + +func pids(connections []Connection) []int { + out := make([]int, 0, len(connections)) + for _, conn := range connections { + out = append(out, conn.PID) + } + return out +} diff --git a/internal/connections/history/capture.go b/internal/connections/history/capture.go new file mode 100644 index 00000000..4f117c94 --- /dev/null +++ b/internal/connections/history/capture.go @@ -0,0 +1,19 @@ +package history + +import ( + "time" + + "github.com/planetscale/cli/internal/connections" +) + +type Capture struct { + At time.Time `json:"at"` + List connections.ConnectionList `json:"capture"` +} + +func NewCapture(list connections.ConnectionList) Capture { + return Capture{ + At: list.CapturedAt, + List: list, + } +} diff --git a/internal/connections/history/capture_history.go b/internal/connections/history/capture_history.go new file mode 100644 index 00000000..a59808b1 --- /dev/null +++ b/internal/connections/history/capture_history.go @@ -0,0 +1,127 @@ +package history + +import ( + "slices" + + "github.com/planetscale/cli/internal/connections" +) + +// CaptureCursor identifies a capture stored in a CaptureHistory. Cursors are +// monotonic and stable across eviction — once a cursor's capture is evicted, +// At returns false for it, but subsequent captures keep advancing the cursor +// space so the model can hold a step position without recomputing on push. +type CaptureCursor int + +// CaptureHistory is a capped ring of captured ConnectionLists. New captures +// push out the oldest once capacity is reached. +type CaptureHistory struct { + capacity int + samples []connections.ConnectionList + base CaptureCursor +} + +func NewCaptureHistory(capacity int) *CaptureHistory { + if capacity < 1 { + capacity = 1 + } + return &CaptureHistory{capacity: capacity} +} + +// Push stores list and returns its cursor. The Connections slice is cloned so +// callers can mutate (e.g. sort in place) without corrupting history. +func (h *CaptureHistory) Push(list connections.ConnectionList) CaptureCursor { + list.Connections = slices.Clone(list.Connections) + h.samples = append(h.samples, list) + if len(h.samples) > h.capacity { + h.samples = h.samples[len(h.samples)-h.capacity:] + h.base += CaptureCursor(1) + } + return h.base + CaptureCursor(len(h.samples)-1) +} + +func (h *CaptureHistory) At(cursor CaptureCursor) (connections.ConnectionList, bool) { + idx := int(cursor - h.base) + if idx < 0 || idx >= len(h.samples) { + return connections.ConnectionList{}, false + } + return h.samples[idx], true +} + +func (h *CaptureHistory) All() []connections.ConnectionList { + out := make([]connections.ConnectionList, 0, len(h.samples)) + for _, list := range h.samples { + list.Connections = slices.Clone(list.Connections) + out = append(out, list) + } + return out +} + +// Step moves cursor by delta, clamped to [Oldest, Latest]. Returns the new +// cursor and whether it moved. A cursor outside the live range is treated as +// "off the oldest edge" / "off the newest edge" — positive delta from below +// snaps to the oldest, negative from above snaps to the latest. +func (h *CaptureHistory) Step(cursor CaptureCursor, delta int) (CaptureCursor, bool) { + if len(h.samples) == 0 || delta == 0 { + return cursor, false + } + oldest := h.base + latest := h.base + CaptureCursor(len(h.samples)-1) + if cursor < oldest { + if delta > 0 { + return oldest, true + } + return cursor, false + } + if cursor > latest { + if delta < 0 { + return latest, true + } + return cursor, false + } + target := cursor + CaptureCursor(delta) + if target < oldest { + target = oldest + } + if target > latest { + target = latest + } + return target, target != cursor +} + +func (h *CaptureHistory) Oldest() (CaptureCursor, bool) { + if len(h.samples) == 0 { + return 0, false + } + return h.base, true +} + +func (h *CaptureHistory) Latest() (CaptureCursor, bool) { + if len(h.samples) == 0 { + return 0, false + } + return h.base + CaptureCursor(len(h.samples)-1), true +} + +// MustLatest panics when history is empty — convenient in tests where a +// caller has just pushed and knows the history is non-empty. +func (h *CaptureHistory) MustLatest() CaptureCursor { + cursor, ok := h.Latest() + if !ok { + panic("CaptureHistory: MustLatest called on empty history") + } + return cursor +} + +// Position returns the 1-based index of cursor in the live window and the +// total number of retained samples. Returns (0, total) when cursor has been +// evicted. +func (h *CaptureHistory) Position(cursor CaptureCursor) (n, total int) { + total = len(h.samples) + idx := int(cursor - h.base) + if idx < 0 || idx >= total { + return 0, total + } + return idx + 1, total +} + +func (h *CaptureHistory) Len() int { return len(h.samples) } diff --git a/internal/connections/history/capture_history_test.go b/internal/connections/history/capture_history_test.go new file mode 100644 index 00000000..098114a7 --- /dev/null +++ b/internal/connections/history/capture_history_test.go @@ -0,0 +1,150 @@ +package history + +import ( + "testing" + "time" + + qt "github.com/frankban/quicktest" + "github.com/planetscale/cli/internal/connections" +) + +func TestCaptureHistoryAtReturnsPushedSample(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(10) + cursor := h.Push(listWithPID(101)) + + list, ok := h.At(cursor) + + c.Assert(ok, qt.IsTrue) + c.Assert(list.Connections[0].PID, qt.Equals, 101) +} + +func TestCaptureHistoryEvictsOldestPastCapacity(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(3) + + first := h.Push(listWithPID(100)) + h.Push(listWithPID(101)) + h.Push(listWithPID(102)) + h.Push(listWithPID(103)) + + _, ok := h.At(first) + c.Assert(ok, qt.IsFalse) + c.Assert(h.Len(), qt.Equals, 3) +} + +// Cursors remain valid across eviction: a stable identity lets the model pin a +// step position without recomputing on every Push. +func TestCaptureHistoryCursorsSurviveEvictionUntilEvicted(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(2) + + first := h.Push(listWithPID(100)) + second := h.Push(listWithPID(101)) + h.Push(listWithPID(102)) + + _, firstOK := h.At(first) + list, secondOK := h.At(second) + + c.Assert(firstOK, qt.IsFalse) + c.Assert(secondOK, qt.IsTrue) + c.Assert(list.Connections[0].PID, qt.Equals, 101) +} + +func TestCaptureHistoryStepMovesCursorWithinBounds(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(10) + first := h.Push(listWithPID(100)) + h.Push(listWithPID(101)) + third := h.Push(listWithPID(102)) + + moved, ok := h.Step(first, 2) + c.Assert(ok, qt.IsTrue) + c.Assert(moved, qt.Equals, third) + + back, ok := h.Step(third, -2) + c.Assert(ok, qt.IsTrue) + c.Assert(back, qt.Equals, first) +} + +func TestCaptureHistoryStepClampsAtBounds(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(10) + first := h.Push(listWithPID(100)) + h.Push(listWithPID(101)) + + beforeOldest, ok := h.Step(first, -5) + c.Assert(ok, qt.IsFalse) + c.Assert(beforeOldest, qt.Equals, first) + + pastNewest, ok := h.Step(first, 99) + c.Assert(ok, qt.IsTrue) + c.Assert(pastNewest, qt.Equals, h.MustLatest()) +} + +func TestCaptureHistoryOldestLatestPosition(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(10) + first := h.Push(listWithPID(100)) + h.Push(listWithPID(101)) + third := h.Push(listWithPID(102)) + + oldest, ok := h.Oldest() + c.Assert(ok, qt.IsTrue) + c.Assert(oldest, qt.Equals, first) + + latest, ok := h.Latest() + c.Assert(ok, qt.IsTrue) + c.Assert(latest, qt.Equals, third) + + pos, total := h.Position(first) + c.Assert(pos, qt.Equals, 1) + c.Assert(total, qt.Equals, 3) + + pos, total = h.Position(third) + c.Assert(pos, qt.Equals, 3) + c.Assert(total, qt.Equals, 3) +} + +// Push must clone the Connections slice so callers (the TUI) can sort it in +// place without corrupting the stored history. +func TestCaptureHistoryPushClonesConnections(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(10) + input := listWithPID(100) + cursor := h.Push(input) + input.Connections[0].PID = 999 + + stored, ok := h.At(cursor) + c.Assert(ok, qt.IsTrue) + c.Assert(stored.Connections[0].PID, qt.Equals, 100) +} + +func TestCaptureHistoryAllReturnsSamplesInOrder(t *testing.T) { + c := qt.New(t) + h := NewCaptureHistory(2) + h.Push(listWithPID(100)) + h.Push(listWithPID(101)) + h.Push(listWithPID(102)) + + got := h.All() + + c.Assert(got, qt.HasLen, 2) + c.Assert(got[0].Connections[0].PID, qt.Equals, 101) + c.Assert(got[1].Connections[0].PID, qt.Equals, 102) + + got[0].Connections[0].PID = 999 + oldest, ok := h.Oldest() + c.Assert(ok, qt.IsTrue) + stored, ok := h.At(oldest) + c.Assert(ok, qt.IsTrue) + c.Assert(stored.Connections[0].PID, qt.Equals, 101) +} + +func listWithPID(pid int) connections.ConnectionList { + return connections.NewConnectionList( + time.Date(2026, 5, 27, 12, 0, pid, 0, time.UTC), + []connections.Connection{{PID: pid, Instance: "primary"}}, + connections.SortByTransactionStart, + ) +} diff --git a/internal/connections/history/capture_reader.go b/internal/connections/history/capture_reader.go new file mode 100644 index 00000000..b6435e98 --- /dev/null +++ b/internal/connections/history/capture_reader.go @@ -0,0 +1,113 @@ +package history + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "strings" +) + +// CaptureReader reads Capture records from a stream previously written by +// CaptureWriter. It skips capture_start headers and rejects unsupported schema +// versions on capture_start. +type CaptureReader struct { + reader *bufio.Reader + line int + done bool + captureStart *CaptureStart +} + +func NewCaptureReader(r io.Reader) *CaptureReader { + return &CaptureReader{reader: bufio.NewReader(r)} +} + +// Read returns the next Capture in the stream, or io.EOF when exhausted. A +// torn final line (no trailing newline, partial JSON) is treated as EOF so +// captures interrupted by SIGINT remain replayable. +func (r *CaptureReader) Read() (Capture, error) { + if r.done { + return Capture{}, io.EOF + } + + for { + line, err := r.reader.ReadString('\n') + if err != nil && err != io.EOF { + return Capture{}, err + } + if err == io.EOF { + r.done = true + } + if line == "" && err == io.EOF { + return Capture{}, io.EOF + } + + r.line++ + if strings.TrimSpace(line) == "" { + if err == io.EOF { + return Capture{}, io.EOF + } + continue + } + + complete := strings.HasSuffix(line, "\n") + var envelope struct { + Type string `json:"type"` + SchemaVersion int `json:"schema_version"` + } + if decodeErr := json.Unmarshal([]byte(line), &envelope); decodeErr != nil { + if err == io.EOF && !complete { + return Capture{}, io.EOF + } + return Capture{}, fmt.Errorf("capture line %d: %w", r.line, decodeErr) + } + if envelope.Type == "capture_start" { + if envelope.SchemaVersion != traceSchemaVersion { + return Capture{}, fmt.Errorf("capture line %d: schema version %d is not supported (expected %d)", r.line, envelope.SchemaVersion, traceSchemaVersion) + } + var start CaptureStart + if decodeErr := json.Unmarshal([]byte(line), &start); decodeErr == nil { + r.captureStart = &start + } + if err == io.EOF { + return Capture{}, io.EOF + } + continue + } + if envelope.Type != "" { + return Capture{}, fmt.Errorf("capture line %d: record type %q is not supported", r.line, envelope.Type) + } + + var capture Capture + decoder := json.NewDecoder(strings.NewReader(line)) + decoder.DisallowUnknownFields() + if decodeErr := decoder.Decode(&capture); decodeErr != nil { + return Capture{}, fmt.Errorf("capture line %d: %w", r.line, decodeErr) + } + return capture, nil + } +} + +// CaptureStart returns the parsed capture_start metadata when the trace +// contains a well-formed header. +func (r *CaptureReader) CaptureStart() (CaptureStart, bool) { + if r.captureStart == nil { + return CaptureStart{}, false + } + return *r.captureStart, true +} + +// ReadAll consumes the stream and returns every complete capture in order. +func (r *CaptureReader) ReadAll() ([]Capture, error) { + var captures []Capture + for { + capture, err := r.Read() + if err == io.EOF { + return captures, nil + } + if err != nil { + return nil, err + } + captures = append(captures, capture) + } +} diff --git a/internal/connections/history/capture_reader_test.go b/internal/connections/history/capture_reader_test.go new file mode 100644 index 00000000..0512b64d --- /dev/null +++ b/internal/connections/history/capture_reader_test.go @@ -0,0 +1,166 @@ +package history + +import ( + "bytes" + "io" + "strings" + "testing" + "time" + + "github.com/planetscale/cli/internal/connections" + + qt "github.com/frankban/quicktest" +) + +func TestCaptureReaderReadsCapturesWrittenByCaptureWriter(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + }), qt.IsNil) + c.Assert(writer.Write(traceCapture()), qt.IsNil) + + reader := NewCaptureReader(&buffer) + capture, err := reader.Read() + + c.Assert(err, qt.IsNil) + c.Assert(capture.List.Connections[0].PID, qt.Equals, 101) + c.Assert(capture.List.Connections[0].QueryText, qt.Equals, "SELECT * FROM widgets") + + _, err = reader.Read() + c.Assert(err, qt.Equals, io.EOF) +} + +func TestCaptureReaderRejectsUnsupportedSchemaVersion(t *testing.T) { + c := qt.New(t) + input := `{"type":"capture_start","schema_version":2}` + "\n" + reader := NewCaptureReader(strings.NewReader(input)) + + _, err := reader.Read() + + c.Assert(err, qt.ErrorMatches, `capture line 1: schema version 2 is not supported \(expected 1\)`) +} + +// Tail tolerance: a SIGINT during write can leave a torn final line. The +// reader should return EOF, not surface a JSON error, so replay still works on +// a partially-flushed file. +func TestCaptureReaderToleratesTornFinalLine(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.Write(traceCapture()), qt.IsNil) + buffer.WriteString(`{"at":"2026-04-28T15:00:01Z"`) + + reader := NewCaptureReader(&buffer) + _, err := reader.Read() + c.Assert(err, qt.IsNil) + _, err = reader.Read() + c.Assert(err, qt.Equals, io.EOF) +} + +// A complete final line that is malformed JSON must surface as an error, not +// be silently dropped — otherwise file corruption goes undetected. +func TestCaptureReaderReportsMalformedCompleteFinalLine(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.Write(traceCapture()), qt.IsNil) + buffer.WriteString(`{"at":` + "\n") + + reader := NewCaptureReader(&buffer) + _, err := reader.Read() + c.Assert(err, qt.IsNil) + _, err = reader.Read() + c.Assert(err, qt.ErrorMatches, `capture line 2: .*`) +} + +func TestCaptureReaderReadAllReturnsCapturesInOrder(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC)}), qt.IsNil) + + at := time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC) + for i := 0; i < 3; i++ { + list := connections.NewConnectionList(at.Add(time.Duration(i)*time.Second), []connections.Connection{{ + PID: 100 + i, + Instance: "primary", + }}, connections.SortByTransactionStart) + c.Assert(writer.Write(NewCapture(list)), qt.IsNil) + } + + reader := NewCaptureReader(&buffer) + captures, err := reader.ReadAll() + + c.Assert(err, qt.IsNil) + c.Assert(captures, qt.HasLen, 3) + for i, capture := range captures { + c.Assert(capture.List.Connections[0].PID, qt.Equals, 100+i) + } +} + +func TestCaptureReaderReadAllSkipsMalformedCaptureStartMetadata(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + buffer.WriteString(`{"type":"capture_start","schema_version":1,"at":{}}` + "\n") + writer := NewCaptureWriter(&buffer) + c.Assert(writer.Write(traceCapture()), qt.IsNil) + + reader := NewCaptureReader(bytes.NewReader(buffer.Bytes())) + captures, err := reader.ReadAll() + + c.Assert(err, qt.IsNil) + c.Assert(captures, qt.HasLen, 1) + c.Assert(captures[0].List.Connections[0].PID, qt.Equals, 101) +} + +// Round-trip byte identity: writing → reading → writing the same captures +// produces byte-identical output. Locks the format against accidental +// reordering or escaping drift in the serializer. +func TestCaptureRoundTripIsByteIdentical(t *testing.T) { + c := qt.New(t) + var first bytes.Buffer + w1 := NewCaptureWriter(&first) + c.Assert(w1.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + }), qt.IsNil) + c.Assert(w1.Write(traceCapture()), qt.IsNil) + + reader := NewCaptureReader(bytes.NewReader(first.Bytes())) + captures, err := reader.ReadAll() + c.Assert(err, qt.IsNil) + + var second bytes.Buffer + w2 := NewCaptureWriter(&second) + c.Assert(w2.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + }), qt.IsNil) + for _, capture := range captures { + c.Assert(w2.Write(capture), qt.IsNil) + } + + c.Assert(second.Bytes(), qt.DeepEquals, first.Bytes()) +} + +// Anti-regression: a pre-amendment capture file with old-only fields must +// surface a clear decode error rather than being accepted as a replayable +// capture. +func TestCaptureReaderRejectsOldShapeCaptureCleanly(t *testing.T) { + c := qt.New(t) + oldShape := `{"at":"2026-04-28T15:00:00Z","capture":{"captured_at":"2026-04-28T15:00:00Z","connections":[{"id":"primary-10","pid":10,"instance":"primary","username":"brett","application_name":"psql","client_addr":"127.0.0.1","state":"active","wait_event_type":"","wait_event":"","backend_type":"","query_id":"primary-10-1779113716123456","duration":5000000000}],"sort":"xact_start"}}` + "\n" + + reader := NewCaptureReader(strings.NewReader(oldShape)) + _, err := reader.Read() + + c.Assert(err, qt.ErrorMatches, `capture line 1: json: unknown field "id"`) +} diff --git a/internal/connections/history/capture_writer.go b/internal/connections/history/capture_writer.go new file mode 100644 index 00000000..a85a5e80 --- /dev/null +++ b/internal/connections/history/capture_writer.go @@ -0,0 +1,92 @@ +package history + +import ( + "encoding/json" + "io" + "time" +) + +const traceSchemaVersion = 1 + +type CaptureStart struct { + Type string `json:"type"` + At time.Time `json:"at"` + Organization string `json:"org"` + Database string `json:"database"` + Branch string `json:"branch"` + SchemaVersion int `json:"schema_version"` + Filter *CaptureFilter `json:"filter,omitempty"` + Target *CaptureTarget `json:"target,omitempty"` +} + +// CaptureTarget records the Vitess target active at capture time. +type CaptureTarget struct { + Keyspace string `json:"keyspace,omitempty"` + Shard string `json:"shard,omitempty"` +} + +// CaptureFilter records any client-side row filter active at capture time so +// replay tooling can tell a partial snapshot from a complete branch view. +// Omitted when no filter was set. +type CaptureFilter struct { + Instance string `json:"instance,omitempty"` + Role string `json:"role,omitempty"` +} + +// CaptureWriter writes captures as newline-delimited JSON records. +type CaptureWriter struct { + writer io.Writer +} + +func NewCaptureWriter(writer io.Writer) *CaptureWriter { + return &CaptureWriter{writer: writer} +} + +func (w *CaptureWriter) Write(capture Capture) error { + return w.writeRecord(capture) +} + +func (w *CaptureWriter) WriteCaptureStart(start CaptureStart) error { + start.Type = "capture_start" + if start.SchemaVersion == 0 { + start.SchemaVersion = traceSchemaVersion + } + return w.writeRecord(start) +} + +func (w *CaptureWriter) writeRecord(value any) error { + record, err := json.Marshal(value) + if err != nil { + return err + } + record = append(record, '\n') + n, err := w.writer.Write(record) + if err != nil { + return err + } + if n != len(record) { + return io.ErrShortWrite + } + return w.Flush() +} + +func (w *CaptureWriter) Flush() error { + flusher, ok := w.writer.(interface{ Flush() error }) + if !ok { + return nil + } + return flusher.Flush() +} + +func (w *CaptureWriter) Close() error { + flushErr := w.Flush() + closer, ok := w.writer.(io.Closer) + if !ok { + return flushErr + } + closeErr := closer.Close() + if flushErr != nil { + return flushErr + } + return closeErr +} diff --git a/internal/connections/history/capture_writer_test.go b/internal/connections/history/capture_writer_test.go new file mode 100644 index 00000000..b264bf8e --- /dev/null +++ b/internal/connections/history/capture_writer_test.go @@ -0,0 +1,136 @@ +package history + +import ( + "bytes" + "io" + "testing" + "time" + + "github.com/planetscale/cli/internal/connections" + + qt "github.com/frankban/quicktest" +) + +func TestTraceFormatStability(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + + err := writer.Write(traceCapture()) + + c.Assert(err, qt.IsNil) + c.Assert(buffer.String(), qt.JSONEquals, map[string]any{ + "at": "2026-04-28T15:00:00Z", + "capture": map[string]any{ + "captured_at": "2026-04-28T15:00:00Z", + "connections": []any{map[string]any{ + "pid": 101, + "instance": "", + "username": "brett", + "application_name": "psql", + "client_addr": "127.0.0.1", + "state": "active", + "wait_event_type": "", + "wait_event": "", + "backend_type": "", + "duration": 5000000000, + "query_text": "SELECT * FROM widgets", + }}, + "sort": "xact_start", + }, + }) +} + +func TestCaptureWriterWritesCaptureStart(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + SchemaVersion: 1, + }), qt.IsNil) + c.Assert(buffer.String(), qt.Equals, `{"type":"capture_start","at":"2026-04-28T14:59:00Z","org":"acme","database":"prod","branch":"main","schema_version":1}`+"\n") +} + +// When a filter is active at capture time, the header records it so replay +// tooling can distinguish a partial snapshot from a complete branch view. +func TestCaptureWriterRecordsFilter(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + SchemaVersion: 1, + Filter: &CaptureFilter{Role: "replica"}, + }), qt.IsNil) + c.Assert(buffer.String(), qt.JSONEquals, map[string]any{ + "type": "capture_start", + "at": "2026-04-28T14:59:00Z", + "org": "acme", + "database": "prod", + "branch": "main", + "schema_version": 1, + "filter": map[string]any{ + "role": "replica", + }, + }) +} + +func TestCaptureStartRecordsVitessTarget(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{ + At: time.Date(2026, 4, 28, 14, 59, 0, 0, time.UTC), + Organization: "acme", + Database: "prod", + Branch: "main", + SchemaVersion: 1, + Target: &CaptureTarget{ + Keyspace: "commerce", + Shard: "-80", + }, + }), qt.IsNil) + + c.Assert(buffer.String(), qt.JSONEquals, map[string]any{ + "type": "capture_start", + "at": "2026-04-28T14:59:00Z", + "org": "acme", + "database": "prod", + "branch": "main", + "schema_version": 1, + "target": map[string]any{ + "keyspace": "commerce", + "shard": "-80", + }, + }) + + reader := NewCaptureReader(bytes.NewReader(buffer.Bytes())) + _, err := reader.Read() + c.Assert(err, qt.Equals, io.EOF) + start, ok := reader.CaptureStart() + c.Assert(ok, qt.IsTrue) + c.Assert(start.Target, qt.DeepEquals, &CaptureTarget{ + Keyspace: "commerce", + Shard: "-80", + }) +} + +func traceCapture() Capture { + capturedAt := time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC) + return NewCapture(connections.NewConnectionList(capturedAt, []connections.Connection{{ + PID: 101, + Username: "brett", + ApplicationName: "psql", + ClientAddr: "127.0.0.1", + State: "active", + Duration: 5 * time.Second, + QueryText: "SELECT * FROM widgets", + }}, connections.SortByTransactionStart)) +} diff --git a/internal/connections/history/replay_source.go b/internal/connections/history/replay_source.go new file mode 100644 index 00000000..df511205 --- /dev/null +++ b/internal/connections/history/replay_source.go @@ -0,0 +1,52 @@ +package history + +import ( + "context" + "errors" + "io" + + "github.com/planetscale/cli/internal/connections" +) + +// ReplaySource serves captured snapshots from a trace file as if they were +// live. +type ReplaySource struct { + captures []Capture + start CaptureStart + hasStart bool +} + +// NewReplaySource loads every Capture from r into memory. Returns an error +// when the trace contains no complete capture records — replay against an +// empty or header-only file is not meaningful. +func NewReplaySource(r io.Reader) (*ReplaySource, error) { + reader := NewCaptureReader(r) + captures, err := reader.ReadAll() + if err != nil { + return nil, err + } + if len(captures) == 0 { + return nil, errors.New("capture file contains no replayable snapshots") + } + start, hasStart := reader.CaptureStart() + return &ReplaySource{captures: captures, start: start, hasStart: hasStart}, nil +} + +// List returns the latest captured snapshot, sorted by mode. +func (s *ReplaySource) List(_ context.Context, mode connections.SortMode) (connections.ConnectionList, error) { + out := s.captures[len(s.captures)-1].List + connections.SortConnections(out.Connections, mode) + out.Sort = mode + return out, nil +} + +// Captures returns the loaded captures in order. Used by the CLI to seed the +// TUI's capture history so the operator can step the full timeline. +func (s *ReplaySource) Captures() []Capture { + return s.captures +} + +// CaptureStart returns the trace metadata header when the capture file has one. +func (s *ReplaySource) CaptureStart() (CaptureStart, bool) { + return s.start, s.hasStart +} diff --git a/internal/connections/history/replay_source_test.go b/internal/connections/history/replay_source_test.go new file mode 100644 index 00000000..12a23b58 --- /dev/null +++ b/internal/connections/history/replay_source_test.go @@ -0,0 +1,102 @@ +package history + +import ( + "bytes" + "context" + "strings" + "testing" + "time" + + "github.com/planetscale/cli/internal/connections" + + qt "github.com/frankban/quicktest" +) + +func TestReplaySourceReturnsLatestCapture(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + + at := time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC) + for i := 0; i < 3; i++ { + list := connections.NewConnectionList(at.Add(time.Duration(i)*time.Second), []connections.Connection{{ + PID: 100 + i, + Instance: "primary", + }}, connections.SortByTransactionStart) + c.Assert(writer.Write(NewCapture(list)), qt.IsNil) + } + + source, err := NewReplaySource(&buffer) + c.Assert(err, qt.IsNil) + + list, err := source.List(context.Background(), connections.SortByTransactionStart) + c.Assert(err, qt.IsNil) + c.Assert(list.Connections, qt.HasLen, 1) + c.Assert(list.Connections[0].PID, qt.Equals, 102) +} + +func TestReplaySourceAppliesSortOnEachList(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + at := time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC) + xact1 := at.Add(-30 * time.Second) + xact2 := at.Add(-10 * time.Second) + list := connections.NewConnectionList(at, []connections.Connection{ + {PID: 10, Instance: "primary", State: "active", XactStart: &xact2, Duration: 1 * time.Second}, + {PID: 20, Instance: "primary", State: "active", XactStart: &xact1, Duration: 30 * time.Second}, + }, connections.SortByTransactionStart) + c.Assert(writer.Write(NewCapture(list)), qt.IsNil) + + source, err := NewReplaySource(&buffer) + c.Assert(err, qt.IsNil) + + byXact, err := source.List(context.Background(), connections.SortByTransactionStart) + c.Assert(err, qt.IsNil) + c.Assert(byXact.Connections[0].PID, qt.Equals, 20) + c.Assert(byXact.Sort, qt.Equals, connections.SortByTransactionStart) + + byDuration, err := source.List(context.Background(), connections.SortByDuration) + c.Assert(err, qt.IsNil) + c.Assert(byDuration.Connections[0].PID, qt.Equals, 20) + c.Assert(byDuration.Sort, qt.Equals, connections.SortByDuration) +} + +// Replay must preserve InstanceMeta.Error so the partial-failure banner that +// reads it renders against captured data the same way it does live. +func TestReplaySourcePreservesInstancesMetadata(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + at := time.Date(2026, 4, 28, 15, 0, 0, 0, time.UTC) + list := connections.NewConnectionList(at, nil, connections.SortByTransactionStart) + list.Instances = []connections.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-1", Role: "replica", Error: "connection refused"}, + } + c.Assert(writer.Write(NewCapture(list)), qt.IsNil) + + source, err := NewReplaySource(&buffer) + c.Assert(err, qt.IsNil) + + replayed, err := source.List(context.Background(), connections.SortByTransactionStart) + c.Assert(err, qt.IsNil) + c.Assert(replayed.Instances, qt.HasLen, 2) + c.Assert(replayed.Instances[1].Error, qt.Equals, "connection refused") +} + +func TestNewReplaySourceRejectsEmptyTrace(t *testing.T) { + c := qt.New(t) + _, err := NewReplaySource(strings.NewReader("")) + c.Assert(err, qt.ErrorMatches, "capture file contains no replayable snapshots") +} + +func TestNewReplaySourceRejectsHeaderOnlyTrace(t *testing.T) { + c := qt.New(t) + var buffer bytes.Buffer + writer := NewCaptureWriter(&buffer) + c.Assert(writer.WriteCaptureStart(CaptureStart{At: time.Now()}), qt.IsNil) + + _, err := NewReplaySource(&buffer) + c.Assert(err, qt.ErrorMatches, "capture file contains no replayable snapshots") +} diff --git a/internal/connections/tui/blocking_graph.go b/internal/connections/tui/blocking_graph.go new file mode 100644 index 00000000..47dc3b38 --- /dev/null +++ b/internal/connections/tui/blocking_graph.go @@ -0,0 +1,231 @@ +package tui + +import ( + "fmt" + + live "github.com/planetscale/cli/internal/connections" +) + +type blockerRow struct { + Depth int + PID int + Instance string + Connection live.Connection + Present bool + Cycle bool + Truncated bool + Remaining int + // WaitOn is the lock type the waiter on this edge is blocked on (e.g. + // "Lock/transactionid"), so the wait-chain names what kind of lock the + // blocked session is waiting for, not just who holds it. + WaitOn string +} + +const maxBlockerRows = 32 + +func blockerRows(list live.ConnectionList, root live.Connection) []blockerRow { + byKey := connectionsByKey(list.Connections) + return walkBlockerRows(byKey, root, func(c live.Connection) []int { + return c.BlockedBy + }, func(parent, child live.Connection) live.Connection { + // Upstream: the parent waits on each child (the holder). + return parent + }) +} + +func blockingRows(list live.ConnectionList, root live.Connection) []blockerRow { + byKey := connectionsByKey(list.Connections) + downstream := make(map[int][]int) + for _, conn := range list.Connections { + for _, blockerPID := range conn.BlockedBy { + downstream[blockerPID] = append(downstream[blockerPID], conn.PID) + } + } + return walkBlockerRows(byKey, root, func(c live.Connection) []int { + return downstream[c.PID] + }, func(parent, child live.Connection) live.Connection { + // Downstream: each child is the session waiting on the parent. + return child + }) +} + +func connectionsByKey(connections []live.Connection) map[int]live.Connection { + byKey := make(map[int]live.Connection, len(connections)) + for _, conn := range connections { + byKey[conn.PID] = conn + } + return byKey +} + +func walkBlockerRows(byKey map[int]live.Connection, root live.Connection, nextPIDs func(live.Connection) []int, waiterFor func(parent, child live.Connection) live.Connection) []blockerRow { + var rows []blockerRow + path := map[int]bool{root.PID: true} + var walk func(conn live.Connection, depth int) + walk = func(conn live.Connection, depth int) { + for _, pid := range nextPIDs(conn) { + if len(rows) >= maxBlockerRows { + rows = append(rows, blockerRow{Depth: depth, Truncated: true, Remaining: countBlockerLevels(byKey, conn, clonePath(path), nextPIDs)}) + return + } + row := blockerRow{Depth: depth, PID: pid} + blocker, ok := byKey[pid] + if ok { + row.Connection = blocker + row.Instance = blocker.Instance + row.Present = true + row.WaitOn = waitOnText(waiterFor(conn, blocker)) + } + if path[pid] { + row.Cycle = true + rows = append(rows, row) + continue + } + rows = append(rows, row) + if ok { + path[pid] = true + walk(blocker, depth+1) + delete(path, pid) + } + } + } + walk(root, 0) + return collapseRootSuffixRows(rows) +} + +func clonePath(path map[int]bool) map[int]bool { + out := make(map[int]bool, len(path)) + for k, v := range path { + out[k] = v + } + return out +} + +func countBlockerLevels(byKey map[int]live.Connection, conn live.Connection, path map[int]bool, nextPIDs func(live.Connection) []int) int { + count := 0 + for _, pid := range nextPIDs(conn) { + count++ + blocker, ok := byKey[pid] + if !ok || path[pid] { + continue + } + path[pid] = true + count += countBlockerLevels(byKey, blocker, path, nextPIDs) + delete(path, pid) + } + return count +} + +// collapseRootSuffixRows drops a root-level subtree when its blocker chain +// is a proper suffix of another root's chain. Without this, the tree double- +// renders the shared tail: e.g. when sessions A and B both lead to the same +// upstream C, we'd otherwise show [A→C, B→C, C] — operators end up scrolling +// past redundant chains in real lock-contention snapshots where many waiters +// converge on one lock holder. +func collapseRootSuffixRows(rows []blockerRow) []blockerRow { + subtrees := rootSubtrees(rows) + if len(subtrees) < 2 { + return rows + } + signatures := make([][]int, len(subtrees)) + for i, subtree := range subtrees { + signatures[i] = blockerRowSignature(subtree) + } + + collapsed := make([]blockerRow, 0, len(rows)) + for i, subtree := range subtrees { + keep := true + for j, candidate := range signatures { + if i == j { + continue + } + if isProperIntSuffix(signatures[i], candidate) { + keep = false + break + } + } + if keep { + collapsed = append(collapsed, subtree...) + } + } + return collapsed +} + +func rootSubtrees(rows []blockerRow) [][]blockerRow { + var subtrees [][]blockerRow + for _, row := range rows { + if row.Depth == 0 || len(subtrees) == 0 { + subtrees = append(subtrees, []blockerRow{row}) + continue + } + last := len(subtrees) - 1 + subtrees[last] = append(subtrees[last], row) + } + return subtrees +} + +func blockerRowSignature(rows []blockerRow) []int { + signature := make([]int, 0, len(rows)) + for _, row := range rows { + value := row.PID + if row.Cycle { + value = -value + } + signature = append(signature, value) + } + return signature +} + +func isProperIntSuffix(value, candidate []int) bool { + if len(value) == 0 || len(value) >= len(candidate) { + return false + } + offset := len(candidate) - len(value) + for i := range value { + if value[i] != candidate[offset+i] { + return false + } + } + return true +} + +func detailBlockerRows(list live.ConnectionList, root live.Connection) []blockerRow { + rows := blockerRows(list, root) + return append(rows, blockingRows(list, root)...) +} + +// blockerLabel returns the text rendered for a node in the Blockers tab tree. +// PID / app / state are fixed-width so eye can column-scan; the trailing +// query preview is variable. +func blockerLabel(row blockerRow) string { + if row.Truncated { + return fmt.Sprintf("... (truncated, %d more levels)", row.Remaining) + } + if !row.Present { + return fmt.Sprintf("%-7d %s", row.PID, mutedStyle.Render("(session ended)")) + } + suffix := "" + if row.Cycle { + suffix = " " + mutedStyle.Render("(cycle)") + } + if row.WaitOn != "" { + suffix = " " + mutedStyle.Render("("+row.WaitOn+")") + suffix + } + return fmt.Sprintf( + "%-7d %-14s %-12s %s%s", + row.PID, + clipLine(emptyDash(row.Connection.ApplicationName), 14), + clipLine(emptyDash(stateText(row.Connection.State)), 12), + queryPreview(row.Connection.QueryText), + suffix, + ) +} + +// waitOnText returns the lock type the connection is waiting on (e.g. +// "Lock/transactionid"), or "" when it is not waiting on anything. +func waitOnText(conn live.Connection) string { + wait := waitText(conn) + if wait == "-" { + return "" + } + return wait +} diff --git a/internal/connections/tui/blocking_graph_test.go b/internal/connections/tui/blocking_graph_test.go new file mode 100644 index 00000000..0247ab54 --- /dev/null +++ b/internal/connections/tui/blocking_graph_test.go @@ -0,0 +1,138 @@ +package tui + +import ( + "strings" + "testing" + + qt "github.com/frankban/quicktest" + live "github.com/planetscale/cli/internal/connections" +) + +func TestBlockerRowsWalksUpstreamChain(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 10, BlockedBy: []int{20}, WaitEventType: "Lock", WaitEvent: "transactionid"}, + {PID: 20, BlockedBy: []int{30}}, + {PID: 30}, + } + list := live.ConnectionList{Connections: connections} + + rows := blockerRows(list, connections[0]) + c.Assert(len(rows), qt.Equals, 2) + c.Assert(rows[0].PID, qt.Equals, 20) + c.Assert(rows[0].Depth, qt.Equals, 0) + c.Assert(rows[0].Present, qt.IsTrue) + c.Assert(rows[0].WaitOn, qt.Equals, "Lock/transactionid") + c.Assert(rows[1].PID, qt.Equals, 30) + c.Assert(rows[1].Depth, qt.Equals, 1) +} + +func TestBlockerRowsMarksUnknownBlockerAbsent(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 10, BlockedBy: []int{99}}, + } + list := live.ConnectionList{Connections: connections} + + rows := blockerRows(list, connections[0]) + c.Assert(len(rows), qt.Equals, 1) + c.Assert(rows[0].PID, qt.Equals, 99) + c.Assert(rows[0].Present, qt.IsFalse) +} + +func TestBlockerRowsDetectsCycle(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 10, BlockedBy: []int{20}}, + {PID: 20, BlockedBy: []int{10}}, + } + list := live.ConnectionList{Connections: connections} + + rows := blockerRows(list, connections[0]) + c.Assert(len(rows), qt.Equals, 2) + c.Assert(rows[1].PID, qt.Equals, 10) + c.Assert(rows[1].Cycle, qt.IsTrue) +} + +func TestBlockingRowsWalksDownstream(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 30}, + {PID: 20, BlockedBy: []int{30}}, + {PID: 10, BlockedBy: []int{20}}, + } + list := live.ConnectionList{Connections: connections} + + rows := blockingRows(list, connections[0]) + c.Assert(len(rows), qt.Equals, 2) + c.Assert(rows[0].PID, qt.Equals, 20) + c.Assert(rows[1].PID, qt.Equals, 10) + c.Assert(rows[1].Depth, qt.Equals, 1) +} + +func TestCollapseRootSuffixRowsDropsRedundantSubtrees(t *testing.T) { + c := qt.New(t) + rows := []blockerRow{ + {Depth: 0, PID: 30}, + {Depth: 0, PID: 20}, + {Depth: 1, PID: 30}, + } + collapsed := collapseRootSuffixRows(rows) + c.Assert(len(collapsed), qt.Equals, 2) + c.Assert(collapsed[0].PID, qt.Equals, 20) + c.Assert(collapsed[1].PID, qt.Equals, 30) +} + +func TestWalkBlockerRowsTruncatesPastMaxDepth(t *testing.T) { + c := qt.New(t) + connections := make([]live.Connection, 0, maxBlockerRows+5) + connections = append(connections, live.Connection{PID: 1, BlockedBy: []int{2}}) + for i := 2; i < maxBlockerRows+5; i++ { + connections = append(connections, live.Connection{PID: i, BlockedBy: []int{i + 1}}) + } + connections = append(connections, live.Connection{PID: maxBlockerRows + 5}) + list := live.ConnectionList{Connections: connections} + + rows := blockerRows(list, connections[0]) + c.Assert(len(rows) >= maxBlockerRows, qt.IsTrue) + last := rows[len(rows)-1] + c.Assert(last.Truncated, qt.IsTrue) + c.Assert(last.Remaining > 0, qt.IsTrue) +} + +func TestBlockerLabelDescribesAbsentAndCycleRows(t *testing.T) { + c := qt.New(t) + c.Assert(blockerLabel(blockerRow{PID: 99}), qt.Contains, "session ended") + cycleRow := blockerRow{ + PID: 77, + Present: true, + Cycle: true, + Connection: live.Connection{ApplicationName: "psql", State: "active", QueryText: "SELECT 1"}, + } + c.Assert(blockerLabel(cycleRow), qt.Contains, "(cycle)") + c.Assert(blockerLabel(cycleRow), qt.Contains, "psql") + + waitRow := blockerRow{ + PID: 242763, + Present: true, + WaitOn: "Lock/transactionid", + Connection: live.Connection{ + State: "active", + QueryText: "SELECT pg_sleep(86400);", + }, + } + c.Assert(stripANSI(blockerLabel(waitRow)), qt.Contains, "Lock/transactionid") + + truncatedRow := blockerRow{ + PID: 242763, + Present: true, + Connection: live.Connection{ + ApplicationName: "qa-block-holder", + State: "active", + QueryText: "SELECT pg_sleep(86400);", + }, + } + label := blockerLabel(truncatedRow) + c.Assert(strings.Contains(label, "…"), qt.IsTrue) + c.Assert(strings.Contains(label, "qa-block-holde "), qt.IsFalse) +} diff --git a/internal/connections/tui/capture.go b/internal/connections/tui/capture.go new file mode 100644 index 00000000..e7720393 --- /dev/null +++ b/internal/connections/tui/capture.go @@ -0,0 +1,89 @@ +package tui + +import ( + "fmt" + + tea "github.com/charmbracelet/bubbletea" + "github.com/planetscale/cli/internal/connections/history" +) + +type CaptureOpener func() (*history.CaptureWriter, string, error) + +type CaptureControl struct { + Open CaptureOpener + Writer *history.CaptureWriter + Path string +} + +func (c *CaptureControl) Close() error { + if c == nil || c.Writer == nil { + return nil + } + err := c.Writer.Close() + c.Writer = nil + c.Path = "" + return err +} + +func (m Model) toggleCapture() (Model, tea.Cmd) { + if m.capture == nil { + m.lastError = "capture is not available in this mode" + return m, nil + } + if m.capture.Writer != nil { + closed := m.capture.Path + if err := m.capture.Close(); err != nil { + m.lastError = fmt.Sprintf("stop capture: %v", err) + m = m.clearNotice() + return m, nil + } + m.lastError = "" + if closed != "" { + return m.setNotice("stopped capture: " + closed) + } + return m.setNotice("stopped capture") + } + if m.capture.Open == nil { + m.lastError = "capture is not available in this mode" + return m, nil + } + writer, path, err := m.capture.Open() + if err != nil { + m.lastError = fmt.Sprintf("start capture: %v", err) + m = m.clearNotice() + return m, nil + } + m.capture.Writer = writer + m.capture.Path = path + if err := m.writeCaptureBackfill(); err != nil { + m.captureStopped = "capture stopped: " + err.Error() + _ = m.capture.Close() + m = m.clearNotice() + return m, nil + } + m.captureStopped = "" + m.lastError = "" + return m.setNotice("capturing to " + path) +} + +func (m Model) writeCaptureBackfill() error { + for _, list := range m.samples.All() { + if err := m.capture.Writer.Write(history.NewCapture(list)); err != nil { + return err + } + } + return nil +} + +func (m Model) captureStatusText() string { + if m.capture == nil { + return "" + } + if m.capture.Writer == nil { + return "rec off" + } + if m.capture.Path != "" { + return "rec " + m.capture.Path + } + return "rec on" +} diff --git a/internal/connections/tui/connection_capabilities.go b/internal/connections/tui/connection_capabilities.go new file mode 100644 index 00000000..bc2278d1 --- /dev/null +++ b/internal/connections/tui/connection_capabilities.go @@ -0,0 +1,113 @@ +package tui + +import live "github.com/planetscale/cli/internal/connections" + +type actionTargetField int + +const ( + actionTargetUnset actionTargetField = iota + actionTargetQueryID + actionTargetTransactionID + actionTargetConnectionID +) + +// ActionRequirement describes the row identifier required before an action can +// be offered for a selected connection. A zero requirement means the action is +// not available for the connection source. +type ActionRequirement struct { + field actionTargetField +} + +var ( + // ActionTargetQueryID requires a query_id before the action can run. + ActionTargetQueryID = ActionRequirement{field: actionTargetQueryID} + // ActionTargetTransactionID requires a transaction_id before the action can run. + ActionTargetTransactionID = ActionRequirement{field: actionTargetTransactionID} + // ActionTargetConnectionID requires a connection_id before the action can run. + ActionTargetConnectionID = ActionRequirement{field: actionTargetConnectionID} +) + +// ConnectionCapabilities is the TUI capability contract for a connection source. +// It tells the view which selected-row IDs are needed for each action, whether +// blocker navigation should be shown, and which action controls should be +// visible. +// +// It does not execute actions; the model still calls ConnectionsClient. The +// support value only gates visible controls and validates that the selected row +// has the identifier the backend expects. +type ConnectionCapabilities struct { + CancelQuery ActionRequirement + TerminateTransaction ActionRequirement + TerminateConnection ActionRequirement + ShowBlockers bool + + configured bool +} + +// DefaultConnectionCapabilities is the Postgres capability set. A zero +// ConnectionCapabilities also resolves to this default so existing tests and +// model construction keep Postgres behavior unless a source opts into different +// capabilities. +func DefaultConnectionCapabilities() ConnectionCapabilities { + return ConnectionCapabilities{ + CancelQuery: ActionTargetQueryID, + TerminateTransaction: ActionTargetTransactionID, + TerminateConnection: ActionTargetConnectionID, + ShowBlockers: true, + configured: true, + } +} + +func (s ConnectionCapabilities) effective() ConnectionCapabilities { + if !s.configured && s.isZero() { + return DefaultConnectionCapabilities() + } + return s +} + +func (s ConnectionCapabilities) isZero() bool { + return s.CancelQuery == (ActionRequirement{}) && + s.TerminateTransaction == (ActionRequirement{}) && + s.TerminateConnection == (ActionRequirement{}) && + !s.ShowBlockers +} + +func (s ConnectionCapabilities) supports(kind actionKind) bool { + requirement := s.requirement(kind) + return requirement.field != actionTargetUnset +} + +func (s ConnectionCapabilities) requirement(kind actionKind) ActionRequirement { + s = s.effective() + switch kind { + case actionCancelQuery: + return s.CancelQuery + case actionTerminateTxn: + return s.TerminateTransaction + case actionTerminateConn: + return s.TerminateConnection + default: + return ActionRequirement{} + } +} + +func (s ConnectionCapabilities) missingActionID(kind actionKind, target live.ActionTarget) string { + requirement := s.requirement(kind) + switch requirement.field { + case actionTargetQueryID: + if live.DerefString(target.QueryID) == "" { + return "no active query to cancel on this connection" + } + case actionTargetTransactionID: + if live.DerefString(target.TransactionID) == "" { + return "no open transaction to terminate on this connection" + } + case actionTargetConnectionID: + if live.DerefString(target.ConnectionID) == "" { + return "no connection id available to terminate this connection" + } + default: + return "action is not supported for this connection" + } + return "" +} diff --git a/internal/connections/tui/detail.go b/internal/connections/tui/detail.go new file mode 100644 index 00000000..49518b54 --- /dev/null +++ b/internal/connections/tui/detail.go @@ -0,0 +1,454 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/key" + lgtree "github.com/charmbracelet/lipgloss/tree" + "github.com/charmbracelet/x/ansi" + live "github.com/planetscale/cli/internal/connections" +) + +type detailTab string + +const ( + tabQuery detailTab = "query" + tabBlockers detailTab = "blockers" + + detailLabelWidth = 16 +) + +type detailState struct { + List live.ConnectionList + Subject live.Connection + SubjectFound bool + Tab detailTab + BlockerSelection int + QueryOffset int + Width int + Height int + Paused bool + Refresh refreshDotState + ReadOnlyActions bool + Replay bool + LastError string + Notice string + CaptureStopped string // sticky reason when capture writer detaches on error; "" when capture is healthy or absent + CaptureStatus string + Confirm string + Now time.Time + Interval time.Duration + StepPos int // 1-based position when stepping; 0 means following live + StepTotal int // total samples held in history + Target Target + DisplayPreset connectionDisplayPreset + Capabilities ConnectionCapabilities +} + +func renderDetail(state detailState) string { + state = normalizeDetailState(state) + width := tableWidth(state.Width) + height := state.Height + if height <= 0 { + height = 24 + } + + headerLines := []string{ + clipLine(renderDetailHeader(state), width), + "", + clipLine(renderDetailTabs(state.Tab, state.Capabilities), width), + "", + } + + footerLines := strings.Split(renderDetailFooter(state), "\n") + bodyHeight := height - len(headerLines) - len(footerLines) + if bodyHeight < 0 { + bodyHeight = 0 + } + var bodyLines []string + if !state.SubjectFound { + bodyLines = append(bodyLines, + clipLine(mutedStyle.Render("This connection is no longer in the live snapshot — it closed or was terminated."), width), + clipLine(mutedStyle.Render("Press q or esc to return to the connection list."), width), + ) + } else { + switch state.Tab { + case tabBlockers: + bodyLines = append(bodyLines, renderBlockers(state, width, bodyHeight)...) + default: + bodyLines = append(bodyLines, renderQuery(state.Subject, state.DisplayPreset, width, bodyHeight, state.QueryOffset)...) + } + } + + if len(bodyLines) > bodyHeight { + bodyLines = bodyLines[:bodyHeight] + } + lines := append(headerLines, bodyLines...) + for len(lines)+len(footerLines) < height { + lines = append(lines, "") + } + lines = append(lines, footerLines...) + return strings.Join(lines, "\n") +} + +func normalizeDetailState(state detailState) detailState { + state.Capabilities = state.Capabilities.effective() + state.Tab = effectiveDetailTab(state.Tab, state.Capabilities) + return state +} + +func renderDetailHeader(state detailState) string { + if !state.SubjectFound { + if target := renderTarget(state.Target); target != "" { + return target + " | " + mutedStyle.Render("connection ended") + } + return "live connection detail | " + mutedStyle.Render("connection ended") + } + // The header carries only identity + operational state. The per-connection + // metadata (user, app, state, duration, wait, IDs) lives in the body record + // below, so the header stays short instead of overflowing and clipping. + parts := []string{} + if target := renderTarget(state.Target); target != "" { + parts = append(parts, target) + } + parts = append(parts, refreshIndicator(state.Refresh)+" "+fmt.Sprintf("pid %d", state.Subject.PID)) + if state.Subject.Instance != "" { + parts = append(parts, "on "+state.Subject.Instance) + } + if state.Paused { + parts = append(parts, pausedStyle.Render("paused")) + } + if state.StepPos > 0 { + parts = append(parts, fmt.Sprintf("step %d/%d", state.StepPos, state.StepTotal)) + } + if state.CaptureStopped != "" { + parts = append(parts, errorStyle.Render(state.CaptureStopped)) + } + if state.CaptureStatus != "" { + parts = append(parts, state.CaptureStatus) + } + if token := capturedToken(state.List.CapturedAt, state.Now, state.Interval); token != "" { + parts = append(parts, token) + } + return strings.Join(parts, " | ") +} + +func renderDetailTabs(active detailTab, capabilities ConnectionCapabilities) string { + if !capabilities.ShowBlockers { + return tabActiveStyle.Render("[query]") + } + blockers, query := "blockers", "query" + if active == tabBlockers { + blockers = tabActiveStyle.Render("[blockers]") + } else { + query = tabActiveStyle.Render("[query]") + } + return blockers + " " + query +} + +func effectiveDetailTab(active detailTab, capabilities ConnectionCapabilities) detailTab { + if active == tabBlockers && !capabilities.ShowBlockers { + return tabQuery + } + return active +} + +// connectionRecordLines renders the full per-connection detail as a MySQL +// `\G`-style vertical record: one aligned field per line (from the shared +// live.Connection.HumanFields), then the wrapped query text. This fills the +// detail body — which otherwise showed only the query over a sea of whitespace +// — and reuses the exact field set the agent-cli `list --format human` emits. +func connectionRecordLines(conn live.Connection, preset connectionDisplayPreset, width int) []string { + return renderConnectionRecord(connectionDetailFields(preset, conn), conn.QueryText, width) +} + +func renderConnectionRecord(fields [][2]string, query string, width int) []string { + var lines []string + for _, field := range fields { + label := mutedStyle.Render(fmt.Sprintf("%-*s", detailLabelWidth, field[0]+":")) + values := detailFieldValueLines(field[1], width-detailLabelWidth-1) + for i, value := range values { + if i == 0 { + lines = append(lines, clipLine(label+" "+value, width)) + continue + } + lines = append(lines, clipLine(strings.Repeat(" ", detailLabelWidth+1)+value, width)) + } + } + lines = append(lines, headerStyle.Render("query:")) + queryLines := queryDisplayLines(query, width) + if len(queryLines) == 0 { + queryLines = []string{mutedStyle.Render("no query")} + } + lines = append(lines, queryLines...) + return lines +} + +func detailFieldValueLines(value string, width int) []string { + rendered := detailFieldValue(value) + if strings.TrimSpace(value) == "" || width <= 0 || ansi.StringWidth(rendered) <= width { + return []string{rendered} + } + return wrapLines(value, width) +} + +func detailFieldValue(value string) string { + if strings.TrimSpace(value) == "" { + return mutedStyle.Render("none") + } + return value +} + +func connectionDetailFields(preset connectionDisplayPreset, conn live.Connection) [][2]string { + if preset != connectionDisplayProcesslist { + return conn.HumanFields() + } + + return [][2]string{ + {"pid", fmt.Sprint(conn.PID)}, + {"tablet", conn.Instance}, + {"state", conn.State}, + {"duration", processlistDetailDuration(conn)}, + {"user", conn.Username}, + {"database", conn.DatabaseName}, + {"client_addr", conn.ClientAddr}, + {"connection_id", live.DerefString(conn.ConnectionID)}, + {"query_id", live.DerefString(conn.QueryID)}, + } +} + +func processlistDetailDuration(conn live.Connection) string { + if conn.Duration <= 0 && !processlistConnectionHasWork(conn) { + return "" + } + return conn.Duration.String() +} + +func renderQuery(conn live.Connection, preset connectionDisplayPreset, width, height, offset int) []string { + if height <= 0 { + return nil + } + wrapped := connectionRecordLines(conn, preset, width) + if len(wrapped) == 0 { + return []string{mutedStyle.Render("no query")} + } + if len(wrapped) > height-1 { + offset = clampInt(offset, 0, maxQueryOffset(len(wrapped), height)) + end := offset + height - 1 + visible := append([]string{}, wrapped[offset:end]...) + endLine := end + if endLine < offset+1 { + endLine = offset + 1 + } + visible = append(visible, mutedStyle.Render(fmt.Sprintf("lines %d-%d/%d", offset+1, endLine, len(wrapped)))) + return visible + } + return wrapped +} + +func renderBlockers(state detailState, width, height int) []string { + lines := []string{headerStyle.Render("BLOCKED BY")} + selection := 0 + selectedLine := -1 + + blockedBy := blockerRows(state.List, state.Subject) + if len(blockedBy) == 0 { + lines = append(lines, mutedStyle.Render("No upstream blocker")) + } else { + if state.BlockerSelection < len(blockedBy) { + selectedLine = len(lines) + state.BlockerSelection + } + rendered := renderBlockerTreeRows(blockedBy, state.BlockerSelection, selection, width) + lines = append(lines, rendered...) + selection += len(blockedBy) + } + + // Two blank lines separate the upstream and downstream trees so the + // BLOCKING header reads as a section break, not just another tree row, + // even when the upstream tree has dozens of entries. + lines = append(lines, "", "", headerStyle.Render("BLOCKING")) + blocking := blockingRows(state.List, state.Subject) + if len(blocking) == 0 { + lines = append(lines, mutedStyle.Render("Not blocking other connections")) + } else { + if state.BlockerSelection >= selection { + selectedLine = len(lines) + (state.BlockerSelection - selection) + } + lines = append(lines, renderBlockerTreeRows(blocking, state.BlockerSelection, selection, width)...) + } + return visibleDetailBody(lines, selectedLine, height) +} + +func visibleDetailBody(lines []string, selectedLine, height int) []string { + if height <= 0 || len(lines) <= height { + return lines + } + if selectedLine < 0 { + return lines[:height] + } + start := centeredViewportStart(selectedLine, len(lines), height) + return lines[start : start+height] +} + +func renderDetailFooter(state detailState) string { + var lines []string + if state.Confirm != "" { + lines = append(lines, errorStyle.Render(clipLine(state.Confirm, tableWidth(state.Width)))) + } else if state.LastError != "" { + lines = append(lines, errorStyle.Render(clipLine("error: "+state.LastError, tableWidth(state.Width)))) + } else if state.Notice != "" { + lines = append(lines, clipLine("status: "+state.Notice, tableWidth(state.Width))) + } + if status := renderDetailSelectedStatus(state); status != "" { + lines = append(lines, status) + } + lines = append(lines, renderDetailHelp(state)) + return strings.Join(lines, "\n") +} + +func renderDetailSelectedStatus(state detailState) string { + if !state.SubjectFound { + return "" + } + width := tableWidth(state.Width) + if state.Tab == tabBlockers { + rows := detailBlockerRows(state.List, state.Subject) + if idx := state.BlockerSelection; idx >= 0 && idx < len(rows) { + row := rows[idx] + if row.Present { + return renderConnectionStatus("→ selected blocker", row.Connection, width) + } + return clipLine(fmt.Sprintf("→ selected blocker pid %d | %s", row.PID, mutedStyle.Render("connection ended")), width) + } + } + return renderConnectionStatus("→ selected connection", state.Subject, width) +} + +func renderConnectionStatus(label string, conn live.Connection, width int) string { + status := fmt.Sprintf("%s pid %d", label, conn.PID) + query := sanitizeFooterText(strings.Join(strings.Fields(conn.QueryText), " ")) + if query != "" { + status += " | " + query + } + return clipLine(status, width) +} + +func renderDetailHelp(state detailState) string { + labels := actionHelpLabels(state.DisplayPreset) + cancelHelp := labels.cancel + killTxnHelp := "kill transaction" + terminateHelp := labels.terminateConn + support := state.Capabilities.effective() + + bindings := []key.Binding{} + if support.ShowBlockers { + bindings = append(bindings, key.NewBinding(key.WithKeys("left", "right"), key.WithHelp("left/right", "tabs"))) + } + if state.Tab == tabBlockers { + bindings = append(bindings, key.NewBinding(key.WithKeys("enter"), key.WithHelp("enter", "open selected"))) + } + if !state.Replay && !state.Paused { + bindings = append(bindings, key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh"))) + } + pauseLabel := "pause" + if state.Paused { + pauseLabel = "resume" + } + bindings = append(bindings, key.NewBinding(key.WithKeys("space"), key.WithHelp("space", pauseLabel))) + if !state.Replay { + bindings = append(bindings, key.NewBinding(key.WithKeys("C"), key.WithHelp("C", "capture"))) + } + bindings = append(bindings, key.NewBinding(key.WithKeys("q", "esc"), key.WithHelp("q/esc", "back"))) + // An ended connection has nothing to act on, so suppress the destructive + // action hints rather than advertise keys that no-op. + if state.SubjectFound && !state.ReadOnlyActions && !state.Replay { + if support.supports(actionCancelQuery) { + bindings = append(bindings, key.NewBinding(key.WithKeys("c"), key.WithHelp("c", cancelHelp))) + } + if support.supports(actionTerminateTxn) { + bindings = append(bindings, key.NewBinding(key.WithKeys("k"), key.WithHelp("k", killTxnHelp))) + } + if support.supports(actionTerminateConn) { + bindings = append(bindings, key.NewBinding(key.WithKeys("K"), key.WithHelp("shift+K", terminateHelp))) + } + } + bindings = append(bindings, + key.NewBinding(key.WithKeys("?"), key.WithHelp("?", "help")), + key.NewBinding(key.WithKeys("ctrl+c"), key.WithHelp("ctrl+c", "quit")), + ) + return renderWrappedHelp(bindings, " | ", tableWidth(state.Width)) +} + +func renderBlockerTreeRows(rows []blockerRow, selected, selectionOffset, width int) []string { + if isSingleLevelBlockerTree(rows) { + lines := make([]string, 0, len(rows)) + for i, row := range rows { + lines = append(lines, clipLine("• "+blockerTreeLabel(row, selectionOffset+i == selected), width)) + } + return lines + } + + tree := lgtree.New() + stack := []*lgtree.Tree{tree} + for i, row := range rows { + depth := row.Depth + if depth < 0 { + depth = 0 + } + if depth >= len(stack) { + depth = len(stack) - 1 + } + node := lgtree.Root(blockerTreeLabel(row, selectionOffset+i == selected)) + stack[depth].Child(node) + stack = append(stack[:depth+1], node) + } + rendered := tree.String() + if rendered == "" { + return nil + } + lines := strings.Split(rendered, "\n") + for i, line := range lines { + lines[i] = clipLine(line, width) + } + return lines +} + +func isSingleLevelBlockerTree(rows []blockerRow) bool { + if len(rows) == 0 { + return false + } + for _, row := range rows { + if row.Depth > 0 { + return false + } + } + return true +} + +func blockerTreeLabel(row blockerRow, selected bool) string { + label := blockerLabel(row) + if selected { + // Use the same ▶ cursor as the table so the selection affordance is + // consistent across views. + return selectedRowStyle.Render("▶ " + label) + } + return label +} + +func wrapLines(text string, width int) []string { + if width <= 0 { + return []string{text} + } + runes := []rune(text) + var lines []string + for len(runes) > width { + lines = append(lines, string(runes[:width])) + runes = runes[width:] + } + if len(runes) > 0 { + lines = append(lines, string(runes)) + } + return lines +} diff --git a/internal/connections/tui/detail_test.go b/internal/connections/tui/detail_test.go new file mode 100644 index 00000000..643fb87e --- /dev/null +++ b/internal/connections/tui/detail_test.go @@ -0,0 +1,690 @@ +package tui + +import ( + "context" + "strings" + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/x/ansi" + qt "github.com/frankban/quicktest" + live "github.com/planetscale/cli/internal/connections" +) + +func newDetailModel(t *testing.T, stub *clientStub, list live.ConnectionList) Model { + t.Helper() + model := NewModel(context.Background(), stub, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 30}) + model = updated.(Model) + updated, _ = model.Update(listMsg{list: list}) + return updated.(Model) +} + +func TestEnterFromTableOpensDetailOnQueryTab(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + QueryText: "SELECT * FROM widgets WHERE id = 7", + }}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "pid 10") + c.Assert(view, qt.Contains, "on primary") + c.Assert(view, qt.Contains, "SELECT * FROM widgets WHERE id = 7") + c.Assert(view, qt.Contains, "blockers") + c.Assert(view, qt.Contains, "[query]") + c.Assert(view, qt.Not(qt.Contains), "V Query") + c.Assert(view, qt.Not(qt.Contains), "B Blockers") +} + +func TestTableUppercaseVOpensDetailOnQueryTab(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + QueryText: "SELECT * FROM widgets WHERE id = 7", + }}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("V")}) + got := updated.(Model) + + c.Assert(got.detailOpen, qt.IsTrue) + c.Assert(got.detailTab, qt.Equals, tabQuery) + c.Assert(got.View(), qt.Contains, "SELECT * FROM widgets WHERE id = 7") +} + +func TestTableLowercaseVOpensDetailOnQueryTab(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + QueryText: "SELECT * FROM widgets WHERE id = 7", + }}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("v")}) + got := updated.(Model) + + c.Assert(got.detailOpen, qt.IsTrue) + c.Assert(got.detailTab, qt.Equals, tabQuery) + c.Assert(got.View(), qt.Contains, "SELECT * FROM widgets WHERE id = 7") +} + +func TestTableUppercaseBOpensDetailOnBlockersTab(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}, QueryText: "blocked"}, + {PID: 20, Instance: "primary", QueryText: "blocker holding lock"}, + }, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("B")}) + got := updated.(Model) + + c.Assert(got.detailOpen, qt.IsTrue) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + c.Assert(got.View(), qt.Contains, "BLOCKED BY") +} + +func TestDetailTabKeysUseArrowsAndBlockersShortcut(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + QueryText: "SELECT 1", + }}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'b'}}) + got := updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRight}) + got = updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabQuery) +} + +func TestDetailVDoesNotSwitchTabs(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + QueryText: "SELECT 1", + }}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'b'}}) + got := updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'v'}}) + + c.Assert(updated.(Model).detailTab, qt.Equals, tabBlockers) +} + +func TestEscapeReturnsToTable(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEsc}) + + view := updated.(Model).View() + c.Assert(view, qt.Contains, "connections 1") +} + +func TestDetailSwitchesBetweenQueryAndBlockersTabs(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}, QueryText: "blocked"}, + {PID: 20, Instance: "primary", QueryText: "blocker holding lock"}, + }, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "BLOCKED BY") + c.Assert(view, qt.Contains, "20") + c.Assert(view, qt.Contains, "blocker holding lock") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRight}) + view = updated.(Model).View() + c.Assert(view, qt.Contains, "blocked") + c.Assert(view, qt.Not(qt.Contains), "BLOCKED BY") +} + +func TestDetailArrowTabSwitchResetsQueryOffset(t *testing.T) { + c := qt.New(t) + query := "select a, b, c\nfrom t\nwhere a = 1\nand b = 2\nand c = 3\nand d = 4\nand e = 5\nand f = 6\norder by a\nlimit 10" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary", QueryText: query}}, live.SortByTransactionStart) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 12}) + updated, _ = updated.(Model).Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + for i := 0; i < 4; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + got := updated.(Model) + c.Assert(got.queryOffset > 0, qt.IsTrue) + got.detailTab = tabBlockers + updated = got + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyLeft}) + + c.Assert(updated.(Model).detailTab, qt.Equals, tabQuery) + c.Assert(updated.(Model).queryOffset, qt.Equals, 0) +} + +func TestDetailExplicitTabKeysResetQueryOffset(t *testing.T) { + c := qt.New(t) + query := "select a, b, c\nfrom t\nwhere a = 1\nand b = 2\nand c = 3\nand d = 4\nand e = 5\nand f = 6\norder by a\nlimit 10" + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}, QueryText: query}, + {PID: 20, Instance: "primary", QueryText: "blocker holding lock"}, + }, live.SortByTransactionStart) + + for _, key := range []string{"b", "B"} { + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 12}) + updated, _ = updated.(Model).Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + for i := 0; i < 4; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + got := updated.(Model) + c.Assert(got.queryOffset > 0, qt.IsTrue) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}) + + got = updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + c.Assert(got.queryOffset, qt.Equals, 0) + } + + for _, key := range []tea.KeyType{tea.KeyLeft, tea.KeyRight} { + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 12}) + updated, _ = updated.(Model).Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + for i := 0; i < 4; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + got := updated.(Model) + c.Assert(got.queryOffset > 0, qt.IsTrue) + got.detailTab = tabBlockers + + updated, _ = got.Update(tea.KeyMsg{Type: key}) + + got = updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabQuery) + c.Assert(got.queryOffset, qt.Equals, 0) + } +} + +func TestDetailBlockerTabClampsOutOfRangeSelection(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20, 30}, QueryText: "blocked"}, + {PID: 20, Instance: "primary", QueryText: "first blocker"}, + {PID: 30, Instance: "primary", QueryText: "second blocker"}, + }, live.SortByTransactionStart) + + for _, key := range []string{"b", "B"} { + model := newDetailModel(t, &clientStub{}, list) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + got := updated.(Model) + got.blockerRow = 99 + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}) + + got = updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + c.Assert(got.blockerRow, qt.Equals, 1) + } +} + +func TestDetailDownArrowSelectsNextBlockerOnBlockersTab(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, BlockedBy: []int{20}}, + {PID: 20, BlockedBy: []int{30}}, + {PID: 30}, + }, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + got := updated.(Model) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + c.Assert(got.blockerRow, qt.Equals, 0) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyDown}) + c.Assert(updated.(Model).blockerRow, qt.Equals, 1) +} + +func TestDetailCancelDispatchesActionForSubject(t *testing.T) { + c := qt.New(t) + xid := "10-1" + qid := "10-2" + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + Instance: "primary", + TransactionID: &xid, + QueryID: &qid, + }}, live.SortByTransactionStart) + stub := &clientStub{} + model := newDetailModel(t, stub, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + + got := updated.(Model) + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingKind, qt.Equals, actionCancelQuery) + c.Assert(stub.cancelCalls, qt.Equals, 0) + + updated, cmd = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")}) + c.Assert(cmd, qt.IsNotNil) + cmd() + c.Assert(stub.cancelCalls, qt.Equals, 1) + c.Assert(stub.lastTarget.PID, qt.Equals, 10) + c.Assert(stub.lastTarget.Instance, qt.Equals, "primary") + c.Assert(stub.lastTarget.QueryID, qt.Not(qt.IsNil)) + c.Assert(*stub.lastTarget.QueryID, qt.Equals, "10-2") + _ = updated +} + +func TestDetailCancelOnBlockersTabTargetsSelectedBlocker(t *testing.T) { + c := qt.New(t) + bxid := "20-99" + bqid := "20-q" + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}}, + {PID: 20, Instance: "primary", TransactionID: &bxid, QueryID: &bqid}, + }, live.SortByTransactionStart) + stub := &clientStub{} + model := newDetailModel(t, stub, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + + got := updated.(Model) + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingKind, qt.Equals, actionCancelQuery) + c.Assert(stub.cancelCalls, qt.Equals, 0) + + updated, cmd = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")}) + c.Assert(cmd, qt.IsNotNil) + cmd() + c.Assert(stub.cancelCalls, qt.Equals, 1) + c.Assert(stub.lastTarget.PID, qt.Equals, 20) + c.Assert(*stub.lastTarget.TransactionID, qt.Equals, "20-99") + _ = updated +} + +func TestDetailKillTransactionPromptsConfirm(t *testing.T) { + c := qt.New(t) + xid := "10-x" + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", TransactionID: &xid}, + }, live.SortByTransactionStart) + stub := &clientStub{} + model := newDetailModel(t, stub, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + + got := updated.(Model) + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingKind, qt.Equals, actionTerminateTxn) + view := got.View() + c.Assert(view, qt.Contains, "Terminate transaction on PID 10") +} + +func TestDetailEnterOnBlockersTabReanchorsToBlocker(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}}, + {PID: 20, Instance: "primary"}, + }, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + got := updated.(Model) + c.Assert(got.detailPID, qt.Equals, 20) +} + +func TestDetailRendersConnectionEndedWhenSubjectMissing(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary"}}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + + updated, _ = updated.(Model).Update(listMsg{list: live.ConnectionList{}}) + + view := stripANSI(updated.(Model).View()) + c.Assert(view, qt.Contains, "connection ended") + c.Assert(view, qt.Contains, "Press q or esc to return") + + endedFooter := stripANSI(renderDetailHelp(detailState{SubjectFound: false, Width: 120, Height: 24})) + c.Assert(strings.Contains(endedFooter, "cancel"), qt.IsFalse) + c.Assert(strings.Contains(endedFooter, "kill"), qt.IsFalse) + c.Assert(strings.Contains(endedFooter, "terminate"), qt.IsFalse) + c.Assert(strings.Contains(endedFooter, "q/esc back"), qt.IsTrue) + + liveFooter := stripANSI(renderDetailHelp(detailState{SubjectFound: true, Width: 200})) + c.Assert(strings.Contains(liveFooter, "cancel query"), qt.IsTrue) + c.Assert(strings.Contains(liveFooter, "kill transaction"), qt.IsTrue) +} + +func TestRenderDetailFieldsShowNoneForEmptyValues(t *testing.T) { + c := qt.New(t) + view := stripANSI(renderDetail(detailState{ + List: live.NewConnectionList(time.Now(), nil, live.SortByDuration), + Subject: live.Connection{ + PID: 101, + Duration: 42 * time.Second, + QueryText: "SELECT 1", + }, + SubjectFound: true, + Tab: tabQuery, + DisplayPreset: connectionDisplayProcesslist, + Capabilities: DefaultConnectionCapabilities(), + Width: 120, + Height: 24, + })) + + for _, field := range []string{ + "tablet:", + "state:", + "user:", + "database:", + "client_addr:", + "connection_id:", + "query_id:", + } { + c.Assert(view, qt.Contains, field) + c.Assert(view, qt.Contains, field+" ") + } + c.Assert(strings.Count(view, "none"), qt.Equals, 7) +} + +func TestRenderDetailWrapsLongIDFields(t *testing.T) { + c := qt.New(t) + queryID := "primary-1234567890abcdef-primary-1234567890abcdef-query" + view := stripANSI(renderDetail(detailState{ + List: live.NewConnectionList(time.Now(), nil, live.SortByTransactionStart), + Subject: live.Connection{ + PID: 101, + QueryID: &queryID, + }, + SubjectFound: true, + Tab: tabQuery, + Capabilities: DefaultConnectionCapabilities(), + Width: 44, + Height: 24, + })) + + c.Assert(view, qt.Contains, "query_id:") + c.Assert(view, qt.Contains, "primary-1234567890abcdef-pr") + c.Assert(view, qt.Contains, "imary-1234567890abcdef-quer") + c.Assert(view, qt.Contains, " y") + c.Assert(view, qt.Not(qt.Contains), "query_id: primary-1234567890abcdef-prima…") +} + +func TestDetailEndedStateKeepsTargetAndActionFeedback(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary"}}, live.SortByTransactionStart) + model := newDetailModel(t, &clientStub{}, list).WithTarget(Target{ + Database: "prod", + Branch: "main", + }) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + updated, _ = updated.(Model).Update(listMsg{list: live.ConnectionList{}}) + + view := stripANSI(updated.(Model).View()) + c.Assert(view, qt.Contains, "prod / main | connection ended") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + got := stripANSI(updated.(Model).View()) + c.Assert(got, qt.Contains, "connection ended — actions unavailable; esc to go back") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + got = stripANSI(updated.(Model).View()) + c.Assert(got, qt.Contains, "connection ended — actions unavailable; esc to go back") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + got = stripANSI(updated.(Model).View()) + c.Assert(got, qt.Contains, "connection ended — actions unavailable; esc to go back") +} + +func TestDetailFooterWrapsAtNarrowWidth(t *testing.T) { + c := qt.New(t) + + footer := stripANSI(renderDetailHelp(detailState{ + SubjectFound: true, + Tab: tabBlockers, + Width: 80, + })) + + c.Assert(strings.Contains(footer, "ctrl+c quit"), qt.IsTrue) + c.Assert(strings.Contains(footer, "? help"), qt.IsTrue) + c.Assert(strings.Contains(footer, "q/esc back"), qt.IsTrue) + + for _, line := range strings.Split(footer, "\n") { + c.Assert(ansi.StringWidth(line) <= 80, qt.IsTrue) + } + c.Assert(strings.Count(footer, "\n") >= 1, qt.IsTrue) +} + +func TestDetailFooterShowsShiftKForForceTerminate(t *testing.T) { + c := qt.New(t) + + footer := stripANSI(renderDetailHelp(detailState{ + SubjectFound: true, + Width: 160, + })) + + c.Assert(footer, qt.Contains, "shift+K force terminate") +} + +// The Blockers-tab contextual footer includes both navigation and destructive +// action hints, so it must stay within the default 200-column recording width. +func TestDetailBlockersFooterFitsOneLineAtDefaultRecordingWidth(t *testing.T) { + c := qt.New(t) + + footer := stripANSI(renderDetailHelp(detailState{ + SubjectFound: true, + Tab: tabBlockers, + Width: 200, + })) + + c.Assert(strings.Contains(footer, "\n"), qt.IsFalse) + c.Assert(ansi.StringWidth(footer) <= tableWidth(200), qt.IsTrue) + c.Assert(strings.Contains(footer, "ctrl+c quit"), qt.IsTrue) +} + +func TestRenderBlockersShowsBlockedByAndBlockingHeadings(t *testing.T) { + c := qt.New(t) + subject := live.Connection{PID: 10, BlockedBy: []int{20}} + list := live.ConnectionList{Connections: []live.Connection{ + subject, + {PID: 20, BlockedBy: []int{30}}, + {PID: 30}, + {PID: 40, BlockedBy: []int{10}}, + }} + + state := detailState{ + List: list, + Subject: subject, + SubjectFound: true, + Tab: tabBlockers, + Width: 120, + Height: 30, + } + out := renderDetail(state) + c.Assert(out, qt.Contains, "BLOCKED BY") + c.Assert(out, qt.Contains, "BLOCKING") + c.Assert(out, qt.Contains, "20") + c.Assert(out, qt.Contains, "40") +} + +func TestRenderBlockersUsesFriendlyEmptyStates(t *testing.T) { + c := qt.New(t) + conn := live.Connection{PID: 10} + lines := renderBlockers(detailState{ + List: live.ConnectionList{Connections: []live.Connection{conn}}, + Subject: conn, + SubjectFound: true, + }, 120, 20) + got := stripANSI(strings.Join(lines, "\n")) + + c.Assert(got, qt.Contains, "No upstream blocker") + c.Assert(got, qt.Contains, "Not blocking other connections") +} + +func TestRenderDetailCoercesUnavailableBlockersToQuery(t *testing.T) { + c := qt.New(t) + subject := live.Connection{PID: 10, BlockedBy: []int{20}, QueryText: "SELECT 1"} + view := stripANSI(renderDetail(detailState{ + List: live.ConnectionList{Connections: []live.Connection{ + subject, + {PID: 20, QueryText: "blocker"}, + }}, + Subject: subject, + SubjectFound: true, + Tab: tabBlockers, + Capabilities: ConnectionCapabilities{ + CancelQuery: ActionTargetQueryID, + TerminateConnection: ActionTargetConnectionID, + ShowBlockers: false, + }, + Width: 120, + Height: 24, + })) + + c.Assert(view, qt.Contains, "[query]") + c.Assert(view, qt.Contains, "SELECT 1") + c.Assert(view, qt.Not(qt.Contains), "BLOCKED BY") + c.Assert(view, qt.Not(qt.Contains), "BLOCKING") + c.Assert(view, qt.Not(qt.Contains), "selected blocker") +} + +func TestDetailQueryTabRendersProcesslistRecord(t *testing.T) { + c := qt.New(t) + connectionID := "zone1-2001-101" + queryID := "zone1-2001-101" + view := renderDetail(detailState{ + List: live.NewConnectionList(time.Now(), nil, live.SortByDuration), + Subject: live.Connection{ + PID: 101, + Instance: "zone1-2001", + State: "Query/executing", + Duration: 42 * time.Second, + Username: "vt_app", + DatabaseName: "checkout", + ClientAddr: "10.0.0.1:1234", + ConnectionID: &connectionID, + QueryID: &queryID, + QueryText: "SELECT 1", + }, + SubjectFound: true, + Tab: tabQuery, + DisplayPreset: connectionDisplayProcesslist, + Capabilities: DefaultConnectionCapabilities(), + Width: 120, + Height: 24, + }) + + stripped := stripANSI(view) + c.Assert(stripped, qt.Contains, "pid:") + c.Assert(stripped, qt.Contains, "tablet:") + c.Assert(stripped, qt.Contains, "zone1-2001") + c.Assert(stripped, qt.Contains, "state:") + c.Assert(stripped, qt.Contains, "Query/executing") + c.Assert(stripped, qt.Contains, "duration:") + c.Assert(stripped, qt.Contains, "user:") + c.Assert(stripped, qt.Contains, "database:") + c.Assert(stripped, qt.Contains, "checkout") + c.Assert(stripped, qt.Contains, "client_addr:") + c.Assert(stripped, qt.Contains, "connection_id:") + c.Assert(stripped, qt.Contains, "query_id:") + c.Assert(stripped, qt.Contains, "SELECT 1") + c.Assert(stripped, qt.Not(qt.Contains), "blocked_by:") + c.Assert(stripped, qt.Not(qt.Contains), "wait:") + c.Assert(stripped, qt.Not(qt.Contains), "transaction_id:") +} + +func TestDetailProcesslistZeroDurationShowsNone(t *testing.T) { + c := qt.New(t) + view := stripANSI(renderDetail(detailState{ + List: live.NewConnectionList(time.Now(), nil, live.SortByDuration), + Subject: live.Connection{ + PID: 101, + Instance: "zone1-2001", + State: "Sleep", + }, + SubjectFound: true, + Tab: tabQuery, + DisplayPreset: connectionDisplayProcesslist, + Capabilities: DefaultConnectionCapabilities(), + Width: 120, + Height: 24, + })) + + c.Assert(view, qt.Contains, "duration: none") + c.Assert(view, qt.Not(qt.Contains), "duration: 0s") +} + +func TestRenderQueryShowsEmptyMessageWhenTextBlank(t *testing.T) { + c := qt.New(t) + subject := live.Connection{PID: 10} + state := detailState{ + List: live.ConnectionList{Connections: []live.Connection{subject}}, + Subject: subject, + SubjectFound: true, + Tab: tabQuery, + Width: 80, + Height: 24, + } + view := stripANSI(renderDetail(state)) + c.Assert(view, qt.Contains, "no query") + c.Assert(view, qt.Not(qt.Contains), "query empty") +} + +func TestRenderQueryShowsOverflowIndicatorWhenViewportIsOneLine(t *testing.T) { + c := qt.New(t) + conn := live.Connection{ + QueryText: "select a from t where a = 1 and b = 2 and c = 3 order by a limit 10", + } + + lines := renderQuery(conn, connectionDisplayDefault, 80, 1, 0) + + c.Assert(lines, qt.HasLen, 1) + c.Assert(lines[0], qt.Contains, "lines ") + c.Assert(lines[0], qt.Not(qt.Contains), "SELECT") +} + +func TestDetailHelpHidesRefreshWhilePaused(t *testing.T) { + c := qt.New(t) + footer := stripANSI(renderDetailHelp(detailState{SubjectFound: true, Width: 200, Paused: true})) + + c.Assert(footer, qt.Not(qt.Contains), "r refresh") + c.Assert(footer, qt.Contains, "space resume") +} diff --git a/internal/connections/tui/help.go b/internal/connections/tui/help.go new file mode 100644 index 00000000..d2129b6e --- /dev/null +++ b/internal/connections/tui/help.go @@ -0,0 +1,143 @@ +package tui + +import ( + "fmt" + "strings" +) + +type helpState struct { + Target Target + Width int + Height int + Offset int + CanSort bool + Paused bool + Replay bool + DisplayPreset connectionDisplayPreset + Capabilities ConnectionCapabilities +} + +func renderHelpModal(state helpState) string { + width := tableWidth(state.Width) + lines := clippedHelpLines(helpModalLines(state), width) + if state.Height > 0 { + if len(lines) > state.Height { + return strings.Join(visibleHelpLines(lines, state.Height, state.Offset), "\n") + } + for len(lines) < state.Height { + lines = append(lines, "") + } + } + return strings.Join(lines, "\n") +} + +// maxHelpOffset reports the largest useful scroll offset for the help modal; +// scrolling past it would not reveal any new line. +func maxHelpOffset(state helpState) int { + if state.Height <= 1 { + return 0 + } + lines := helpModalLines(state) + if len(lines) <= state.Height { + return 0 + } + return len(lines) - helpBodyHeight(state.Height) +} + +func helpModalLines(state helpState) []string { + support := state.Capabilities.effective() + lines := []string{ + headerStyle.Render("Connections Help"), + "", + "Target: " + emptyDash(renderTarget(state.Target)), + "", + headerStyle.Render("Reading The Table"), + " Rows are live database sessions; the selected row is the action target.", + " Columns are abbreviated to fit the terminal; open detail for full query text.", + "", + headerStyle.Render("States And Status"), + statusHelp(state.DisplayPreset), + " Paused keeps captured age visible; refreshing means fetching.", + "", + } + if support.ShowBlockers { + lines = append(lines, + headerStyle.Render("Markers And Blocking"), + " R marks replica sessions. BLOCK digits show downstream sessions blocked.", + " W in BLOCK means the session is waiting on a lock.", + "", + ) + } + if !support.supports(actionTerminateTxn) { + lines = append(lines, + headerStyle.Render("Actions"), + " c Kill the selected query (KILL QUERY) using the query_id from the selected process.", + " K Kill the selected connection (KILL) using the connection_id from the selected process.", + " c and K require confirmation. Replay mode blocks backend actions.", + "", + ) + } else { + lines = append(lines, + headerStyle.Render("Actions"), + " c Cancel the selected query (pg_cancel_backend) only if it is the same active query we observed.", + " k Kill the selected transaction (pg_terminate_backend) only if it is the same transaction we observed.", + " K Force terminate the selected connection (pg_terminate_backend) only if backend start matches what we observed.", + " c, k, and K require confirmation. Replay mode blocks backend actions.", + "", + ) + } + navigation := " up/down select; enter/v detail" + if support.ShowBlockers { + navigation += "; left/right switch detail tabs; b blockers" + } + controls := []string{} + if !state.Replay && !state.Paused { + controls = append(controls, "r refresh") + } + if state.Paused { + controls = append(controls, "space resume") + } else { + controls = append(controls, "space pause") + } + if !state.Replay { + controls = append(controls, "C capture") + } + controls = append(controls, "[ ] { } step history") + if state.CanSort { + controls = append(controls, "s sort") + } + lines = append(lines, + headerStyle.Render("Navigation"), + navigation, + " "+strings.Join(controls, "; "), + " q/esc close or go back; ctrl+c quit; ?, esc, or q closes help", + ) + return lines +} + +func clippedHelpLines(lines []string, width int) []string { + out := make([]string, len(lines)) + for i, line := range lines { + out[i] = clipLine(line, width) + } + return out +} + +func visibleHelpLines(lines []string, height, offset int) []string { + if height <= 1 { + return lines[:1] + } + bodyHeight := height - 1 + offset = clampInt(offset, 0, len(lines)-bodyHeight) + end := offset + bodyHeight + visible := append([]string{}, lines[offset:end]...) + visible = append(visible, mutedStyle.Render(fmt.Sprintf("lines %d-%d/%d", offset+1, end, len(lines)))) + return visible +} + +func statusHelp(display connectionDisplayPreset) string { + if display == connectionDisplayProcesslist { + return " Query is running; Sleep is quiet; other states come from the Vitess processlist." + } + return " active is running; idle is quiet; idle/xact may hold locks." +} diff --git a/internal/connections/tui/model.go b/internal/connections/tui/model.go new file mode 100644 index 00000000..1a3169ae --- /dev/null +++ b/internal/connections/tui/model.go @@ -0,0 +1,1193 @@ +package tui + +import ( + "context" + "errors" + "fmt" + "net/http" + "strings" + "time" + + tea "github.com/charmbracelet/bubbletea" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" +) + +const ( + // defaultHistoryCapacity bounds the in-memory ring of recent samples kept for + // stepping. ~5 minutes at the default 1s interval. + defaultHistoryCapacity = 300 + noticeTTL = 5 * time.Second +) + +type listMsg struct { + list live.ConnectionList + err error +} + +type tickMsg time.Time + +type noticeTimeoutMsg struct { + id uint64 +} + +type durationDoneMsg struct{} + +type noticeState struct { + id uint64 + text string +} + +type actionKind int + +const ( + actionCancelQuery actionKind = iota + 1 + actionTerminateTxn + actionTerminateConn +) + +type actionResultMsg struct { + kind actionKind + err error +} + +type connectionDisplayPreset string + +const ( + connectionDisplayDefault connectionDisplayPreset = "" + connectionDisplayProcesslist connectionDisplayPreset = "processlist" +) + +type ConnectionViewProfile struct { + displayPreset connectionDisplayPreset + capabilities ConnectionCapabilities + defaultSort live.SortMode + sortModes []live.SortMode +} + +var ( + // PostgresConnectionView presents live connections with Postgres session semantics. + PostgresConnectionView = ConnectionViewProfile{ + displayPreset: connectionDisplayDefault, + capabilities: DefaultConnectionCapabilities(), + defaultSort: live.SortByTransactionStart, + sortModes: postgresSortModes(), + } + // VitessConnectionView presents live connections with Vitess processlist semantics. + VitessConnectionView = ConnectionViewProfile{ + displayPreset: connectionDisplayProcesslist, + capabilities: ConnectionCapabilities{ + CancelQuery: ActionTargetQueryID, + TerminateConnection: ActionTargetConnectionID, + ShowBlockers: false, + configured: true, + }, + defaultSort: live.SortByDuration, + sortModes: []live.SortMode{live.SortByDuration}, + } +) + +func postgresSortModes() []live.SortMode { + return []live.SortMode{ + live.SortByTransactionStart, + live.SortByDuration, + live.SortByBlocked, + } +} + +// DefaultSort is the sort order a connection source should use for this view. +func (p ConnectionViewProfile) DefaultSort() live.SortMode { + return p.defaultSort +} + +func (p ConnectionViewProfile) sortOptions() []live.SortMode { + if len(p.sortModes) == 0 { + return []live.SortMode{p.defaultSort} + } + return append([]live.SortMode(nil), p.sortModes...) +} + +type Target struct { + Database string + Branch string + Keyspace string + Shard string +} + +// ConnectionsClient is the live-connections data + action dependency the +// view model talks to. Production wires this to the wire-level *live.Client +// (directly or through a filtering wrapper); tests provide a stub that records +// calls. +type ConnectionsClient interface { + List(context.Context, live.SortMode) (live.ConnectionList, error) + CancelQuery(context.Context, live.ActionTarget) error + TerminateTransaction(context.Context, live.ActionTarget) error + TerminateConnection(context.Context, live.ActionTarget) error +} + +// Model is the Bubble Tea view model for the interactive table. +type Model struct { + client ConnectionsClient + ctx context.Context + interval time.Duration + duration time.Duration + sort live.SortMode + sortModes []live.SortMode + now func() time.Time + capture *CaptureControl + target Target + filter string + + displayPreset connectionDisplayPreset + capabilities ConnectionCapabilities + + samples *history.CaptureHistory + cursor history.CaptureCursor + following bool + liveRefresh bool + + // stepAnchorBase is the history base recorded when the cursor last moved. + // The step-history numerator counts from this anchor rather than the live + // oldest edge, so a held frame's position holds steady while paused even as + // eviction advances the live base underneath it. + stepAnchorBase history.CaptureCursor + + lastSuccessfulList live.ConnectionList + hasList bool + lastError string + initialAccessDenied bool + // actionError holds the result of an explicit user action (cancel/kill, a + // permission denial, or a "nothing to act on" guard). Unlike lastError — + // which a successful auto-refresh clears — it survives refreshes and is + // cleared only by the operator's next keystroke, so a destructive action's + // outcome can't flash past unread on the ~1s refresh. + actionError string + notice noticeState + captureStopped string + paused bool + loading bool + // consecutiveErrors counts list fetches that have failed in a row; reset to + // 0 on the next success. Drives the header refresh dot: 1 reads as a + // transient blip (cyan), refreshDotFailThreshold+ reads as a sustained + // outage (red). + consecutiveErrors int + selected int + viewportStart int + width int + height int + confirming bool + pendingTarget live.ActionTarget + pendingKind actionKind + readOnlyActionError string + helpOpen bool + helpOffset int + + detailOpen bool + detailInstance string + detailPID int + detailTab detailTab + blockerRow int + queryOffset int +} + +func NewModel(ctx context.Context, client ConnectionsClient, interval, duration time.Duration) Model { + return Model{ + client: client, + ctx: ctx, + interval: interval, + duration: duration, + sort: live.SortByTransactionStart, + sortModes: postgresSortModes(), + now: time.Now, + width: 100, + height: 24, + detailTab: tabQuery, + capabilities: DefaultConnectionCapabilities(), + samples: history.NewCaptureHistory(defaultHistoryCapacity), + following: true, + liveRefresh: true, + } +} + +func (m Model) WithTarget(target Target) Model { + m.target = target + return m +} + +// WithFilter records the active row filter (e.g. "filter: role=primary") so the +// header can show that the view is scoped. Empty when no filter is active. +func (m Model) WithFilter(filter string) Model { + m.filter = filter + return m +} + +func (m Model) WithConnectionView(profile ConnectionViewProfile) Model { + m.displayPreset = profile.displayPreset + m.capabilities = profile.capabilities.effective() + m.sort = profile.DefaultSort() + m.sortModes = profile.sortOptions() + return m +} + +func (m Model) WithReadOnlyActions(message string) Model { + m.readOnlyActionError = message + return m +} + +// WithCaptureWriter persists every successful list to the capture trace file +// while the TUI runs. Returns the model unchanged when w is nil. +func (m Model) WithCaptureWriter(w *history.CaptureWriter) Model { + if w != nil { + m.capture = &CaptureControl{Writer: w} + } + return m +} + +func (m Model) WithCaptureControl(control *CaptureControl) Model { + m.capture = control + return m +} + +// WithCaptureHistory replaces the default in-memory ring with one preloaded +// by the caller, used by replay mode to seed the model with every captured +// snapshot up front. The cursor points at the latest capture, rendering +// reflects it immediately, and the model starts paused so the operator +// controls advance via the step keybindings instead of an auto-tick chewing +// through the trace. +func (m Model) WithCaptureHistory(h *history.CaptureHistory) Model { + m.samples = h + m.liveRefresh = false + if cursor, ok := h.Latest(); ok { + m.cursor = cursor + m.lastSuccessfulList = m.currentList() + m.hasList = true + m.paused = true + m.recordStepPosition() + } + return m +} + +// isReplay reports whether the model is driving a replayed trace rather than a +// live source. +func (m Model) isReplay() bool { + return !m.liveRefresh +} + +func (m Model) setNotice(text string) (Model, tea.Cmd) { + m.notice.id++ + m.notice.text = text + id := m.notice.id + return m, tea.Tick(noticeTTL, func(time.Time) tea.Msg { + return noticeTimeoutMsg{id: id} + }) +} + +func (m Model) clearNotice() Model { + m.notice.text = "" + return m +} + +func (m Model) rejectReadOnlyAction() (Model, bool) { + if m.readOnlyActionError == "" { + return m, false + } + m = m.setActionError(m.readOnlyActionError) + return m, true +} + +func (m Model) setActionError(text string) Model { + m.actionError = text + return m.clearNotice() +} + +func (m Model) Init() tea.Cmd { + var cmds []tea.Cmd + if m.liveRefresh && !m.hasList { + cmds = append(cmds, m.fetch()) + } + if m.liveRefresh { + cmds = append(cmds, m.tick()) + } + cmds = append(cmds, m.durationTimer()) + return tea.Batch(cmds...) +} + +func (m Model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.WindowSizeMsg: + m.width = msg.Width + m.height = msg.Height + m.clampViewport() + m.clampBlockerSelection() + m.clampQueryOffset() + return m, tea.ClearScreen + case tickMsg: + if !m.liveRefresh { + return m, nil + } + if m.loading { + return m, m.tick() + } + m.loading = true + return m, tea.Batch(m.fetch(), m.tick()) + case noticeTimeoutMsg: + if msg.id == m.notice.id { + m = m.clearNotice() + } + return m, nil + case listMsg: + m.loading = false + if msg.err != nil { + m.lastError = listErrorText(msg.err, m.hasList) + m.initialAccessDenied = !m.hasList && isForbiddenHTTPError(msg.err) + m.consecutiveErrors++ + return m, nil + } + m.consecutiveErrors = 0 + m.initialAccessDenied = false + cursor := m.samples.Push(msg.list) + prevCursor := m.cursor + if !m.hasList { + m.cursor = cursor + m.following = !m.paused + } else if m.paused { + if _, ok := m.samples.At(m.cursor); !ok { + m.cursor, _ = m.samples.Oldest() + } + m.following = false + } else if m.following { + m.cursor = cursor + } else if _, ok := m.samples.At(m.cursor); !ok { + m.cursor, _ = m.samples.Oldest() + } + // Re-anchor the step label only when the cursor actually moved (initial + // load, follow-live advance, or an eviction reset). A held paused frame + // keeps its label so it does not drift as the buffer tail grows. + if m.cursor != prevCursor || !m.hasList { + m.recordStepPosition() + } + prevConn, hadSelection := m.selectedConnection() + prevIndex := m.selected + m.lastSuccessfulList = m.currentList() + m.hasList = true + m.lastError = "" + if hadSelection { + m.reanchorSelection(prevConn, prevIndex) + } + m.clampViewport() + m.clampBlockerSelection() + m.clampQueryOffset() + if m.capture != nil && m.capture.Writer != nil { + if err := m.capture.Writer.Write(history.NewCapture(msg.list)); err != nil { + // Detach the writer so we stop calling Write on a dead + // destination (ENOSPC, EPIPE, EBADF won't recover), and + // stash the reason in a sticky field that listMsg does + // not clear — m.lastError gets reset on every successful + // refresh, which would flash the warning instead of + // holding it. + m.captureStopped = "capture stopped: " + err.Error() + _ = m.capture.Close() + } + } + return m, nil + case durationDoneMsg: + return m, tea.Quit + case actionResultMsg: + if msg.err != nil { + m = m.setActionError(live.UserFacingErrorText(msg.err, "modify")) + return m, nil + } + var cmd tea.Cmd + m, cmd = m.setNotice(actionNotice(msg.kind) + " sent") + m.lastError = "" + return m, cmd + case tea.KeyMsg: + // An action error (cancel/kill outcome, permission denial, "nothing to + // act on" guard) is sticky until the operator's next keystroke, so it + // can't be wiped by an auto-refresh before it's read. Clear it here, on + // that next keystroke. A keystroke that starts a fresh action sets it + // again below. + m.actionError = "" + if m.confirming { + return m.handleConfirmKey(msg) + } + if m.helpOpen { + return m.handleHelpKey(msg) + } + if m.detailOpen { + return m.handleDetailKey(msg) + } + return m.handleTableKey(msg) + } + return m, nil +} + +func (m Model) handleTableKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "q", "ctrl+c": + return m, tea.Quit + case "?": + m.helpOpen = true + m.helpOffset = 0 + return m, nil + case "r": + if !m.liveRefresh { + return m, nil + } + if m.loading { + return m, nil + } + m.loading = true + return m, m.fetch() + case " ": + return m.togglePause() + case "C": + if m.capture == nil { + return m, nil + } + return m.toggleCapture() + case "s": + next, ok := nextSort(m.sort, m.sortModes) + if !ok { + return m, nil + } + m.sort = next + if m.hasList { + m.lastSuccessfulList = m.currentList() + m.clampViewport() + } + return m, nil + case "[", "]", "{", "}": + return m.handleStepKey(msg.String()), nil + case "c": + if next, rejected := m.rejectReadOnlyAction(); rejected { + return next, nil + } + conn, ok := m.selectedConnection() + if !ok { + return m, nil + } + return m.startConfirm(actionCancelQuery, conn) + case "k", "K": + if next, rejected := m.rejectReadOnlyAction(); rejected { + return next, nil + } + conn, ok := m.selectedConnection() + if !ok { + return m, nil + } + kind := actionTerminateTxn + if msg.String() == "K" { + kind = actionTerminateConn + } + return m.startConfirm(kind, conn) + case "enter", "v", "V", "b", "B": + conn, ok := m.selectedConnection() + if !ok { + return m, nil + } + if strings.EqualFold(msg.String(), "b") && !m.capabilities.effective().ShowBlockers { + return m, nil + } + m.detailOpen = true + m.detailInstance = conn.Instance + m.detailPID = conn.PID + m.blockerRow = 0 + m.queryOffset = 0 + if strings.EqualFold(msg.String(), "b") { + m.detailTab = tabBlockers + m.clampBlockerSelection() + } else { + m.detailTab = tabQuery + } + return m, nil + case "up": + m.moveSelection(-1) + return m, nil + case "down": + m.moveSelection(1) + return m, nil + } + return m, nil +} + +func (m Model) handleDetailKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "ctrl+c": + return m, tea.Quit + case "?": + m.helpOpen = true + m.helpOffset = 0 + return m, nil + case "q", "esc", "backspace": + m.detailOpen = false + m.blockerRow = 0 + m.queryOffset = 0 + return m, nil + case "r": + if !m.liveRefresh { + return m, nil + } + if m.loading { + return m, nil + } + m.loading = true + return m, m.fetch() + case " ": + return m.togglePause() + case "C": + if m.capture == nil { + return m, nil + } + return m.toggleCapture() + case "b", "B": + if !m.capabilities.effective().ShowBlockers { + return m, nil + } + m.detailTab = tabBlockers + m.queryOffset = 0 + m.clampBlockerSelection() + return m, nil + case "left", "right": + if !m.capabilities.effective().ShowBlockers { + m.detailTab = tabQuery + m.queryOffset = 0 + m.clampQueryOffset() + return m, nil + } + if m.detailTab == tabBlockers { + m.detailTab = tabQuery + m.queryOffset = 0 + m.clampQueryOffset() + } else { + m.detailTab = tabBlockers + m.queryOffset = 0 + m.clampBlockerSelection() + m.clampQueryOffset() + } + return m, nil + case "up": + if m.detailTab == tabQuery && m.queryOffset > 0 { + m.queryOffset-- + return m, nil + } + if m.detailTab == tabBlockers && m.blockerRow > 0 { + m.blockerRow-- + } + return m, nil + case "down": + if m.detailTab == tabQuery { + m.queryOffset++ + m.clampQueryOffset() + return m, nil + } + if m.detailTab == tabBlockers { + m.blockerRow++ + m.clampBlockerSelection() + } + return m, nil + case "enter": + if !m.capabilities.effective().ShowBlockers { + return m, nil + } + if m.detailTab == tabBlockers { + target, ok := m.actionTargetConnection() + if ok { + m.detailInstance = target.Instance + m.detailPID = target.PID + m.blockerRow = 0 + m.queryOffset = 0 + m.clampBlockerSelection() + m.clampQueryOffset() + } + } + return m, nil + case "c": + if next, rejected := m.rejectReadOnlyAction(); rejected { + return next, nil + } + conn, ok := m.actionTargetConnection() + if !ok { + return m.rejectEndedDetailAction(), nil + } + return m.startConfirm(actionCancelQuery, conn) + case "k", "K": + if next, rejected := m.rejectReadOnlyAction(); rejected { + return next, nil + } + conn, ok := m.actionTargetConnection() + if !ok { + return m.rejectEndedDetailAction(), nil + } + kind := actionTerminateTxn + if msg.String() == "K" { + kind = actionTerminateConn + } + return m.startConfirm(kind, conn) + case "[", "]", "{", "}": + return m.handleStepKey(msg.String()), nil + } + return m, nil +} + +func (m Model) rejectEndedDetailAction() Model { + if m.detailOpen { + m.actionError = "connection ended — actions unavailable; esc to go back" + m = m.clearNotice() + } + return m +} + +func (m Model) handleHelpKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "?", "esc", "q": + m.helpOpen = false + return m, nil + case "ctrl+c": + return m, tea.Quit + case "up": + m.helpOffset = clampInt(m.helpOffset-1, 0, maxHelpOffset(m.helpState())) + return m, nil + case "down": + m.helpOffset = clampInt(m.helpOffset+1, 0, maxHelpOffset(m.helpState())) + return m, nil + case "pgup": + m.helpOffset = clampInt(m.helpOffset-helpBodyHeight(m.height), 0, maxHelpOffset(m.helpState())) + return m, nil + case "pgdown": + m.helpOffset = clampInt(m.helpOffset+helpBodyHeight(m.height), 0, maxHelpOffset(m.helpState())) + return m, nil + } + return m, nil +} + +func helpBodyHeight(height int) int { + if height <= 1 { + return 1 + } + return height - 1 +} + +func (m Model) helpState() helpState { + return helpState{Target: m.target, Width: m.width, Height: m.height, Offset: m.helpOffset, CanSort: m.canChangeSort(), Paused: m.paused, Replay: m.isReplay(), DisplayPreset: m.displayPreset, Capabilities: m.capabilities} +} + +func (m Model) handleStepKey(key string) Model { + var ( + cursor history.CaptureCursor + ok bool + ) + switch key { + case "[": + cursor, ok = m.samples.Step(m.cursor, -1) + case "]": + cursor, ok = m.samples.Step(m.cursor, 1) + case "{": + cursor, ok = m.samples.Oldest() + ok = ok && cursor != m.cursor + case "}": + cursor, ok = m.samples.Latest() + ok = ok && (cursor != m.cursor || m.canResumeLiveFollow()) + } + if !ok { + return m + } + m.cursor = cursor + latest, hasLatest := m.samples.Latest() + if hasLatest && cursor == latest && m.canResumeLiveFollow() { + m.following = true + m.paused = false + } else { + m.following = false + } + m.recordStepPosition() + m.lastSuccessfulList = m.currentList() + m.clampViewport() + m.clampBlockerSelection() + m.clampQueryOffset() + return m +} + +func (m Model) togglePause() (Model, tea.Cmd) { + m.paused = !m.paused + if m.paused { + m.following = false + m.recordStepPosition() + return m.setNotice("paused") + } + var cmd tea.Cmd + m, cmd = m.setNotice("resumed") + if cursor, ok := m.samples.Latest(); ok { + m.cursor = cursor + m.following = true + m.recordStepPosition() + m.lastSuccessfulList = m.currentList() + m.clampViewport() + m.clampBlockerSelection() + m.clampQueryOffset() + } + return m, cmd +} + +func (m Model) View() string { + if m.helpOpen { + return renderHelpModal(m.helpState()) + } + if m.detailOpen { + return renderDetail(m.detailState()) + } + return renderTable(m.tableState()) +} + +func (m Model) tableState() tableState { + pos, total := m.stepPosition() + return tableState{ + List: m.lastSuccessfulList, + HasList: m.hasList, + Sort: m.sort, + CanSort: m.canChangeSort(), + Selected: m.selected, + ViewportStart: m.viewportStart, + Width: m.width, + Height: m.height, + Paused: m.paused, + Refresh: computeRefreshDot(m.loading, m.consecutiveErrors, m.isReplay()), + ReadOnlyActions: m.readOnlyActionError != "", + Replay: m.isReplay(), + LastError: firstNonEmpty(m.actionError, m.lastError), + AccessDenied: m.initialAccessDenied, + CanStepHistory: m.samples.Len() > 0, + Notice: m.notice.text, + CaptureStopped: m.captureStopped, + CaptureStatus: m.captureStatusText(), + Confirm: m.confirmPrompt(), + Now: m.now(), + Interval: m.interval, + StepPos: pos, + StepTotal: total, + Target: m.target, + Filter: m.filter, + DisplayPreset: m.displayPreset, + Capabilities: m.capabilities, + } +} + +// currentList returns the capture at the model's cursor, sorted by m.sort. +// SortConnections runs in place on the history-owned slice; row order drifts +// across sort changes but values are unchanged. +func (m Model) currentList() live.ConnectionList { + list, ok := m.samples.At(m.cursor) + if !ok { + return live.ConnectionList{} + } + live.SortConnections(list.Connections, m.sort) + list.Sort = m.sort + return list +} + +// stepPosition returns the 1-based position of the displayed capture and the +// total captures held in history, but only while the user is stepping back +// from latest. When following live, both are zero so the header omits the +// indicator. The numerator counts from stepAnchorBase (the base when the +// cursor last moved), so a held frame's position does not drift backward while +// paused as eviction advances the live base. +func (m Model) stepPosition() (pos, total int) { + if m.following { + return 0, 0 + } + total = m.samples.Len() + if total == 0 { + return 0, 0 + } + pos = int(m.cursor-m.stepAnchorBase) + 1 + if pos < 1 { + pos = 1 + } + if pos >= total { + return 0, 0 + } + return pos, total +} + +// recordStepPosition re-anchors the step numerator to the current live base. +// Call it only when the cursor actually moves (step keys, pause/resume, +// eviction reset) — never on a routine push that leaves the held cursor in +// place, or the numerator would drift backward as eviction advances the base. +func (m *Model) recordStepPosition() { + if oldest, ok := m.samples.Oldest(); ok { + m.stepAnchorBase = oldest + } +} + +func (m Model) startConfirm(kind actionKind, conn live.Connection) (tea.Model, tea.Cmd) { + target := actionTargetFor(conn) + if !m.capabilities.supports(kind) { + return m, nil + } + // Don't raise a destructive confirmation for an action that can't run: if + // the required server-issued ID is absent (e.g. an idle backend with no + // active query, or a partial row), surface that immediately instead of + // prompting y/N and only revealing "X is required" after the operator + // confirms. + if reason := m.capabilities.missingActionID(kind, target); reason != "" { + m = m.setActionError(reason) + return m, nil + } + m.confirming = true + m.pendingKind = kind + m.pendingTarget = target + return m, nil +} + +func firstNonEmpty(a, b string) string { + if a != "" { + return a + } + return b +} + +// listErrorText renders a refresh failure. An instance filter that stops +// matching mid-session means the instance left the branch's instance set, not +// that the operator mistyped it — reword so the live view stays honest. +func listErrorText(err error, hasList bool) string { + var unknownInstance *live.UnknownInstanceError + if hasList && errors.As(err, &unknownInstance) { + return fmt.Sprintf("instance %q is no longer in the branch's instance set", unknownInstance.Instance) + } + return live.UserFacingErrorText(err, "view") +} + +func isForbiddenHTTPError(err error) bool { + var httpErr *live.HTTPError + return errors.As(err, &httpErr) && httpErr.StatusCode == http.StatusForbidden +} + +func (m Model) canResumeLiveFollow() bool { + return m.liveRefresh +} + +func actionTargetFor(conn live.Connection) live.ActionTarget { + return live.ActionTarget{ + Instance: conn.Instance, + PID: conn.PID, + ConnectionID: conn.ConnectionID, + TransactionID: conn.TransactionID, + QueryID: conn.QueryID, + } +} + +func (m Model) handleConfirmKey(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "y": + target := m.pendingTarget + kind := m.pendingKind + m.confirming = false + m.pendingTarget = live.ActionTarget{} + m.pendingKind = 0 + return m, m.fireAction(kind, target) + case "n", "esc", "enter": + kind := m.pendingKind + m.confirming = false + m.pendingTarget = live.ActionTarget{} + m.pendingKind = 0 + return m.setNotice(actionNotice(kind) + " cancelled") + case "q", "ctrl+c": + return m, tea.Quit + } + return m, nil +} + +func (m Model) confirmPrompt() string { + if !m.confirming { + return "" + } + var verb string + switch m.pendingKind { + case actionCancelQuery: + verb = "Cancel query on" + case actionTerminateTxn: + verb = "Terminate transaction on" + case actionTerminateConn: + verb = "Force terminate" + default: + return "" + } + return fmt.Sprintf("%s PID %d on %s? [y/N]", verb, m.pendingTarget.PID, m.pendingTarget.Instance) +} + +func (m Model) fetch() tea.Cmd { + return func() tea.Msg { + list, err := m.client.List(m.ctx, m.sort) + return listMsg{list: list, err: err} + } +} + +func (m Model) tick() tea.Cmd { + if m.interval <= 0 { + return nil + } + return tea.Tick(m.interval, func(t time.Time) tea.Msg { + return tickMsg(t) + }) +} + +func (m Model) durationTimer() tea.Cmd { + if m.duration <= 0 { + return nil + } + return tea.Tick(m.duration, func(time.Time) tea.Msg { + return durationDoneMsg{} + }) +} + +func (m *Model) moveSelection(delta int) { + count := len(m.lastSuccessfulList.Connections) + if count == 0 { + m.selected = 0 + m.viewportStart = 0 + return + } + m.selected = clampInt(m.selected+delta, 0, count-1) + m.clampViewport() +} + +func (m *Model) clampViewport() { + count := len(m.lastSuccessfulList.Connections) + if count == 0 { + m.selected = 0 + m.viewportStart = 0 + return + } + visibleRows := visibleRowCount(count, bodyHeight(m.tableState())) + m.selected = clampInt(m.selected, 0, count-1) + m.viewportStart = viewportStartForSelection(m.viewportStart, m.selected, count, visibleRows) +} + +func (m *Model) clampBlockerSelection() { + if !m.detailOpen || m.detailTab != tabBlockers { + return + } + subject, ok := m.detailSubject() + if !ok { + m.blockerRow = 0 + return + } + rows := detailBlockerRows(m.lastSuccessfulList, subject) + if len(rows) == 0 { + m.blockerRow = 0 + return + } + m.blockerRow = clampInt(m.blockerRow, 0, len(rows)-1) +} + +func (m *Model) clampQueryOffset() { + if !m.detailOpen || m.detailTab != tabQuery { + m.queryOffset = 0 + return + } + subject, ok := m.detailSubject() + if !ok { + m.queryOffset = 0 + return + } + bodyHeight := queryBodyHeight(m.height, m.detailFooterLineCount()) + // The Query tab renders the full \G record (fields + query), so the scroll + // clamp must span the whole record, not just the query lines — otherwise the + // tail of the query is unreachable. + total := len(connectionRecordLines(subject, m.displayPreset, tableWidth(m.width))) + m.queryOffset = clampInt(m.queryOffset, 0, maxQueryOffset(total, bodyHeight)) +} + +func queryBodyHeight(height, footerLines int) int { + if height <= 0 { + height = 24 + } + headerLines := 4 + bodyHeight := height - headerLines - footerLines + if bodyHeight < 0 { + return 0 + } + return bodyHeight +} + +// detailState assembles the detail-view snapshot. renderDetail and +// detailFooterLineCount both derive from this single builder so the body-height +// math they each perform reads an identical footer and can never diverge. +func (m Model) detailState() detailState { + subject, ok := m.detailSubject() + pos, total := m.stepPosition() + return detailState{ + List: m.lastSuccessfulList, + Subject: subject, + SubjectFound: ok, + Tab: m.detailTab, + BlockerSelection: m.blockerRow, + QueryOffset: m.queryOffset, + Width: m.width, + Height: m.height, + Paused: m.paused, + Refresh: computeRefreshDot(m.loading, m.consecutiveErrors, m.isReplay()), + ReadOnlyActions: m.readOnlyActionError != "", + Replay: m.isReplay(), + LastError: firstNonEmpty(m.actionError, m.lastError), + Notice: m.notice.text, + CaptureStopped: m.captureStopped, + CaptureStatus: m.captureStatusText(), + Confirm: m.confirmPrompt(), + Now: m.now(), + Interval: m.interval, + StepPos: pos, + StepTotal: total, + Target: m.target, + DisplayPreset: m.displayPreset, + Capabilities: m.capabilities, + } +} + +func (m Model) detailFooterLineCount() int { + return strings.Count(renderDetailFooter(m.detailState()), "\n") + 1 +} + +func (m Model) canChangeSort() bool { + return len(m.sortModes) > 1 +} + +func nextSort(sort live.SortMode, options []live.SortMode) (live.SortMode, bool) { + if len(options) <= 1 { + return sort, false + } + for i, option := range options { + if option == sort { + return options[(i+1)%len(options)], true + } + } + return options[0], true +} + +func (m Model) selectedConnection() (live.Connection, bool) { + if !m.hasList || len(m.lastSuccessfulList.Connections) == 0 { + return live.Connection{}, false + } + if m.selected < 0 || m.selected >= len(m.lastSuccessfulList.Connections) { + return live.Connection{}, false + } + return m.lastSuccessfulList.Connections[m.selected], true +} + +// reanchorSelection keeps the highlight on the same connection (by PID+instance) +// after a refresh instead of leaving the positional index pointing at whatever +// connection now occupies that row. Vitess processlist rows can recycle into +// Sleep quickly, so a selected work row that becomes idle falls back to nearby +// active work instead of following the stale identity into the idle pool. +func (m *Model) reanchorSelection(previous live.Connection, previousIndex int) { + if m.displayPreset == connectionDisplayProcesslist { + m.reanchorProcesslistSelection(previous, previousIndex) + return + } + for i, conn := range m.lastSuccessfulList.Connections { + if sameConnection(conn, previous) { + m.selected = i + return + } + } +} + +func (m *Model) reanchorProcesslistSelection(previous live.Connection, previousIndex int) { + for i, conn := range m.lastSuccessfulList.Connections { + if !sameConnection(conn, previous) { + continue + } + if processlistConnectionHasWork(conn) || !processlistConnectionHasWork(previous) { + m.selected = i + return + } + break + } + m.selected = nearestProcesslistWorkIndex(m.lastSuccessfulList.Connections, previousIndex) +} + +func sameConnection(a, b live.Connection) bool { + return a.PID == b.PID && a.Instance == b.Instance +} + +func nearestProcesslistWorkIndex(connections []live.Connection, index int) int { + if len(connections) == 0 { + return 0 + } + index = clampInt(index, 0, len(connections)-1) + if processlistConnectionHasWork(connections[index]) { + return index + } + for offset := 1; offset < len(connections); offset++ { + before := index - offset + if before >= 0 && processlistConnectionHasWork(connections[before]) { + return before + } + after := index + offset + if after < len(connections) && processlistConnectionHasWork(connections[after]) { + return after + } + } + return index +} + +func (m Model) detailSubject() (live.Connection, bool) { + if m.detailPID == 0 { + return live.Connection{}, false + } + for _, conn := range m.lastSuccessfulList.Connections { + if conn.PID == m.detailPID && conn.Instance == m.detailInstance { + return conn, true + } + } + return live.Connection{}, false +} + +// actionTargetConnection returns the connection that c/k/K should act on +// while the detail view is active. On the Blockers tab it resolves the +// highlighted blocker row; otherwise it returns the detail subject. +func (m Model) actionTargetConnection() (live.Connection, bool) { + subject, ok := m.detailSubject() + if !ok { + return live.Connection{}, false + } + if m.detailTab != tabBlockers { + return subject, true + } + rows := detailBlockerRows(m.lastSuccessfulList, subject) + if m.blockerRow < 0 || m.blockerRow >= len(rows) { + return subject, true + } + row := rows[m.blockerRow] + if !row.Present { + return live.Connection{}, false + } + return row.Connection, true +} + +func (m Model) fireAction(kind actionKind, target live.ActionTarget) tea.Cmd { + return func() tea.Msg { + var err error + switch kind { + case actionCancelQuery: + err = m.client.CancelQuery(m.ctx, target) + case actionTerminateTxn: + err = m.client.TerminateTransaction(m.ctx, target) + case actionTerminateConn: + err = m.client.TerminateConnection(m.ctx, target) + } + return actionResultMsg{kind: kind, err: err} + } +} + +func actionNotice(kind actionKind) string { + switch kind { + case actionCancelQuery: + return "cancel query" + case actionTerminateTxn: + return "terminate transaction" + case actionTerminateConn: + return "force terminate" + } + return "action" +} diff --git a/internal/connections/tui/model_test.go b/internal/connections/tui/model_test.go new file mode 100644 index 00000000..5e4305d4 --- /dev/null +++ b/internal/connections/tui/model_test.go @@ -0,0 +1,1895 @@ +package tui + +import ( + "bytes" + "context" + "errors" + "fmt" + "net/http" + "strings" + "testing" + "time" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/lipgloss" + qt "github.com/frankban/quicktest" + "github.com/muesli/termenv" + live "github.com/planetscale/cli/internal/connections" + "github.com/planetscale/cli/internal/connections/history" +) + +func TestModelRendersListAndRows(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + list := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, + Instance: "primary", + Username: "brett", + ApplicationName: "psql", + ClientAddr: "127.0.0.1", + State: "active", + Duration: 3 * time.Second, + QueryText: "SELECT * FROM widgets", + }}, live.SortByTransactionStart) + + updated, _ = model.Update(listMsg{list: list}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "connections 1") + c.Assert(view, qt.Contains, "10") + c.Assert(view, qt.Contains, "SELECT * FROM widgets") + c.Assert(view, qt.Contains, "q quit") +} + +func TestModelRendersTargetInTableHeader(t *testing.T) { + tests := []struct { + name string + model Model + list live.ConnectionList + want string + }{ + { + name: "postgres target", + model: NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "prod", Branch: "main"}), + list: live.NewConnectionList(time.Unix(100, 0), []live.Connection{{PID: 10}}, live.SortByTransactionStart), + want: "prod / main", + }, + { + name: "vitess target from configured target", + model: NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "shop", Branch: "main", Keyspace: "commerce", Shard: "-80"}). + WithConnectionView(VitessConnectionView), + list: live.NewConnectionList(time.Unix(100, 0), []live.Connection{{PID: 10}}, live.SortByDuration), + want: "shop / main / commerce / -80", + }, + { + name: "vitess target from topology", + model: NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "shop", Branch: "main"}). + WithConnectionView(VitessConnectionView), + list: func() live.ConnectionList { + list := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{PID: 10}}, live.SortByDuration) + list.Topology = &live.Topology{Keyspace: "commerce", Shard: "-80", Tablet: "zone1-1001"} + return list + }(), + want: "shop / main / commerce / -80", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + updated, _ := tt.model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + updated, _ = updated.(Model).Update(listMsg{list: tt.list}) + + view := updated.(Model).View() + + c.Assert(view, qt.Contains, tt.want) + c.Assert(view, qt.Contains, "connections 1") + }) + } +} + +func TestModelRendersActiveFilterInTableHeader(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithFilter("filter: instance=db-1") + + view := model.View() + + c.Assert(view, qt.Contains, "filter: instance=db-1") +} + +func TestModelClearsScreenOnResize(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + _, cmd := model.Update(tea.WindowSizeMsg{Width: 120, Height: 30}) + + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(fmt.Sprintf("%T", cmd()), qt.Equals, "tea.clearScreenMsg") +} + +func TestModelRendersTargetInDetailHeader(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "prod", Branch: "main"}) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary", QueryText: "select 1"}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "prod / main") + c.Assert(view, qt.Contains, "pid 10") +} + +func TestModelKeepsLastListWhenRefreshFails(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(listMsg{err: errors.New("connection refused")}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "10") + c.Assert(view, qt.Contains, "error: connection refused") +} + +func TestModelShowsInitialListErrorBeforeAnySuccessfulList(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + updated, _ := model.Update(listMsg{err: errors.New("list connections: server is warming up, please retry in a moment")}) + view := updated.(Model).View() + + c.Assert(view, qt.Not(qt.Contains), "loading live connections...") + c.Assert(view, qt.Contains, "unable to load live connections") + c.Assert(view, qt.Contains, "list connections: server is warming up, please retry in a moment") +} + +func TestModelInitialPermissionDeniedShowsAccessDeniedState(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + err := fmt.Errorf("wrapped read failure: %w", &live.HTTPError{ + Op: "list connections", + StatusCode: http.StatusForbidden, + Message: "policy denied this request", + }) + + updated, _ := model.Update(listMsg{err: err}) + view := stripANSI(updated.(Model).View()) + + c.Assert(view, qt.Contains, "connections —") + c.Assert(view, qt.Contains, "you don't have permission to view live connections") + c.Assert(view, qt.Contains, "production branches require the Analyst role or higher — ask an org admin, or use a development branch") + c.Assert(view, qt.Contains, "r refresh | space pause | ? help | q quit") + c.Assert(view, qt.Not(qt.Contains), "policy denied this request") + c.Assert(view, qt.Not(qt.Contains), "enter detail") + c.Assert(view, qt.Not(qt.Contains), "cancel query") + c.Assert(view, qt.Not(qt.Contains), "kill transaction") + c.Assert(view, qt.Not(qt.Contains), "force terminate") +} + +func TestModelActionForbiddenAfterListKeepsTableWithErrorFooter(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + QueryText: "SELECT pg_sleep(30)", + }}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(actionResultMsg{ + kind: actionTerminateConn, + err: &live.HTTPError{ + Op: "terminate connection", + StatusCode: http.StatusForbidden, + Message: "denied by infra policy on tablet zone1-1001", + }, + }) + view := stripANSI(updated.(Model).View()) + + c.Assert(view, qt.Contains, "connections 1") + c.Assert(view, qt.Contains, "10") + c.Assert(view, qt.Contains, "SELECT pg_sleep(30)") + c.Assert(view, qt.Contains, "error: permission denied: you don't have permission to modify live connections") + c.Assert(view, qt.Not(qt.Contains), "zone1-1001") + c.Assert(view, qt.Not(qt.Contains), "you don't have permission to view live connections") +} + +func TestModelListForbiddenAfterSuccessKeepsTableWithErrorFooter(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + QueryText: "SELECT pg_sleep(30)", + }}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(listMsg{ + err: &live.HTTPError{ + Op: "list connections", + StatusCode: http.StatusForbidden, + Message: "denied by infra policy on tablet zone1-1001", + }, + }) + view := stripANSI(updated.(Model).View()) + + c.Assert(view, qt.Contains, "connections 1") + c.Assert(view, qt.Contains, "10") + c.Assert(view, qt.Contains, "SELECT pg_sleep(30)") + c.Assert(view, qt.Contains, "error: permission denied: you don't have permission to view live connections") + c.Assert(view, qt.Not(qt.Contains), "zone1-1001") + c.Assert(view, qt.Not(qt.Contains), "unable to load live connections") +} + +func TestModelInitialErrorDoesNotAdvertiseRowActions(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + updated, _ := model.Update(listMsg{err: errors.New("list connections: server is warming up")}) + view := stripANSI(updated.(Model).View()) + + c.Assert(view, qt.Contains, "connections —") + c.Assert(view, qt.Contains, "unable to load live connections") + c.Assert(view, qt.Not(qt.Contains), "enter detail") + c.Assert(view, qt.Not(qt.Contains), "cancel query") + c.Assert(view, qt.Not(qt.Contains), "kill transaction") + c.Assert(view, qt.Not(qt.Contains), "force terminate") +} + +func TestModelReadOnlyActionsHidesFooterActions(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithReadOnlyActions("replay mode cannot run live actions") + qid := "10-q" + xid := "10-x" + cid := "10-c" + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + QueryID: &qid, + TransactionID: &xid, + ConnectionID: &cid, + QueryText: "SELECT 1", + }}, live.SortByTransactionStart) + + updated, _ := model.Update(listMsg{list: list}) + tableView := stripANSI(updated.(Model).View()) + + c.Assert(tableView, qt.Contains, "r refresh") + c.Assert(tableView, qt.Not(qt.Contains), "cancel query") + c.Assert(tableView, qt.Not(qt.Contains), "kill transaction") + c.Assert(tableView, qt.Not(qt.Contains), "force terminate") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + detailView := stripANSI(updated.(Model).View()) + + c.Assert(detailView, qt.Contains, "q/esc back") + c.Assert(detailView, qt.Not(qt.Contains), "cancel query") + c.Assert(detailView, qt.Not(qt.Contains), "kill transaction") + c.Assert(detailView, qt.Not(qt.Contains), "force terminate") +} + +func TestModelNoticeSurvivesRefreshUntilTTL(t *testing.T) { + c := qt.New(t) + base := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.now = func() time.Time { return base } + + updated, _ := model.Update(actionResultMsg{kind: actionTerminateConn}) + got := updated.(Model) + c.Assert(got.View(), qt.Contains, "force terminate sent") + noticeID := got.notice.id + + got.now = func() time.Time { return base.Add(2 * time.Second) } + updated, _ = got.Update(listMsg{list: live.NewConnectionList(base, []live.Connection{{PID: 10}}, live.SortByTransactionStart)}) + got = updated.(Model) + c.Assert(got.View(), qt.Contains, "force terminate sent") + + got.now = func() time.Time { return base.Add(6 * time.Second) } + updated, _ = got.Update(noticeTimeoutMsg{id: noticeID}) + got = updated.(Model) + c.Assert(got.View(), qt.Not(qt.Contains), "force terminate sent") +} + +func TestModelNoticeExpiresWhileReplayIsIdle(t *testing.T) { + c := qt.New(t) + base := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + h := history.NewCaptureHistory(3) + h.Push(live.NewConnectionList(base, []live.Connection{{PID: 10}}, live.SortByTransactionStart)) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureHistory(h) + model.now = func() time.Time { return base } + + updated, cmd := model.Update(tea.KeyMsg{Type: tea.KeySpace}) + got := updated.(Model) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(got.liveRefresh, qt.IsFalse) + c.Assert(got.notice.text, qt.Equals, "resumed") + noticeID := got.notice.id + + updated, cmd = got.Update(noticeTimeoutMsg{id: noticeID}) + got = updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.notice.text, qt.Equals, "") +} + +func TestModelNoticeExpiryIgnoresStaleTimer(t *testing.T) { + c := qt.New(t) + first := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + second := first.Add(time.Second) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.now = func() time.Time { return first } + + updated, cmd := model.Update(actionResultMsg{kind: actionCancelQuery}) + got := updated.(Model) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(got.notice.text, qt.Equals, "cancel query sent") + firstNoticeID := got.notice.id + + got.now = func() time.Time { return second } + updated, cmd = got.Update(actionResultMsg{kind: actionTerminateConn}) + got = updated.(Model) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(got.notice.text, qt.Equals, "force terminate sent") + + updated, cmd = got.Update(noticeTimeoutMsg{id: firstNoticeID}) + got = updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.notice.text, qt.Equals, "force terminate sent") +} + +func TestModelKeepsEmptyListAndErrorDuringContinuedContention(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), nil, live.SortByTransactionStart) + + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(listMsg{err: errors.New("list connections: server is warming up, please retry in a moment")}) + updated, _ = updated.(Model).Update(tickMsg(time.Now())) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "connections 0") + c.Assert(view, qt.Contains, "●") // fixed-width refresh indicator (replaces the "refreshing" text token) + c.Assert(view, qt.Contains, "no live connections") + c.Assert(view, qt.Contains, "new connections appear on the next refresh") + c.Assert(view, qt.Contains, "error: list connections: server is warming up, please retry in a moment") +} + +func TestModelKeepsPopulatedListAndErrorDuringContinuedContention(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + QueryText: "SELECT pg_sleep(120)", + }}, live.SortByTransactionStart) + + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(listMsg{err: errors.New("list connections: server is warming up, please retry in a moment")}) + updated, _ = updated.(Model).Update(tickMsg(time.Now())) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "connections 1") + c.Assert(view, qt.Contains, "●") // fixed-width refresh indicator (replaces the "refreshing" text token) + c.Assert(view, qt.Contains, "10") + c.Assert(view, qt.Contains, "SELECT pg_sleep(120)") + c.Assert(view, qt.Contains, "error: list connections: server is warming up, please retry in a moment") +} + +func TestModelCyclesSortMode(t *testing.T) { + c := qt.New(t) + now := time.Now() + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(now, []live.Connection{ + {PID: 1, Duration: time.Second}, + {PID: 2, Duration: 5 * time.Second, BlockedBy: []int{3}, State: "active"}, + {PID: 3, Duration: 2 * time.Second}, + }, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("s")}) + got := updated.(Model) + + c.Assert(got.sort, qt.Equals, live.SortByDuration) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 2) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("s")}) + got = updated.(Model) + + c.Assert(got.sort, qt.Equals, live.SortByBlocked) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 3) + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("s")}) + + c.Assert(updated.(Model).sort, qt.Equals, live.SortByTransactionStart) +} + +func TestModelVitessConnectionViewUsesOnlyDurationSort(t *testing.T) { + c := qt.New(t) + now := time.Now() + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithConnectionView(VitessConnectionView) + list := live.NewConnectionList(now, []live.Connection{ + {PID: 1, Duration: time.Second}, + {PID: 2, Duration: 5 * time.Second}, + }, live.SortByDuration) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("s")}) + got := updated.(Model) + + c.Assert(got.sort, qt.Equals, live.SortByDuration) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 2) + c.Assert(stripANSI(got.View()), qt.Not(qt.Contains), "s sort") + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + c.Assert(stripANSI(updated.(Model).View()), qt.Not(qt.Contains), "s sort") +} + +func TestModelPauseKeepsTickRefreshing(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeySpace}) + updated, cmd := updated.(Model).Update(tickMsg(time.Now())) + + got := updated.(Model) + c.Assert(got.paused, qt.IsTrue) + c.Assert(got.loading, qt.IsTrue) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(got.View(), qt.Contains, "paused") +} + +func TestModelClearsPauseNoticeAfterTTL(t *testing.T) { + c := qt.New(t) + base := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.now = func() time.Time { return base } + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeySpace}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeySpace}) + got := updated.(Model) + noticeID := got.notice.id + + got.now = func() time.Time { return base.Add(2 * time.Second) } + updated, _ = got.Update(listMsg{list: live.NewConnectionList(base, nil, live.SortByTransactionStart)}) + got = updated.(Model) + c.Assert(got.View(), qt.Contains, "resumed") + + got.now = func() time.Time { return base.Add(6 * time.Second) } + updated, _ = got.Update(noticeTimeoutMsg{id: noticeID}) + c.Assert(updated.(Model).View(), qt.Not(qt.Contains), "resumed") +} + +func TestModelArrowKeysScrollVisibleRows(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 10}) + model = updated.(Model) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 101}, + {PID: 202}, + {PID: 303}, + {PID: 404}, + {PID: 505}, + {PID: 606}, + }, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + view := updated.(Model).View() + + c.Assert(view, qt.Not(qt.Contains), "101") + c.Assert(view, qt.Contains, "505") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyUp}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyUp}) + view = updated.(Model).View() + + c.Assert(view, qt.Contains, "101") + c.Assert(view, qt.Not(qt.Contains), "606") +} + +func TestDetailQueryTabScrollsMultilineQuery(t *testing.T) { + c := qt.New(t) + query := "select a\nfrom t\nwhere a = 1\nand b = 2\nand c = 3\nand d = 4\norder by a\nlimit 10" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary", QueryText: query}}, live.SortByTransactionStart) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 10}) + updated, _ = updated.(Model).Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + firstView := updated.(Model).View() + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + secondView := updated.(Model).View() + + c.Assert(firstView, qt.Not(qt.Equals), secondView) + c.Assert(secondView, qt.Contains, "lines ") +} + +func TestDetailQueryTabRendersVerticalRecord(t *testing.T) { + c := qt.New(t) + conn := live.Connection{ + PID: 42, Instance: "cell-1", InstanceRole: "replica", State: "active", + Username: "app", ApplicationName: "worker", QueryText: "SELECT 1", + } + lines := connectionRecordLines(conn, connectionDisplayDefault, 120) + joined := strings.Join(lines, "\n") + for _, want := range []string{"pid:", "instance:", "role:", "state:", "application:", "query_id:", "connection_id:", "query:", "SELECT 1"} { + c.Assert(joined, qt.Contains, want) + } +} + +func TestDetailQueryTabScrollsToLastLine(t *testing.T) { + c := qt.New(t) + // An authored multiline query has more lines than a short viewport, forcing + // the scroll path. The final line must be reachable at max scroll. + query := "select a, b, c\nfrom t\nwhere a = 1\nand b = 2\nand c = 3\nand d = 4\nand e = 5\nand f = 6\norder by a\nlimit 10" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary", QueryText: query}}, live.SortByTransactionStart) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 12}) + updated, _ = updated.(Model).Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + // Scroll well past the end; the clamp must land on the true max offset. + for i := 0; i < 50; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + view := updated.(Model).View() + + // "limit 10" is the last line; it must be visible at max scroll. + c.Assert(view, qt.Contains, "limit 10") +} + +func TestModelManualRefreshFetchesWhenIdle(t *testing.T) { + c := qt.New(t) + source := &clientStub{list: live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByDuration)} + model := NewModel(context.Background(), source, time.Second, 0) + + updated, cmd := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("r")}) + + c.Assert(updated.(Model).loading, qt.IsTrue) + c.Assert(cmd, qt.Not(qt.IsNil)) + msg := cmd().(listMsg) + c.Assert(msg.err, qt.IsNil) + c.Assert(msg.list.Connections[0].PID, qt.Equals, 10) + c.Assert(source.calls, qt.Equals, 1) +} + +func TestModelManualRefreshSkipsWhenLoading(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.loading = true + + _, cmd := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("r")}) + + c.Assert(cmd, qt.IsNil) +} + +func TestModelTickDoesNotStartFetchWhileLoading(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.loading = true + + updated, cmd := model.Update(tickMsg(time.Now())) + + c.Assert(updated.(Model).loading, qt.IsTrue) + c.Assert(cmd, qt.Not(qt.IsNil)) +} + +func TestModelFetchUsesParentContext(t *testing.T) { + c := qt.New(t) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + model := NewModel(ctx, &clientStub{}, time.Second, 0) + + msg := model.fetch()().(listMsg) + + c.Assert(errors.Is(msg.err, context.Canceled), qt.IsTrue) +} + +// clientStub stands in for *live.Client in TUI tests. It records List calls +// and the three action calls, with separate error fields so a test can fail +// just the read path or just the action path. +type clientStub struct { + list live.ConnectionList + err error + calls int + + cancelCalls int + terminateTxn int + terminateConn int + lastTarget live.ActionTarget + actionErr error +} + +func (s *clientStub) List(ctx context.Context, sort live.SortMode) (live.ConnectionList, error) { + s.calls++ + if err := ctx.Err(); err != nil { + return live.ConnectionList{}, err + } + return s.list, s.err +} + +func (s *clientStub) CancelQuery(ctx context.Context, target live.ActionTarget) error { + s.cancelCalls++ + s.lastTarget = target + return s.actionErr +} + +func (s *clientStub) TerminateTransaction(ctx context.Context, target live.ActionTarget) error { + s.terminateTxn++ + s.lastTarget = target + return s.actionErr +} + +func (s *clientStub) TerminateConnection(ctx context.Context, target live.ActionTarget) error { + s.terminateConn++ + s.lastTarget = target + return s.actionErr +} + +func vitessModelWithConnection() (Model, *clientStub) { + client := &clientStub{} + connectionID := "zone1-2001-101" + queryID := "zone1-2001-101" + model := NewModel(context.Background(), client, time.Second, 0). + WithConnectionView(VitessConnectionView) + list := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 101, + Instance: "zone1-2001", + ConnectionID: &connectionID, + QueryID: &queryID, + QueryText: "SELECT 1", + }}, live.SortByDuration) + updated, _ := model.Update(listMsg{list: list}) + return updated.(Model), client +} + +func TestModelRendersPartialFailureBanner(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + updated, _ = model.Update(listMsg{list: live.ConnectionList{ + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-a", Role: "replica", Error: "timeout after 2s"}, + {ID: "replica-b", Role: "replica", Error: "connection refused"}, + }, + }}) + + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "2 of 3 instances unreachable") + c.Assert(view, qt.Contains, "replica-a") + c.Assert(view, qt.Contains, "replica-b") +} + +func TestModelOmitsBannerWhenAllInstancesHealthy(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + updated, _ = model.Update(listMsg{list: live.ConnectionList{ + Instances: []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-a", Role: "replica"}, + }, + }}) + + view := updated.(Model).View() + + c.Assert(view, qt.Not(qt.Contains), "instances unreachable") +} + +func TestModelRendersFreshnessRelativeAge(t *testing.T) { + c := qt.New(t) + captured := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + clock := captured.Add(2 * time.Second) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + model.now = func() time.Time { return clock } + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + updated, _ = model.Update(listMsg{list: live.NewConnectionList(captured, []live.Connection{{PID: 1}}, live.SortByTransactionStart)}) + + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "(2s ago)") +} + +func TestModelInteractiveWritesCaptureFile(t *testing.T) { + c := qt.New(t) + var buf bytes.Buffer + writer := history.NewCaptureWriter(&buf) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureWriter(writer) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + + updated, _ = model.Update(listMsg{list: list}) + _ = updated + + c.Assert(buf.String(), qt.Contains, `"pid":10`) +} + +func TestModelCaptureWriteErrorDetachesWriterAndSetsStickyError(t *testing.T) { + c := qt.New(t) + target := &countingFailingWriter{err: errors.New("disk full")} + writer := history.NewCaptureWriter(target) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureWriter(writer) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + + updated, _ = model.Update(listMsg{list: list}) + got := updated.(Model) + + c.Assert(got.captureStopped, qt.Contains, "capture stopped") + c.Assert(got.captureStopped, qt.Contains, "disk full") + c.Assert(got.capture.Writer, qt.IsNil) + c.Assert(target.writes, qt.Equals, 1) + c.Assert(got.View(), qt.Contains, "capture stopped") + + // A subsequent successful list must not retry the detached writer and + // must not clear the sticky indicator. + updated, _ = got.Update(listMsg{list: list}) + got = updated.(Model) + + c.Assert(target.writes, qt.Equals, 1) + c.Assert(got.captureStopped, qt.Contains, "capture stopped") + c.Assert(got.View(), qt.Contains, "capture stopped") +} + +func TestModelToggleCaptureBackfillsHistoryAndTailsFutureSamples(t *testing.T) { + c := qt.New(t) + var buf bytes.Buffer + openCalls := 0 + control := &CaptureControl{ + Open: func() (*history.CaptureWriter, string, error) { + openCalls++ + return history.NewCaptureWriter(&buf), "trace.jsonl", nil + }, + } + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureControl(control) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + for _, pid := range []int{10, 11} { + next, _ := model.Update(listMsg{list: live.NewConnectionList(time.Unix(int64(pid), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = next.(Model) + } + c.Assert(buf.String(), qt.Equals, "") + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("C")}) + model = updated.(Model) + + c.Assert(openCalls, qt.Equals, 1) + c.Assert(buf.String(), qt.Contains, `"pid":10`) + c.Assert(buf.String(), qt.Contains, `"pid":11`) + c.Assert(model.View(), qt.Contains, "rec trace.jsonl") + + updated, _ = model.Update(listMsg{list: live.NewConnectionList(time.Unix(12, 0), []live.Connection{{ + PID: 12, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = updated.(Model) + + c.Assert(buf.String(), qt.Contains, `"pid":12`) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("C")}) + model = updated.(Model) + updated, _ = model.Update(listMsg{list: live.NewConnectionList(time.Unix(13, 0), []live.Connection{{ + PID: 13, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = updated.(Model) + + c.Assert(buf.String(), qt.Not(qt.Contains), `"pid":13`) + c.Assert(model.View(), qt.Not(qt.Contains), "rec trace.jsonl") +} + +func TestModelToggleCaptureBackfillsTailOnceButWritesFutureSamples(t *testing.T) { + c := qt.New(t) + var buf bytes.Buffer + control := &CaptureControl{ + Open: func() (*history.CaptureWriter, string, error) { + return history.NewCaptureWriter(&buf), "trace.jsonl", nil + }, + } + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureControl(control) + list := live.NewConnectionList(time.Unix(10, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart) + + updated, _ := model.Update(listMsg{list: list}) + model = updated.(Model) + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("C")}) + model = updated.(Model) + + c.Assert(strings.Count(buf.String(), `"pid":10`), qt.Equals, 1) + + updated, _ = model.Update(listMsg{list: list}) + _ = updated + + c.Assert(strings.Count(buf.String(), `"pid":10`), qt.Equals, 2) +} + +func TestModelCaptureNoticeSurvivesRefreshUntilTTL(t *testing.T) { + c := qt.New(t) + base := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + var buf bytes.Buffer + control := &CaptureControl{ + Open: func() (*history.CaptureWriter, string, error) { + return history.NewCaptureWriter(&buf), "trace.jsonl", nil + }, + } + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureControl(control) + model.now = func() time.Time { return base } + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("C")}) + got := updated.(Model) + c.Assert(got.View(), qt.Contains, "capturing to trace.jsonl") + noticeID := got.notice.id + + got.now = func() time.Time { return base.Add(2 * time.Second) } + updated, _ = got.Update(listMsg{list: live.NewConnectionList(base, []live.Connection{{PID: 10}}, live.SortByTransactionStart)}) + got = updated.(Model) + c.Assert(got.View(), qt.Contains, "capturing to trace.jsonl") + + got.now = func() time.Time { return base.Add(6 * time.Second) } + updated, _ = got.Update(noticeTimeoutMsg{id: noticeID}) + got = updated.(Model) + c.Assert(got.View(), qt.Not(qt.Contains), "capturing to trace.jsonl") +} + +type countingFailingWriter struct { + err error + writes int +} + +func (w *countingFailingWriter) Write([]byte) (int, error) { + w.writes++ + return 0, w.err +} + +func TestModelCancelQueryDispatchesAction(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + + qid := "10-7" + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, Instance: "primary", QueryID: &qid, + }}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + + got := updated.(Model) + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingKind, qt.Equals, actionCancelQuery) + c.Assert(got.View(), qt.Contains, "Cancel query") + c.Assert(got.View(), qt.Contains, "[y/N]") + c.Assert(client.cancelCalls, qt.Equals, 0) + + updated, cmd = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")}) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(updated.(Model).confirming, qt.IsFalse) + _ = cmd() + c.Assert(client.cancelCalls, qt.Equals, 1) + c.Assert(client.lastTarget.Instance, qt.Equals, "primary") + c.Assert(client.lastTarget.PID, qt.Equals, 10) + c.Assert(client.lastTarget.QueryID, qt.Not(qt.IsNil)) + c.Assert(*client.lastTarget.QueryID, qt.Equals, "10-7") +} + +func TestModelReadOnlyRejectsActionsBeforeConfirmation(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0).WithReadOnlyActions("not available in replay mode") + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary"}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + for _, key := range []string{"c", "k", "K"} { + next, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(key)}) + got := next.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsFalse) + c.Assert(got.actionError, qt.Equals, "not available in replay mode") + } + c.Assert(client.cancelCalls, qt.Equals, 0) + c.Assert(client.terminateTxn, qt.Equals, 0) + c.Assert(client.terminateConn, qt.Equals, 0) +} + +func TestModelHelpOpensAndClosesFromTable(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "prod", Branch: "main"}) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + got := updated.(Model) + + c.Assert(got.View(), qt.Contains, "Reading The Table") + c.Assert(got.View(), qt.Contains, "prod / main") + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyEsc}) + c.Assert(updated.(Model).View(), qt.Not(qt.Contains), "Reading The Table") +} + +func TestModelHelpScrollsWhenClipped(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "prod", Branch: "main"}) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 80, Height: 8}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + first := stripANSI(updated.(Model).View()) + + c.Assert(first, qt.Contains, "Reading The Table") + c.Assert(first, qt.Not(qt.Contains), "Actions") + + for i := 0; i < 16; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + second := stripANSI(updated.(Model).View()) + + c.Assert(second, qt.Contains, "Actions") + c.Assert(second, qt.Contains, "lines ") +} + +func TestModelHelpUsesPostgresActionCopy(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "Cancel the selected query") + c.Assert(view, qt.Contains, "Kill the selected transaction") + c.Assert(view, qt.Contains, "Force terminate the selected connection") + c.Assert(view, qt.Not(qt.Contains), "pg_cancel_query") +} + +func TestModelHelpUsesVitessActionCopy(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithConnectionView(VitessConnectionView) + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "Kill the selected query") + c.Assert(view, qt.Contains, "Kill the selected connection") + c.Assert(view, qt.Contains, "connection_id") + c.Assert(view, qt.Not(qt.Contains), "idle/xact") + c.Assert(view, qt.Not(qt.Contains), "pg_terminate_backend only if it is the same transaction") +} + +func TestModelHelpBlocksActionKeys(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary"}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + got := updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsFalse) + c.Assert(client.terminateConn, qt.Equals, 0) +} + +func TestModelHelpIgnoredWhileConfirming(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + xid := "10-42" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, Instance: "primary", TransactionID: &xid}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + // Open the terminate confirm prompt, then press ? — help must stay closed and + // the confirm prompt must remain until resolved. + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + c.Assert(updated.(Model).confirming, qt.IsTrue) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("?")}) + got := updated.(Model) + + c.Assert(got.helpOpen, qt.IsFalse) + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.View(), qt.Not(qt.Contains), "Reading The Table") +} + +func TestModelVitessConnectionViewDisablesBlockersAndTransactionKill(t *testing.T) { + c := qt.New(t) + model, _ := vitessModelWithConnection() + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + c.Assert(updated.(Model).detailOpen, qt.IsFalse) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + c.Assert(updated.(Model).confirming, qt.IsFalse) + c.Assert(updated.(Model).actionError, qt.Equals, "") +} + +func TestModelVitessConnectionViewCancelTargetsQueryID(t *testing.T) { + c := qt.New(t) + model, _ := vitessModelWithConnection() + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + got := updated.(Model) + + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingTarget.QueryID, qt.Not(qt.IsNil)) + c.Assert(*got.pendingTarget.QueryID, qt.Equals, "zone1-2001-101") +} + +func TestModelVitessConnectionViewForceTerminateTargetsConnectionID(t *testing.T) { + c := qt.New(t) + model, _ := vitessModelWithConnection() + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + got := updated.(Model) + + c.Assert(got.confirming, qt.IsTrue) + c.Assert(got.pendingTarget.ConnectionID, qt.Not(qt.IsNil)) + c.Assert(*got.pendingTarget.ConnectionID, qt.Equals, "zone1-2001-101") +} + +func TestModelVitessConnectionViewUsesProcesslistActionLabels(t *testing.T) { + c := qt.New(t) + connectionID := "zone1-2001-101" + queryID := "zone1-2001-101" + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithConnectionView(VitessConnectionView) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 101, + Instance: "zone1-2001", + ConnectionID: &connectionID, + QueryID: &queryID, + QueryText: "SELECT 1", + }}, live.SortByDuration) + updated, _ := model.Update(listMsg{list: list}) + tableView := stripANSI(updated.(Model).View()) + + c.Assert(tableView, qt.Contains, "c KILL QUERY") + c.Assert(tableView, qt.Contains, "shift+K KILL") + c.Assert(tableView, qt.Not(qt.Contains), "cancel query") + c.Assert(tableView, qt.Not(qt.Contains), "force terminate") + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + detailView := stripANSI(updated.(Model).View()) + + c.Assert(detailView, qt.Contains, "c KILL QUERY") + c.Assert(detailView, qt.Contains, "shift+K KILL") + c.Assert(detailView, qt.Not(qt.Contains), "left/right tabs") +} + +func TestModelTerminateTransactionConfirmationGate(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + xid := "10-42" + model := NewModel(context.Background(), client, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, Instance: "primary", TransactionID: &xid, + }}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + // Press k — should enter confirming state without firing. + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("k")}) + c.Assert(cmd, qt.IsNil) + c.Assert(updated.(Model).confirming, qt.IsTrue) + c.Assert(client.terminateTxn, qt.Equals, 0) + + // Press y — fires. + updated, cmd = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")}) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(updated.(Model).confirming, qt.IsFalse) + _ = cmd() + c.Assert(client.terminateTxn, qt.Equals, 1) + c.Assert(*client.lastTarget.TransactionID, qt.Equals, "10-42") +} + +func TestModelForceTerminateConfirmationGate(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + cid := "10-1779113716123456" + model := NewModel(context.Background(), client, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, Instance: "primary", ConnectionID: &cid, + }}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + // Press K — should enter confirming state without firing. + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + c.Assert(cmd, qt.IsNil) + c.Assert(updated.(Model).confirming, qt.IsTrue) + c.Assert(client.terminateConn, qt.Equals, 0) + c.Assert(updated.(Model).View(), qt.Contains, "[y/N]") + c.Assert(updated.(Model).View(), qt.Not(qt.Contains), "(y/n)") + + // A random key should be a no-op (no leak to other handlers). + updated, cmd = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("x")}) + c.Assert(cmd, qt.IsNil) + c.Assert(updated.(Model).confirming, qt.IsTrue) + c.Assert(client.terminateConn, qt.Equals, 0) + + // Press y — fires. + updated, cmd = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("y")}) + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(updated.(Model).confirming, qt.IsFalse) + _ = cmd() + c.Assert(client.terminateConn, qt.Equals, 1) + c.Assert(client.lastTarget.ConnectionID, qt.Not(qt.IsNil)) + c.Assert(*client.lastTarget.ConnectionID, qt.Equals, "10-1779113716123456") +} + +func TestModelForceTerminateConfirmationCancel(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + cid := "10-c" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, ConnectionID: &cid}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(updated.(Model).confirming, qt.IsFalse) + c.Assert(client.terminateConn, qt.Equals, 0) + c.Assert(updated.(Model).notice.text, qt.Equals, "force terminate cancelled") +} + +func TestModelCancelQueryConfirmationCancel(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + qid := "10-q" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, QueryID: &qid}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("c")}) + updated, cmd := updated.(Model).Update(tea.KeyMsg{Type: tea.KeyEnter}) + + c.Assert(cmd, qt.Not(qt.IsNil)) + c.Assert(updated.(Model).confirming, qt.IsFalse) + c.Assert(client.cancelCalls, qt.Equals, 0) + c.Assert(updated.(Model).notice.text, qt.Equals, "cancel query cancelled") +} + +// Universal-quit shortcuts: ctrl+c and q must still quit during confirmation +// so an operator can always escape the gate without knowing n/esc. +func TestModelForceTerminateConfirmationHonorsUniversalQuit(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + + for _, key := range []tea.KeyMsg{ + {Type: tea.KeyCtrlC}, + {Type: tea.KeyRunes, Runes: []rune("q")}, + } { + model := NewModel(context.Background(), client, time.Second, 0) + cid := "10-c" + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10, ConnectionID: &cid}}, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + c.Assert(updated.(Model).confirming, qt.IsTrue) + + _, cmd := updated.(Model).Update(key) + + c.Assert(cmd, qt.Not(qt.IsNil), qt.Commentf("key %v should produce tea.Quit during confirmation", key)) + msg := cmd() + _, ok := msg.(tea.QuitMsg) + c.Assert(ok, qt.IsTrue, qt.Commentf("key %v should fire tea.QuitMsg, got %T", key, msg)) + c.Assert(client.terminateConn, qt.Equals, 0, qt.Commentf("key %v must not fire the destructive action", key)) + } +} + +// startConfirm() returns silently when selectedConnection() reports no +// selection. Pin that branch so a future refactor can't accidentally enter +// confirming state on an empty list. +func TestModelForceTerminateWithNoSelectionIsNoOp(t *testing.T) { + c := qt.New(t) + client := &clientStub{} + model := NewModel(context.Background(), client, time.Second, 0) + // No listMsg dispatched: hasList is false. + + updated, cmd := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("K")}) + + c.Assert(cmd, qt.IsNil) + c.Assert(updated.(Model).confirming, qt.IsFalse) + c.Assert(client.terminateConn, qt.Equals, 0) +} + +// The dispatch tests above verify that the right client method is called but +// don't round-trip the resulting actionResultMsg back through Update. These +// two tests pin the Update-side branches: success sets a notice, error sets +// actionError without preserving a stale notice. + +func TestModelActionResultSuccessSetsNotice(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + cases := []struct { + kind actionKind + want string + }{ + {actionCancelQuery, "cancel query sent"}, + {actionTerminateTxn, "terminate transaction sent"}, + {actionTerminateConn, "force terminate sent"}, + } + + for _, tc := range cases { + updated, cmd := model.Update(actionResultMsg{kind: tc.kind}) + + c.Assert(cmd, qt.Not(qt.IsNil), qt.Commentf("kind %d should schedule notice expiry", tc.kind)) + got := updated.(Model) + c.Assert(got.notice.text, qt.Equals, tc.want) + c.Assert(got.lastError, qt.Equals, "") + } +} + +func TestModelActionResultErrorSetsActionError(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + updated, cmd := model.Update(actionResultMsg{kind: actionCancelQuery, err: errors.New("server boom")}) + + c.Assert(cmd, qt.IsNil) + got := updated.(Model) + // Action failures go to the sticky actionError (survives refresh), not the + // refresh-scoped lastError. + c.Assert(got.actionError, qt.Equals, "server boom") + c.Assert(got.lastError, qt.Equals, "") + c.Assert(got.notice.text, qt.Equals, "") + + // A successful auto-refresh must NOT clear the action error... + updated, _ = got.Update(listMsg{list: live.NewConnectionList(time.Now(), nil, live.SortByTransactionStart)}) + c.Assert(updated.(Model).actionError, qt.Equals, "server boom") + // ...but the operator's next keystroke does. + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("s")}) + c.Assert(updated.(Model).actionError, qt.Equals, "") +} + +func TestModelActionErrorClearsStaleNotice(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(actionResultMsg{kind: actionCancelQuery}) + model = updated.(Model) + c.Assert(model.notice.text, qt.Equals, "cancel query sent") + + updated, _ = model.Update(actionResultMsg{kind: actionCancelQuery, err: errors.New("server boom")}) + got := updated.(Model) + + c.Assert(got.actionError, qt.Equals, "server boom") + c.Assert(got.notice.text, qt.Equals, "") +} + +// Pressing a destructive action on a row missing the required ID surfaces an +// immediate error instead of prompting y/N. +func TestMissingActionIDSkipsConfirm(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + // No QueryID on the connection → cancel-query has nothing to act on. + conn := live.Connection{PID: 7, Instance: "primary"} + + updated, cmd := model.startConfirm(actionCancelQuery, conn) + got := updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsFalse) // no destructive prompt raised + c.Assert(got.actionError, qt.Contains, "no active query to cancel") +} + +func TestDefaultConnectionCapabilitiesKeepsBlockersAndMissingIDs(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", BlockedBy: []int{20}, QueryText: "blocked"}, + {PID: 20, Instance: "primary", QueryText: "blocker"}, + }, live.SortByTransactionStart) + updated, _ := model.Update(listMsg{list: list}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("b")}) + got := updated.(Model) + + c.Assert(got.detailOpen, qt.IsTrue) + c.Assert(got.detailTab, qt.Equals, tabBlockers) + c.Assert(got.View(), qt.Contains, "BLOCKED BY") + c.Assert(got.View(), qt.Contains, "blockers") + + updated, cmd := model.startConfirm(actionTerminateTxn, live.Connection{PID: 7, Instance: "primary"}) + got = updated.(Model) + c.Assert(cmd, qt.IsNil) + c.Assert(got.confirming, qt.IsFalse) + c.Assert(got.actionError, qt.Contains, "no open transaction to terminate") + + support := DefaultConnectionCapabilities() + c.Assert(support.missingActionID(actionCancelQuery, live.ActionTarget{}), qt.Contains, "no active query to cancel") + c.Assert(support.missingActionID(actionTerminateTxn, live.ActionTarget{}), qt.Contains, "no open transaction to terminate") +} + +func TestModelStepStepsBackThroughHistory(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + for i, pid := range []int{10, 11, 12} { + list := live.NewConnectionList(time.Unix(int64(100+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart) + next, _ := model.Update(listMsg{list: list}) + model = next.(Model) + } + + c.Assert(model.lastSuccessfulList.Connections[0].PID, qt.Equals, 12) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("[")}) + got := updated.(Model) + + c.Assert(got.following, qt.IsFalse) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + c.Assert(got.View(), qt.Contains, "step 2/3") +} + +func TestModelStepJumpKeysWalkToOldestAndLatest(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + for i, pid := range []int{10, 11, 12} { + list := live.NewConnectionList(time.Unix(int64(100+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart) + next, _ := model.Update(listMsg{list: list}) + model = next.(Model) + } + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("{")}) + got := updated.(Model) + c.Assert(got.following, qt.IsFalse) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 10) + c.Assert(got.View(), qt.Contains, "step 1/3") + + updated, _ = got.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("}")}) + got = updated.(Model) + c.Assert(got.following, qt.IsTrue) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 12) + pos, total := got.stepPosition() + c.Assert(pos, qt.Equals, 0) + c.Assert(total, qt.Equals, 0) +} + +func TestModelStepHoldsViewWhenNewSamplesArrive(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + for i, pid := range []int{10, 11} { + list := live.NewConnectionList(time.Unix(int64(100+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart) + next, _ := model.Update(listMsg{list: list}) + model = next.(Model) + } + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("[")}) + got := updated.(Model) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 10) + + newest := live.NewConnectionList(time.Unix(102, 0), []live.Connection{{ + PID: 12, Instance: "primary", + }}, live.SortByTransactionStart) + updated, _ = got.Update(listMsg{list: newest}) + got = updated.(Model) + + c.Assert(got.following, qt.IsFalse) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 10) + c.Assert(got.View(), qt.Contains, "step 1/3") +} + +func TestModelStepPositionHoldsUnderEviction(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + for _, pid := range []int{10, 11, 12} { + h.Push(live.NewConnectionList(time.Unix(int64(pid), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)) + } + + model := NewModel(context.Background(), &clientStub{}, time.Second, 0).WithCaptureHistory(h) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("[")}) + model = updated.(Model) + c.Assert(model.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + pos, total := model.stepPosition() + c.Assert(pos, qt.Equals, 2) + c.Assert(total, qt.Equals, 3) + + updated, _ = model.Update(listMsg{list: live.NewConnectionList(time.Unix(13, 0), []live.Connection{{ + PID: 13, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = updated.(Model) + + c.Assert(model.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + pos, total = model.stepPosition() + c.Assert(pos, qt.Equals, 2) + c.Assert(total, qt.Equals, 3) +} + +func TestModelJumpToLatestCatchesUpAfterPausedSamples(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + initial := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: initial}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeySpace}) + model = updated.(Model) + + newest := live.NewConnectionList(time.Unix(101, 0), []live.Connection{{ + PID: 11, Instance: "primary", + }}, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: newest}) + model = updated.(Model) + + c.Assert(model.lastSuccessfulList.Connections[0].PID, qt.Equals, 10) + c.Assert(model.View(), qt.Contains, "step 1/2") + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("}")}) + got := updated.(Model) + + c.Assert(got.following, qt.IsTrue) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + pos, total := got.stepPosition() + c.Assert(pos, qt.Equals, 0) + c.Assert(total, qt.Equals, 0) +} + +func TestModelJumpLatestResumesLiveFollowOutsideReplay(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + + for i, pid := range []int{10, 11} { + updated, _ := model.Update(listMsg{list: live.NewConnectionList(time.Unix(int64(100+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = updated.(Model) + } + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeySpace}) + model = updated.(Model) + c.Assert(model.paused, qt.IsTrue) + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("[")}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("}")}) + got := updated.(Model) + + c.Assert(got.paused, qt.IsFalse) + c.Assert(got.following, qt.IsTrue) + c.Assert(got.liveRefresh, qt.IsTrue) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) +} + +func TestModelJumpLatestInReplayStaysReadOnlyAndPaused(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + for i, pid := range []int{10, 11} { + h.Push(live.NewConnectionList(time.Unix(int64(100+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)) + } + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithCaptureHistory(h). + WithReadOnlyActions("not available in replay mode") + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("[")}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("}")}) + got := updated.(Model) + + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + c.Assert(got.paused, qt.IsTrue) + c.Assert(got.following, qt.IsFalse) + c.Assert(got.liveRefresh, qt.IsFalse) + c.Assert(got.readOnlyActionError, qt.Equals, "not available in replay mode") +} + +func TestModelStepForwardUsesSamplesCollectedWhilePaused(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + + initial := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: initial}) + model = updated.(Model) + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeySpace}) + model = updated.(Model) + + for i, pid := range []int{11, 12} { + list := live.NewConnectionList(time.Unix(int64(101+i), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart) + next, _ := model.Update(listMsg{list: list}) + model = next.(Model) + } + + c.Assert(model.lastSuccessfulList.Connections[0].PID, qt.Equals, 10) + c.Assert(model.View(), qt.Contains, "step 1/3") + + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("]")}) + got := updated.(Model) + + c.Assert(got.following, qt.IsFalse) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 11) + c.Assert(got.View(), qt.Contains, "step 2/3") +} + +func TestModelWithCaptureHistoryDoesNotFetchOnInit(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + for _, pid := range []int{10, 11, 12} { + h.Push(live.NewConnectionList(time.Unix(int64(pid), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)) + } + client := &clientStub{list: live.NewConnectionList(time.Unix(99, 0), []live.Connection{{ + PID: 99, Instance: "primary", + }}, live.SortByTransactionStart)} + + model := NewModel(context.Background(), client, time.Second, 0).WithCaptureHistory(h) + + c.Assert(model.Init(), qt.IsNil) + c.Assert(client.calls, qt.Equals, 0) +} + +func TestModelWithCaptureHistoryDoesNotRefreshOnTick(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + for _, pid := range []int{10, 11, 12} { + h.Push(live.NewConnectionList(time.Unix(int64(pid), 0), []live.Connection{{ + PID: pid, Instance: "primary", + }}, live.SortByTransactionStart)) + } + client := &clientStub{list: live.NewConnectionList(time.Unix(99, 0), []live.Connection{{ + PID: 99, Instance: "primary", + }}, live.SortByTransactionStart)} + model := NewModel(context.Background(), client, time.Second, 0).WithCaptureHistory(h) + + updated, cmd := model.Update(tickMsg(time.Now())) + got := updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.loading, qt.IsFalse) + c.Assert(got.lastSuccessfulList.Connections[0].PID, qt.Equals, 12) + c.Assert(client.calls, qt.Equals, 0) +} + +func TestModelReplayFooterHidesRefreshAndCaptureActions(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + h.Push(live.NewConnectionList(time.Unix(10, 0), []live.Connection{{ + PID: 10, Instance: "primary", QueryText: "SELECT 1", + }}, live.SortByTransactionStart)) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithCaptureHistory(h). + WithReadOnlyActions("not available in replay mode") + + tableView := stripANSI(model.View()) + + c.Assert(tableView, qt.Contains, "[ ] { } step history") + c.Assert(tableView, qt.Contains, "enter/v detail") + c.Assert(tableView, qt.Not(qt.Contains), "r refresh") + c.Assert(tableView, qt.Not(qt.Contains), "C capture") + + updated, _ := model.Update(tea.KeyMsg{Type: tea.KeyEnter}) + detailView := stripANSI(updated.(Model).View()) + + c.Assert(detailView, qt.Contains, "q/esc back") + c.Assert(detailView, qt.Not(qt.Contains), "r refresh") + c.Assert(detailView, qt.Not(qt.Contains), "C capture") +} + +func TestModelEmptyReplayFrameAdvertisesStepHistoryWithoutRowActions(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + h.Push(live.NewConnectionList(time.Unix(10, 0), []live.Connection{{ + PID: 10, Instance: "primary", QueryText: "SELECT 1", + }}, live.SortByTransactionStart)) + h.Push(live.NewConnectionList(time.Unix(11, 0), nil, live.SortByTransactionStart)) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithCaptureHistory(h). + WithReadOnlyActions("not available in replay mode") + + tableView := stripANSI(model.View()) + + c.Assert(tableView, qt.Contains, "connections 0") + c.Assert(tableView, qt.Contains, "[ ] { } step history") + c.Assert(tableView, qt.Not(qt.Contains), "enter detail") + c.Assert(tableView, qt.Not(qt.Contains), "cancel query") + c.Assert(tableView, qt.Not(qt.Contains), "kill transaction") + c.Assert(tableView, qt.Not(qt.Contains), "force terminate") +} + +func TestModelStepHelpIsAdvertised(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + list := live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart) + updated, _ = updated.(Model).Update(listMsg{list: list}) + view := updated.(Model).View() + + c.Assert(view, qt.Contains, "step history") +} + +// Selection must follow the connection (by PID+instance) across a refresh that +// reorders/changes the list, not stay on a positional index. +func TestSelectionReanchorsByIdentityAcrossRefresh(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + // Sort by duration so order is deterministic by Duration desc. + first := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, Instance: "primary", Duration: 30 * time.Second}, + {PID: 20, Instance: "primary", Duration: 20 * time.Second}, + {PID: 30, Instance: "primary", Duration: 10 * time.Second}, + }, live.SortByDuration) + updated, _ := model.Update(listMsg{list: first}) + // Select the middle row (PID 20). + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + got := updated.(Model) + sel, ok := got.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 20) + + // New snapshot reorders: PID 20 is now at the TOP (longest duration). + second := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 20, Instance: "primary", Duration: 40 * time.Second}, + {PID: 10, Instance: "primary", Duration: 30 * time.Second}, + {PID: 30, Instance: "primary", Duration: 10 * time.Second}, + }, live.SortByDuration) + updated, _ = got.Update(listMsg{list: second}) + got = updated.(Model) + + sel, ok = got.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 20) // still PID 20, even though it moved from row 1 to row 0 +} + +func vitessModelWithSelectedIndex(t *testing.T, selected int, conns []live.Connection) Model { + t.Helper() + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithConnectionView(VitessConnectionView) + updated, _ := model.Update(listMsg{list: live.NewConnectionList(time.Unix(100, 0), conns, live.SortByDuration)}) + for i := 0; i < selected; i++ { + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyDown}) + } + return updated.(Model) +} + +func updateModelWithVitessList(t *testing.T, model Model, conns []live.Connection) Model { + t.Helper() + updated, _ := model.Update(listMsg{list: live.NewConnectionList(time.Unix(101, 0), conns, live.SortByDuration)}) + return updated.(Model) +} + +func TestVitessSelectionStopsFollowingQueryRowThatBecomesSleep(t *testing.T) { + c := qt.New(t) + model := vitessModelWithSelectedIndex(t, 2, []live.Connection{ + {PID: 10, Instance: "tablet", State: "Query/update", Duration: 4 * time.Second, QueryText: "INSERT 1"}, + {PID: 20, Instance: "tablet", State: "Query/update", Duration: 3 * time.Second, QueryText: "INSERT 2"}, + {PID: 30, Instance: "tablet", State: "Query/update", Duration: 2 * time.Second, QueryText: "INSERT 3"}, + {PID: 40, Instance: "tablet", State: "Sleep"}, + }) + sel, ok := model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 30) + + model = updateModelWithVitessList(t, model, []live.Connection{ + {PID: 50, Instance: "tablet", State: "Query/update", Duration: 5 * time.Second, QueryText: "INSERT 4"}, + {PID: 60, Instance: "tablet", State: "Query/update", Duration: 1 * time.Second, QueryText: "INSERT 5"}, + {PID: 30, Instance: "tablet", State: "Sleep"}, + {PID: 40, Instance: "tablet", State: "Sleep"}, + }) + + sel, ok = model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 60) +} + +func TestVitessSelectionKeepsSleepingIdentityThatBecomesActive(t *testing.T) { + c := qt.New(t) + model := vitessModelWithSelectedIndex(t, 2, []live.Connection{ + {PID: 10, Instance: "tablet", State: "Query/update", Duration: 2 * time.Second, QueryText: "INSERT 1"}, + {PID: 20, Instance: "tablet", State: "Sleep"}, + {PID: 30, Instance: "tablet", State: "Sleep"}, + }) + + model = updateModelWithVitessList(t, model, []live.Connection{ + {PID: 30, Instance: "tablet", State: "Query/update", Duration: 5 * time.Second, QueryText: "INSERT 2"}, + {PID: 20, Instance: "tablet", State: "Sleep"}, + }) + + sel, ok := model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 30) +} + +func TestVitessSelectionKeepsSleepingIdentityThatStaysSleep(t *testing.T) { + c := qt.New(t) + model := vitessModelWithSelectedIndex(t, 2, []live.Connection{ + {PID: 10, Instance: "tablet", State: "Query/update", Duration: 2 * time.Second, QueryText: "INSERT 1"}, + {PID: 20, Instance: "tablet", State: "Sleep"}, + {PID: 30, Instance: "tablet", State: "Sleep"}, + }) + + model = updateModelWithVitessList(t, model, []live.Connection{ + {PID: 40, Instance: "tablet", State: "Query/update", Duration: 3 * time.Second, QueryText: "INSERT 2"}, + {PID: 20, Instance: "tablet", State: "Sleep"}, + {PID: 30, Instance: "tablet", State: "Sleep"}, + }) + + sel, ok := model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 30) +} + +func TestVitessSelectionFallsBackToNearestActiveRowWhenSelectedConnectionEnds(t *testing.T) { + c := qt.New(t) + model := vitessModelWithSelectedIndex(t, 2, []live.Connection{ + {PID: 10, Instance: "tablet", State: "Query/update", Duration: 4 * time.Second, QueryText: "INSERT 1"}, + {PID: 20, Instance: "tablet", State: "Query/update", Duration: 3 * time.Second, QueryText: "INSERT 2"}, + {PID: 30, Instance: "tablet", State: "Query/update", Duration: 2 * time.Second, QueryText: "INSERT 3"}, + {PID: 40, Instance: "tablet", State: "Sleep"}, + }) + + model = updateModelWithVitessList(t, model, []live.Connection{ + {PID: 50, Instance: "tablet", State: "Query/update", Duration: 5 * time.Second, QueryText: "INSERT 4"}, + {PID: 60, Instance: "tablet", State: "Query/update", Duration: 1 * time.Second, QueryText: "INSERT 5"}, + {PID: 40, Instance: "tablet", State: "Sleep"}, + {PID: 70, Instance: "tablet", State: "Sleep"}, + }) + + sel, ok := model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 60) +} + +func TestVitessSelectionKeepsIndexWhenOnlySleepingRowsRemain(t *testing.T) { + c := qt.New(t) + model := vitessModelWithSelectedIndex(t, 2, []live.Connection{ + {PID: 10, Instance: "tablet", State: "Query/update", Duration: 4 * time.Second, QueryText: "INSERT 1"}, + {PID: 20, Instance: "tablet", State: "Query/update", Duration: 3 * time.Second, QueryText: "INSERT 2"}, + {PID: 30, Instance: "tablet", State: "Query/update", Duration: 2 * time.Second, QueryText: "INSERT 3"}, + }) + + model = updateModelWithVitessList(t, model, []live.Connection{ + {PID: 40, Instance: "tablet", State: "Sleep"}, + {PID: 50, Instance: "tablet", State: "Sleep"}, + {PID: 60, Instance: "tablet", State: "Sleep"}, + {PID: 70, Instance: "tablet", State: "Sleep"}, + }) + + sel, ok := model.selectedConnection() + c.Assert(ok, qt.IsTrue) + c.Assert(sel.PID, qt.Equals, 60) +} + +func TestRefreshDotReflectsLiveStateWhilePaused(t *testing.T) { + c := qt.New(t) + defer lipgloss.SetColorProfile(termenv.Ascii) + lipgloss.SetColorProfile(termenv.ANSI256) + prevBG := lipgloss.HasDarkBackground() + defer lipgloss.SetHasDarkBackground(prevBG) + lipgloss.SetHasDarkBackground(true) + + const ( + dimDot = "\x1b[38;5;240m●" + redDot = "\x1b[1;38;5;196m●" + ) + + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 180, Height: 24}) + model = updated.(Model) + // Two successful samples → history to step through, and a healthy dim dot. + for _, ts := range []int64{100, 101} { + list := live.NewConnectionList(time.Unix(ts, 0), []live.Connection{{PID: 10, State: "active"}}, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: list}) + model = updated.(Model) + } + c.Assert(model.View(), qt.Contains, dimDot) + + // Two consecutive failed fetches → sustained failure → red dot. + updated, _ = model.Update(listMsg{err: errors.New("503 from api-bb")}) + model = updated.(Model) + updated, _ = model.Update(listMsg{err: errors.New("503 from api-bb")}) + model = updated.(Model) + c.Assert(model.View(), qt.Contains, redDot) + + // Pause and step back to a healthy historical frame; the dot must stay red. + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeySpace}) + model = updated.(Model) + updated, _ = model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune{'['}}) + model = updated.(Model) + view := model.View() + c.Assert(view, qt.Contains, "paused") + c.Assert(view, qt.Contains, "step") + c.Assert(view, qt.Contains, redDot) + + // A successful refresh clears the streak → back to the dim idle dot. + list := live.NewConnectionList(time.Unix(102, 0), []live.Connection{{PID: 10, State: "active"}}, live.SortByTransactionStart) + updated, _ = model.Update(listMsg{list: list}) + c.Assert(updated.(Model).View(), qt.Contains, dimDot) +} + +func TestDetailRefreshIgnoredInReplay(t *testing.T) { + c := qt.New(t) + h := history.NewCaptureHistory(3) + h.Push(live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart)) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithCaptureHistory(h). + WithReadOnlyActions("not available in replay mode") + model.detailOpen = true + model.detailInstance = "primary" + model.detailPID = 10 + + updated, cmd := model.Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune("r")}) + got := updated.(Model) + + c.Assert(cmd, qt.IsNil) + c.Assert(got.loading, qt.IsFalse) + c.Assert(got.samples.Len(), qt.Equals, 1) +} + +func TestModelInstanceGoneMidSessionRewordsError(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ := model.Update(listMsg{list: live.NewConnectionList(time.Unix(100, 0), []live.Connection{{ + PID: 10, Instance: "primary", + }}, live.SortByTransactionStart)}) + model = updated.(Model) + + updated, _ = model.Update(listMsg{err: &live.UnknownInstanceError{Instance: "replica-1", Valid: []string{"primary"}}}) + got := updated.(Model) + c.Assert(got.lastError, qt.Equals, `instance "replica-1" is no longer in the branch's instance set`) + + fresh := NewModel(context.Background(), &clientStub{}, time.Second, 0) + updated, _ = fresh.Update(listMsg{err: &live.UnknownInstanceError{Instance: "replica-1", Valid: []string{"primary"}}}) + c.Assert(updated.(Model).lastError, qt.Contains, "unknown instance") +} + +func TestModelPreListPausedFooterHidesRefresh(t *testing.T) { + c := qt.New(t) + model := NewModel(context.Background(), &clientStub{}, time.Second, 0). + WithTarget(Target{Database: "prod", Branch: "main"}) + updated, _ := model.Update(tea.WindowSizeMsg{Width: 120, Height: 30}) + updated, _ = updated.(Model).Update(tea.KeyMsg{Type: tea.KeyRunes, Runes: []rune(" ")}) + view := stripANSI(updated.(Model).View()) + + c.Assert(view, qt.Contains, "space resume") + c.Assert(view, qt.Not(qt.Contains), "r refresh") +} diff --git a/internal/connections/tui/query_format.go b/internal/connections/tui/query_format.go new file mode 100644 index 00000000..1f742289 --- /dev/null +++ b/internal/connections/tui/query_format.go @@ -0,0 +1,134 @@ +package tui + +import ( + "regexp" + "strings" +) + +var ( + sqlClauseBoundary = regexp.MustCompile(`(?i)\s+(from|left\s+join|right\s+join|inner\s+join|outer\s+join|cross\s+join|full\s+join|join|where|group\s+by|having|order\s+by|limit|offset|returning|set|values|union\s+all|union|intersect|except)\s+`) + sqlBoolBoundary = regexp.MustCompile(`(?i)\s+(and|or)\s+`) + dollarQuote = regexp.MustCompile(`\$[A-Za-z_][A-Za-z0-9_]*\$|\$\$`) +) + +// This is light display formatting for pg_stat_activity text, not SQL parsing. +// It turns common generated one-liners such as +// "select id from t where owner_id = $1 order by id limit 100" into a short +// vertical scan, and leaves risky input such as quoted text, comments, +// parentheses, dollar quotes, and authored multiline line breaks alone. +// formatQueryForDisplay normalizes a query into display lines. Authored +// newlines are preserved, common one-line statements are split at major SQL +// clauses, and uncertain input falls back to the original sanitized text. +func formatQueryForDisplay(query string) []string { + raw := strings.TrimSpace(query) + if raw == "" { + return nil + } + if strings.Contains(raw, "\n") { + raw = strings.TrimSuffix(raw, "\n") + lines := strings.Split(raw, "\n") + for i, line := range lines { + lines[i] = sanitizeFooterText(line) + } + return lines + } + if !canFormatOneLineQuery(raw) { + return []string{sanitizeFooterText(raw)} + } + query = strings.Join(strings.Fields(sanitizeFooterText(raw)), " ") + return formatOneLineQuery(query) +} + +func canFormatOneLineQuery(query string) bool { + lower := strings.ToLower(query) + if strings.ContainsAny(query, "'\"`()") { + return false + } + if strings.Contains(lower, "--") || strings.Contains(lower, "#") || strings.Contains(lower, "/*") || strings.Contains(lower, "*/") { + return false + } + if strings.Contains(lower, " between ") || dollarQuote.MatchString(query) { + return false + } + return len(sqlClauseBoundary.FindAllStringSubmatchIndex(query, -1)) > 0 +} + +func formatOneLineQuery(query string) []string { + matches := sqlClauseBoundary.FindAllStringSubmatchIndex(query, -1) + if len(matches) == 0 { + return []string{query} + } + + lines := []string{} + first := strings.TrimSpace(query[:matches[0][0]]) + if first != "" { + lines = append(lines, first) + } + for i, match := range matches { + end := len(query) + if i+1 < len(matches) { + end = matches[i+1][0] + } + clause := strings.TrimSpace(query[match[2]:end]) + if strings.HasPrefix(strings.ToLower(clause), "where ") { + lines = append(lines, splitWhereClause(clause)...) + continue + } + lines = append(lines, clause) + } + return lines +} + +func splitWhereClause(clause string) []string { + space := strings.IndexByte(clause, ' ') + if space < 0 { + return []string{clause} + } + prefix := clause[:space] + body := strings.TrimSpace(clause[space+1:]) + matches := sqlBoolBoundary.FindAllStringSubmatchIndex(body, -1) + if len(matches) == 0 { + return []string{clause} + } + + lines := []string{prefix + " " + strings.TrimSpace(body[:matches[0][0]])} + for i, match := range matches { + end := len(body) + if i+1 < len(matches) { + end = matches[i+1][0] + } + operator := body[match[2]:match[3]] + condition := strings.TrimSpace(body[match[1]:end]) + lines = append(lines, " "+operator+" "+condition) + } + return lines +} + +// queryDisplayLines is the normalized-then-wrapped line set actually shown in +// the Query tab. Both the scroll clamp and the renderer derive their bounds +// from this exact slice, so they can never disagree about how many lines exist. +func queryDisplayLines(query string, width int) []string { + formatted := formatQueryForDisplay(query) + if len(formatted) == 0 { + return nil + } + wrapped := make([]string, 0, len(formatted)) + for _, line := range formatted { + if line == "" { + wrapped = append(wrapped, "") + continue + } + wrapped = append(wrapped, wrapLines(line, width)...) + } + return wrapped +} + +// maxQueryOffset is the largest scroll offset that still fills the viewport. +// bodyHeight includes the row reserved for the "query lines X-Y/Z" indicator, +// so the visible query rows are bodyHeight-1 and the last line is reachable. +func maxQueryOffset(totalLines, bodyHeight int) int { + if bodyHeight <= 1 || totalLines <= bodyHeight-1 { + return 0 + } + return totalLines - (bodyHeight - 1) +} diff --git a/internal/connections/tui/query_format_test.go b/internal/connections/tui/query_format_test.go new file mode 100644 index 00000000..0057966f --- /dev/null +++ b/internal/connections/tui/query_format_test.go @@ -0,0 +1,234 @@ +package tui + +import ( + "testing" + + qt "github.com/frankban/quicktest" +) + +func TestFormatQueryForDisplayFormatsCommonSingleLineQuery(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select id, account_id from invoices where account_id = $1 and status = $2 order by created_at desc limit 200") + + c.Assert(got, qt.DeepEquals, []string{ + "select id, account_id", + "from invoices", + "where account_id = $1", + " and status = $2", + "order by created_at desc", + "limit 200", + }) +} + +func TestFormatQueryForDisplayFormatsJoinQuery(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select invoices.id, accounts.name from invoices join accounts on accounts.id = invoices.account_id where invoices.status = $1 order by invoices.created_at desc limit 100") + + c.Assert(got, qt.DeepEquals, []string{ + "select invoices.id, accounts.name", + "from invoices", + "join accounts on accounts.id = invoices.account_id", + "where invoices.status = $1", + "order by invoices.created_at desc", + "limit 100", + }) +} + +func TestFormatQueryForDisplayFormatsDMLBoundaries(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("update users set name = $1 where id = $2 returning id") + + c.Assert(got, qt.DeepEquals, []string{ + "update users", + "set name = $1", + "where id = $2", + "returning id", + }) +} + +func TestFormatQueryForDisplayFormatsDeleteQuery(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("delete from sessions where expires_at < now") + + c.Assert(got, qt.DeepEquals, []string{ + "delete", + "from sessions", + "where expires_at < now", + }) +} + +func TestFormatQueryForDisplayFormatsUnionAllAndOffset(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select id from archived_events union all select id from events order by id offset 100 limit 50") + + c.Assert(got, qt.DeepEquals, []string{ + "select id", + "from archived_events", + "union all select id", + "from events", + "order by id", + "offset 100", + "limit 50", + }) +} + +func TestFormatQueryForDisplayMatchesClausesCaseInsensitively(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("SELECT id FROM invoices WHERE status = $1 AND archived = false ORDER BY id LIMIT 10") + + c.Assert(got, qt.DeepEquals, []string{ + "SELECT id", + "FROM invoices", + "WHERE status = $1", + " AND archived = false", + "ORDER BY id", + "LIMIT 10", + }) +} + +func TestFormatQueryForDisplayDoesNotSplitAndInsideQuotedLiteral(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 'research and development' from dual") + + c.Assert(got, qt.DeepEquals, []string{"select 'research and development' from dual"}) +} + +func TestFormatQueryForDisplayPreservesWhitespaceInsideQuotedLiteral(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 'a b' from t") + + c.Assert(got, qt.DeepEquals, []string{"select 'a b' from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitFromInsideQuotedLiteral(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select * from notes where note = 'copied from admin'") + + c.Assert(got, qt.DeepEquals, []string{"select * from notes where note = 'copied from admin'"}) +} + +func TestFormatQueryForDisplayDoesNotSplitInsideDoubleQuotedText(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay(`select "research and development" from dual`) + + c.Assert(got, qt.DeepEquals, []string{`select "research and development" from dual`}) +} + +func TestFormatQueryForDisplayDoesNotSplitInsideBacktickQuotedText(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select `from` from t") + + c.Assert(got, qt.DeepEquals, []string{"select `from` from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitInsideDollarQuotedText(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select $$ where and $$ as text from t") + + c.Assert(got, qt.DeepEquals, []string{"select $$ where and $$ as text from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitInsideTaggedDollarQuotedText(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select $tag$ from hidden $tag$ from t") + + c.Assert(got, qt.DeepEquals, []string{"select $tag$ from hidden $tag$ from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitClausesInsideLineComment(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1 -- from hidden where id = 1") + + c.Assert(got, qt.DeepEquals, []string{"select 1 -- from hidden where id = 1"}) +} + +func TestFormatQueryForDisplayPreservesWhitespaceInsideLineComment(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1 -- from hidden") + + c.Assert(got, qt.DeepEquals, []string{"select 1 -- from hidden"}) +} + +func TestFormatQueryForDisplayDoesNotSplitClausesInsideHashComment(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1 # from hidden where id = 1") + + c.Assert(got, qt.DeepEquals, []string{"select 1 # from hidden where id = 1"}) +} + +func TestFormatQueryForDisplayDoesNotSplitClausesInsideBlockComment(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1 /* from hidden */ where id = 1") + + c.Assert(got, qt.DeepEquals, []string{"select 1 /* from hidden */ where id = 1"}) +} + +func TestFormatQueryForDisplayDoesNotSplitClausesInsideFunctionExpression(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select extract(day from created_at) from t") + + c.Assert(got, qt.DeepEquals, []string{"select extract(day from created_at) from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitClausesInsideWindowExpression(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select row_number() over (order by created_at) from t") + + c.Assert(got, qt.DeepEquals, []string{"select row_number() over (order by created_at) from t"}) +} + +func TestFormatQueryForDisplayDoesNotSplitBetweenCondition(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select * from t where created_at between $1 and $2") + + c.Assert(got, qt.DeepEquals, []string{"select * from t where created_at between $1 and $2"}) +} + +func TestFormatQueryForDisplayDoesNotSplitNotBetweenCondition(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select * from t where created_at not between $1 and $2") + + c.Assert(got, qt.DeepEquals, []string{"select * from t where created_at not between $1 and $2"}) +} + +func TestFormatQueryForDisplayFallsBackForCTE(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("with recent as (select * from t) select * from recent where id = $1") + + c.Assert(got, qt.DeepEquals, []string{"with recent as (select * from t) select * from recent where id = $1"}) +} + +func TestFormatQueryForDisplayFallsBackForMultilineInput(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1\nselect 2") + + c.Assert(got, qt.DeepEquals, []string{"select 1", "select 2"}) +} + +func TestFormatQueryForDisplayPreservesMultilineIndentation(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1\n from t\nwhere a = 1") + + c.Assert(got, qt.DeepEquals, []string{"select 1", " from t", "where a = 1"}) +} + +func TestFormatQueryForDisplayDropsSingleTerminalNewline(t *testing.T) { + c := qt.New(t) + got := formatQueryForDisplay("select 1\n from t\n") + + c.Assert(got, qt.DeepEquals, []string{"select 1", " from t"}) +} + +func TestQueryDisplayLinesPreservesAuthoredBlankLines(t *testing.T) { + c := qt.New(t) + got := queryDisplayLines("select 1\n\nselect 2", 80) + + c.Assert(got, qt.DeepEquals, []string{"select 1", "", "select 2"}) +} + +func TestQueryDisplayLinesDropsSingleTerminalNewline(t *testing.T) { + c := qt.New(t) + got := queryDisplayLines("select 1\n from t\n", 80) + + c.Assert(got, qt.DeepEquals, []string{"select 1", " from t"}) +} diff --git a/internal/connections/tui/styles.go b/internal/connections/tui/styles.go new file mode 100644 index 00000000..b9a61eb0 --- /dev/null +++ b/internal/connections/tui/styles.go @@ -0,0 +1,103 @@ +package tui + +import "github.com/charmbracelet/lipgloss" + +const ( + colorHeaderBlue = "39" + colorHeaderBlueLight = "25" + colorSelectedForeground = "236" + colorSelectedForegroundLight = "236" + colorSelectedBackground = "253" + colorSelectedBackgroundLight = "254" + colorErrorRed = "196" + colorErrorRedLight = "160" + colorMutedGray = "245" + colorMutedGrayLight = "240" + colorStaleYellow = "214" + colorStaleYellowLight = "130" + colorActiveGreen = "40" + colorActiveGreenLight = "28" + colorIdleGray = "245" + colorIdleGrayLight = "240" + colorXactYellow = "214" + colorXactYellowLight = "130" + colorReplicaBlue = "117" // pastel blue; readable beside muted blocking-tier rows + colorReplicaBlueLight = "25" + colorPausedBackground = "229" + colorPausedBackgroundLight = "230" + colorPausedForeground = "236" + colorPausedForegroundLight = "236" + colorBlocksOneRowBackground = "254" // subtle hint + colorBlocksOneLight = "250" // subtle but visible on a white terminal + colorBlocksTwoRowBackground = "253" // noticeable step up + colorBlocksTwoLight = "248" // darker than the one-session light tier so the step-up reads on white + colorBlocksManyRowBackground = "222" // amber highlight distinct from gray tiers without dominating the screen during real incidents + colorBlocksManyLight = "179" + colorBlocksText = "236" + colorBlocksTextLight = "236" + colorBlockingCountBackground = "63" + colorBlockingCountLight = "230" + colorRefreshActive = "39" // cyan dot: a fetch is in flight or recently retried + colorRefreshActiveLight = "25" + colorRefreshIdle = "240" // dim dot when idle — fixed width, no header reflow + colorRefreshIdleLight = "250" + colorRefreshError = "196" // red dot: refresh is continuously failing (sustained 503s/timeouts) + colorRefreshErrorLight = "160" + colorHelpKey = "245" + colorHelpKeyLight = "240" + colorHelpDesc = "245" + colorHelpDescLight = "242" + colorHelpSeparator = "238" + colorHelpSeparatorLight = "247" +) + +var ( + headerStyle = lipgloss.NewStyle().Bold(true).Foreground(adaptiveColor(colorHeaderBlueLight, colorHeaderBlue)) + errorStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorErrorRedLight, colorErrorRed)) + mutedStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorMutedGrayLight, colorMutedGray)) + selectedRowStyle = lipgloss.NewStyle(). + Bold(true). + Foreground(adaptiveColor(colorSelectedForegroundLight, colorSelectedForeground)). + Background(adaptiveColor(colorSelectedBackgroundLight, colorSelectedBackground)) + replicaRowStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorReplicaBlueLight, colorReplicaBlue)) + staleStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorStaleYellowLight, colorStaleYellow)) + veryStaleStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorErrorRedLight, colorErrorRed)).Bold(true) + stateActiveStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorActiveGreenLight, colorActiveGreen)).Bold(true) + stateIdleStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorIdleGrayLight, colorIdleGray)) + stateXactStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorXactYellowLight, colorXactYellow)).Bold(true) + bannerStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorErrorRedLight, colorErrorRed)).Bold(true) + tabActiveStyle = lipgloss.NewStyle().Bold(true).Underline(true).Foreground(adaptiveColor(colorHeaderBlueLight, colorHeaderBlue)) + pausedStyle = lipgloss.NewStyle(). + Background(adaptiveColor(colorPausedBackgroundLight, colorPausedBackground)). + Foreground(adaptiveColor(colorPausedForegroundLight, colorPausedForeground)). + Bold(true) + + rowBlocksOneSessionStyle = lipgloss.NewStyle(). + Foreground(adaptiveColor(colorBlocksTextLight, colorBlocksText)). + Background(adaptiveColor(colorBlocksOneLight, colorBlocksOneRowBackground)) + rowBlocksTwoSessionsStyle = lipgloss.NewStyle(). + Foreground(adaptiveColor(colorBlocksTextLight, colorBlocksText)). + Background(adaptiveColor(colorBlocksTwoLight, colorBlocksTwoRowBackground)). + Bold(true) + rowBlocksManySessionsStyle = lipgloss.NewStyle(). + Foreground(adaptiveColor(colorBlocksTextLight, colorBlocksText)). + Background(adaptiveColor(colorBlocksManyLight, colorBlocksManyRowBackground)). + Bold(true) + // The blocking-count is a bold amber foreground digit (an alarm-family hue), + // not a filled background badge. A filled cell read as a terminal + // text-selection rectangle and the violet hue didn't map to "alarm" + // Amber-on-default keeps it legible on both terminal backgrounds. + blockingCountBadgeStyle = lipgloss.NewStyle(). + Foreground(adaptiveColor(colorStaleYellowLight, colorStaleYellow)). + Bold(true) + refreshActiveStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorRefreshActiveLight, colorRefreshActive)).Bold(true) + refreshIdleStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorRefreshIdleLight, colorRefreshIdle)) + refreshErrorStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorRefreshErrorLight, colorRefreshError)).Bold(true) + helpKeyStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorHelpKeyLight, colorHelpKey)) + helpDescStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorHelpDescLight, colorHelpDesc)) + helpSeparatorStyle = lipgloss.NewStyle().Foreground(adaptiveColor(colorHelpSeparatorLight, colorHelpSeparator)) +) + +func adaptiveColor(light, dark string) lipgloss.AdaptiveColor { + return lipgloss.AdaptiveColor{Light: light, Dark: dark} +} diff --git a/internal/connections/tui/styles_test.go b/internal/connections/tui/styles_test.go new file mode 100644 index 00000000..7aab5898 --- /dev/null +++ b/internal/connections/tui/styles_test.go @@ -0,0 +1,188 @@ +package tui + +import ( + "math" + "strconv" + "testing" + + "github.com/charmbracelet/lipgloss" + qt "github.com/frankban/quicktest" + "github.com/muesli/termenv" +) + +func TestStylesAdaptToLightBackgrounds(t *testing.T) { + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + tests := []struct { + name string + style lipgloss.Style + }{ + {name: "header", style: headerStyle}, + {name: "muted", style: mutedStyle}, + {name: "replica row", style: replicaRowStyle}, + {name: "stale", style: staleStyle}, + {name: "active state", style: stateActiveStyle}, + {name: "idle state", style: stateIdleStyle}, + {name: "transaction state", style: stateXactStyle}, + {name: "paused badge", style: pausedStyle}, + {name: "one blocker row", style: rowBlocksOneSessionStyle}, + {name: "two blocker row", style: rowBlocksTwoSessionsStyle}, + {name: "many blocker row", style: rowBlocksManySessionsStyle}, + {name: "blocking count badge", style: blockingCountBadgeStyle}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + lipgloss.SetHasDarkBackground(true) + dark := tt.style.Render("sample") + lipgloss.SetHasDarkBackground(false) + light := tt.style.Render("sample") + + c.Assert(stripANSI(light), qt.Equals, "sample") + c.Assert(light, qt.Not(qt.Equals), dark) + }) + } +} + +func TestSelectedRowHasBackgroundOnBothTerminalBackgrounds(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(true) + dark := selectedRowStyle.Render("sample") + lipgloss.SetHasDarkBackground(false) + light := selectedRowStyle.Render("sample") + + c.Assert(dark, qt.Contains, "48;5;") + c.Assert(dark, qt.Contains, "38;5;") + c.Assert(stripANSI(dark), qt.Equals, "sample") + c.Assert(light, qt.Contains, "48;5;") + c.Assert(light, qt.Contains, "38;5;") + c.Assert(stripANSI(light), qt.Equals, "sample") +} + +func TestDarkVariantHighlightsRemainStyledForLightFallback(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(true) + + for _, tt := range []struct { + name string + rendered string + plain string + }{ + {name: "paused", rendered: pausedStyle.Render("paused"), plain: "paused"}, + {name: "blocking", rendered: rowBlocksTwoSessionsStyle.Render("sample"), plain: "sample"}, + } { + t.Run(tt.name, func(t *testing.T) { + c.Assert(tt.rendered, qt.Contains, "38;5;") + c.Assert(tt.rendered, qt.Contains, "48;5;") + c.Assert(stripANSI(tt.rendered), qt.Equals, tt.plain) + }) + } +} + +func TestBlockingRowsSetForegroundForLightFallback(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(true) + + c.Assert(rowBlocksOneSessionStyle.Render("sample"), qt.Contains, "38;5;") + c.Assert(rowBlocksTwoSessionsStyle.Render("sample"), qt.Contains, "38;5;") + c.Assert(rowBlocksManySessionsStyle.Render("sample"), qt.Contains, "38;5;") +} + +func TestLightBlockingBackgroundsContrastWithWhiteSurface(t *testing.T) { + for _, tt := range []struct { + name string + color string + contrast float64 + }{ + {name: "one blocker", color: colorBlocksOneLight, contrast: 1.8}, + {name: "two blockers", color: colorBlocksTwoLight, contrast: 2.1}, + {name: "many blockers", color: colorBlocksManyLight, contrast: 2.0}, + } { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + c.Assert(contrastAgainstWhite(tt.color) >= tt.contrast, qt.IsTrue) + }) + } +} + +func TestOneSessionBlockerVisibleOnLightBackground(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(false) + + one := rowBlocksOneSessionStyle.Render("sample") + two := rowBlocksTwoSessionsStyle.Render("sample") + + c.Assert(stripANSI(one), qt.Equals, "sample") + c.Assert(one, qt.Contains, "48;5;") + c.Assert(one, qt.Not(qt.Equals), two) +} + +func contrastAgainstWhite(code string) float64 { + r, g, b := ansi256RGB(code) + l := relativeLuminance(r, g, b) + return 1.05 / (l + 0.05) +} + +func ansi256RGB(code string) (float64, float64, float64) { + n, err := strconv.Atoi(code) + if err != nil { + panic(err) + } + if n >= 232 { + gray := float64(8 + (n-232)*10) + return gray, gray, gray + } + if n >= 16 { + levels := []float64{0, 95, 135, 175, 215, 255} + n -= 16 + return levels[n/36], levels[(n/6)%6], levels[n%6] + } + basic := [16][3]float64{ + {0, 0, 0}, {128, 0, 0}, {0, 128, 0}, {128, 128, 0}, + {0, 0, 128}, {128, 0, 128}, {0, 128, 128}, {192, 192, 192}, + {128, 128, 128}, {255, 0, 0}, {0, 255, 0}, {255, 255, 0}, + {0, 0, 255}, {255, 0, 255}, {0, 255, 255}, {255, 255, 255}, + } + return basic[n][0], basic[n][1], basic[n][2] +} + +func relativeLuminance(r, g, b float64) float64 { + return 0.2126*linearizedColor(r) + 0.7152*linearizedColor(g) + 0.0722*linearizedColor(b) +} + +func linearizedColor(v float64) float64 { + v /= 255 + if v <= 0.03928 { + return v / 12.92 + } + return math.Pow((v+0.055)/1.055, 2.4) +} diff --git a/internal/connections/tui/table.go b/internal/connections/tui/table.go new file mode 100644 index 00000000..d30b1c4f --- /dev/null +++ b/internal/connections/tui/table.go @@ -0,0 +1,1355 @@ +package tui + +import ( + "fmt" + "strings" + "time" + + "github.com/charmbracelet/bubbles/help" + "github.com/charmbracelet/bubbles/key" + "github.com/charmbracelet/lipgloss" + lgtable "github.com/charmbracelet/lipgloss/table" + "github.com/charmbracelet/x/ansi" + live "github.com/planetscale/cli/internal/connections" +) + +type tableState struct { + List live.ConnectionList + HasList bool + Sort live.SortMode + CanSort bool + Selected int + ViewportStart int + Width int + Height int + Paused bool + Refresh refreshDotState + ReadOnlyActions bool + Replay bool + LastError string + AccessDenied bool + CanStepHistory bool + Notice string + CaptureStopped string // sticky reason when capture writer detaches on error; "" when capture is healthy or absent + CaptureStatus string + Confirm string + Now time.Time + Interval time.Duration + StepPos int // 1-based position when stepping; 0 means following live + StepTotal int // total samples held in history + Target Target + Filter string // active row filter chip, e.g. "filter: role=primary"; empty when none + DisplayPreset connectionDisplayPreset + Capabilities ConnectionCapabilities +} + +type freshnessTier int + +const ( + freshnessFresh freshnessTier = iota + freshnessStale + freshnessVeryStale +) + +func freshnessTierFor(captured, now time.Time, interval time.Duration) freshnessTier { + if interval <= 0 || captured.IsZero() { + return freshnessFresh + } + age := now.Sub(captured) + if age >= 10*interval { + return freshnessVeryStale + } + if age >= 3*interval { + return freshnessStale + } + return freshnessFresh +} + +func instanceRoleMarker(role string) string { + switch role { + case "primary": + return "" + case "replica": + return "R" + default: + return "" + } +} + +const ( + defaultTableWidth = 120 + queryPreviewLimit = 160 + viewportEdgeRowThreshold = 2 + startColumnMinWidth = 140 + processlistTabletMinWidth = 120 + // Width budget for the APP column on a 120-col terminal. + appColumnTarget = 14 + // Minimum rendered width of the BLOCK column. Matches the "BLOCK" header so + // the column never narrows below the header when downstream counts are + // single-digit (Y, N, Y 3 are otherwise 1–3 chars wide and would let + // lipgloss squeeze the column past the header). + blockColumnMinWidth = 5 + // Stable minimum widths for the STATE and WAIT columns. lipgloss sizes a + // column to its widest cell, so when a wide value (e.g. "idle/xact", + // "Lock/transactionid") appears or disappears between refreshes the column — + // and everything to its right — jitters horizontally. Padding each cell to a + // fixed minimum that fits the common values pins the column width so the live + // view stays stable across refreshes. Widths cover "idle/xact" (9) + // and "Lock/transactionid" (18); the rare wider values still expand. + stateColumnMinWidth = 9 + waitColumnMinWidth = 18 +) + +const ( + connectionColState = 2 + connectionColBlock = 3 +) + +type bindings struct { + Navigate key.Binding + Refresh key.Binding + Pause key.Binding + Capture key.Binding + Sort key.Binding + Step key.Binding + Detail key.Binding + Cancel key.Binding + TerminateTxn key.Binding + TerminateConn key.Binding + Help key.Binding + Quit key.Binding +} + +type actionLabels struct { + cancel string + terminateConn string +} + +func actionHelpLabels(display connectionDisplayPreset) actionLabels { + if display == connectionDisplayProcesslist { + return actionLabels{cancel: "KILL QUERY", terminateConn: "KILL"} + } + return actionLabels{cancel: "cancel query", terminateConn: "force terminate"} +} + +func (b bindings) ShortHelp() []key.Binding { + return b.ShortHelpFor(false, false, connectionDisplayDefault, DefaultConnectionCapabilities(), true, true, true, true, false) +} + +func (b bindings) ShortHelpFor(readOnlyActions, replay bool, display connectionDisplayPreset, support ConnectionCapabilities, canSort, hasList, canSelectRow, canStepHistory, paused bool) []key.Binding { + b = b.withActionLabels(display) + b = b.withPauseLabel(paused) + if !hasList { + help := []key.Binding{} + if !replay && !paused { + help = append(help, b.Refresh) + } + help = append(help, b.Pause) + return append(help, b.Help, b.Quit) + } + + help := []key.Binding{} + if canSelectRow { + help = append(help, b.Navigate) + } + if !replay && !paused { + help = append(help, b.Refresh) + } + help = append(help, b.Pause) + if !replay { + help = append(help, b.Capture) + } + if canSort { + help = append(help, b.Sort) + } + if canStepHistory { + help = append(help, b.Step) + } + if canSelectRow { + help = append(help, b.Detail) + } + if canSelectRow && !readOnlyActions { + support = support.effective() + if support.supports(actionCancelQuery) { + help = append(help, b.Cancel) + } + if support.supports(actionTerminateTxn) { + help = append(help, b.TerminateTxn) + } + if support.supports(actionTerminateConn) { + help = append(help, b.TerminateConn) + } + } + return append(help, b.Help, b.Quit) +} + +func (b bindings) withActionLabels(display connectionDisplayPreset) bindings { + labels := actionHelpLabels(display) + b.Cancel = key.NewBinding(key.WithKeys("c"), key.WithHelp("c", labels.cancel)) + b.TerminateConn = key.NewBinding(key.WithKeys("K"), key.WithHelp("shift+K", labels.terminateConn)) + return b +} + +func (b bindings) withPauseLabel(paused bool) bindings { + label := "pause" + if paused { + label = "resume" + } + b.Pause = key.NewBinding(key.WithKeys("space"), key.WithHelp("space", label)) + return b +} + +// renderTable returns the complete terminal frame for the table view. Sizes +// the connection-body section so total output never exceeds state.Height — +// otherwise an overflowing alt-screen frame pushes the header off the top. +func renderTable(state tableState) string { + headerLines := []string{renderHeader(state), ""} + var bannerLines []string + if banner := renderInstanceFailureBanner(state.List.Instances); banner != "" { + bannerLines = []string{banner, ""} + } + footerLines := strings.Split(renderFooter(state), "\n") + bodyAvail := state.Height - len(headerLines) - len(bannerLines) - len(footerLines) + + var bodyLines []string + switch { + case !state.HasList && state.AccessDenied: + bodyLines = []string{ + errorStyle.Render("you don't have permission to view live connections"), + mutedStyle.Render("production branches require the Analyst role or higher — ask an org admin, or use a development branch"), + "", + } + case !state.HasList && state.LastError != "": + bodyLines = []string{ + errorStyle.Render("unable to load live connections"), + errorStyle.Render(clipLine(sanitizeFooterText(state.LastError), tableWidth(state.Width))), + "", + } + case !state.HasList: + bodyLines = []string{mutedStyle.Render("loading live connections..."), ""} + case len(state.List.Connections) == 0: + bodyLines = []string{ + mutedStyle.Render("no live connections"), + mutedStyle.Render("new connections appear on the next refresh"), + "", + } + default: + bodyLines = strings.Split(renderConnectionTable(state, bodyAvail), "\n") + } + + if state.Height > 0 { + for len(bodyLines) < bodyAvail { + bodyLines = append(bodyLines, "") + } + if bodyAvail > 0 && len(bodyLines) > bodyAvail { + bodyLines = bodyLines[:bodyAvail] + } else if bodyAvail <= 0 { + bodyLines = nil + } + } + + all := make([]string, 0, len(headerLines)+len(bannerLines)+len(bodyLines)+len(footerLines)) + all = append(all, headerLines...) + all = append(all, bannerLines...) + all = append(all, bodyLines...) + all = append(all, footerLines...) + if state.Height > 0 && len(all) > state.Height { + all = all[:state.Height] + } + return strings.Join(all, "\n") +} + +// headerPart is one ` | `-delimited header token in the fitted summary line. +type headerPart struct { + text string +} + +type headerOptions struct { + showSort bool + showRecOff bool + compactCount bool + targetMode headerTargetMode + compactFresh bool +} + +type headerTargetMode int + +const ( + headerTargetFull headerTargetMode = iota + headerTargetCompactShard + headerTargetEllipsizeWithShard + headerTargetCompactFresh + headerTargetNoShard +) + +const minHeaderDatabaseWidth = 12 + +// renderHeader returns the summary line at the top of the table, fitted to the +// viewport width. The staged fallbacks compact the target, count, sort, and +// capture tokens before protected freshness, pause, and count context is lost. +func renderHeader(state tableState) string { + width := tableWidth(state.Width) + stages := []headerOptions{ + {showSort: true, showRecOff: true, targetMode: headerTargetFull}, + {showSort: true, targetMode: headerTargetFull}, + {targetMode: headerTargetFull}, + {compactCount: true, targetMode: headerTargetFull}, + {compactCount: true, targetMode: headerTargetCompactShard}, + {compactCount: true, targetMode: headerTargetEllipsizeWithShard}, + {compactCount: true, targetMode: headerTargetCompactFresh, compactFresh: true}, + {compactCount: true, targetMode: headerTargetNoShard, compactFresh: true}, + } + if state.CaptureStatus != "rec off" { + for i := range stages { + stages[i].showRecOff = true + } + } + for _, opts := range stages { + line := renderHeaderWithOptions(state, width, opts) + if width <= 0 || ansi.StringWidth(line) <= width { + return clipLine(line, width) + } + } + return clipLine(renderHeaderWithOptions(state, width, stages[len(stages)-1]), width) +} + +func renderHeaderWithOptions(state tableState, width int, opts headerOptions) string { + var parts []headerPart + if target := renderHeaderTarget(displayTarget(state.Target, state.List), width, opts.targetMode, state, opts); target != "" { + parts = append(parts, headerPart{text: target}) + } + parts = append(parts, + headerPart{text: connectionCountHeaderText(state, opts.compactCount)}, + ) + if state.Filter != "" { + parts = append(parts, headerPart{text: headerFilterText(state.Filter, width)}) + } + if opts.showSort { + parts = append(parts, headerPart{text: sortHeaderText(state)}) + } + if state.Paused { + parts = append(parts, headerPart{text: pausedStyle.Render("paused")}) + } + if state.StepPos > 0 { + parts = append(parts, headerPart{text: fmt.Sprintf("step %d/%d", state.StepPos, state.StepTotal)}) + } + if state.CaptureStopped != "" { + parts = append(parts, headerPart{text: errorStyle.Render(state.CaptureStopped)}) + } + if state.CaptureStatus != "" && (state.CaptureStatus != "rec off" || opts.showRecOff) { + parts = append(parts, headerPart{text: state.CaptureStatus}) + } + if token := renderCapturedHeaderToken(state, opts.compactFresh); token != "" { + parts = append(parts, headerPart{text: token}) + } + return joinHeaderParts(parts) +} + +func sortHeaderText(state tableState) string { + if !state.CanSort { + return fmt.Sprintf("sorted by %s", state.Sort) + } + return fmt.Sprintf("sort %s", state.Sort) +} + +func connectionCountText(state tableState) string { + if !state.HasList { + return "connections —" + } + return fmt.Sprintf("connections %d", len(state.List.Connections)) +} + +func connectionCountHeaderText(state tableState, compact bool) string { + if !compact { + return refreshIndicator(state.Refresh) + " " + connectionCountText(state) + } + if !state.HasList { + return refreshIndicator(state.Refresh) + " —" + } + return fmt.Sprintf("%s %d", refreshIndicator(state.Refresh), len(state.List.Connections)) +} + +// refreshDotState is the health of the live refresh, surfaced by the header dot. +type refreshDotState int + +const ( + refreshDotIdle refreshDotState = iota // dim: nothing pending, last refresh succeeded + refreshDotPending // cyan: a fetch is in flight, or one recently failed but not for long + refreshDotFailing // red: refresh is continuously failing (sustained 503s/timeouts) + refreshDotHidden // replay: there is no live refresh, so show no dot +) + +// refreshDotFailThreshold is the count of consecutive failed list fetches after +// which the indicator escalates from "recently retried" (cyan) to "continuously +// failing" (red). One failure still reads as a transient blip; two in a row +// reads as a real, ongoing outage. +const refreshDotFailThreshold = 2 + +// computeRefreshDot maps the live refresh state to a dot state. It depends only +// on the live fetch/error bookkeeping — never on the paused/stepped cursor — so +// the dot keeps reflecting the most recent refresh attempt while the operator +// holds or steps through history. +func computeRefreshDot(loading bool, consecutiveErrors int, replay bool) refreshDotState { + if replay { + return refreshDotHidden + } + if consecutiveErrors >= refreshDotFailThreshold { + return refreshDotFailing + } + if loading || consecutiveErrors > 0 { + return refreshDotPending + } + return refreshDotIdle +} + +// refreshIndicator returns a fixed-width status dot: cyan while a fetch is in +// flight or recently retried, red when the refresh is continuously failing, dim +// when idle, and a blank (no dot) in replay mode. The glyph is always one cell +// wide, so toggling it never reflows the header. +func refreshIndicator(state refreshDotState) string { + switch state { + case refreshDotHidden: + return " " + case refreshDotFailing: + return refreshErrorStyle.Render("●") + case refreshDotPending: + return refreshActiveStyle.Render("●") + default: + return refreshIdleStyle.Render("●") + } +} + +func headerFilterText(filter string, width int) string { + maxWidth := width / 4 + if maxWidth < 12 { + maxWidth = min(width, 12) + } + if maxWidth > 32 { + maxWidth = 32 + } + return clipLine(filter, maxWidth) +} + +func joinHeaderParts(parts []headerPart) string { + texts := make([]string, 0, len(parts)) + for _, p := range parts { + texts = append(texts, p.text) + } + return strings.Join(texts, " | ") +} + +func renderTarget(target Target) string { + var parts []string + if target.Database != "" { + parts = append(parts, target.Database) + } + if target.Branch != "" { + parts = append(parts, target.Branch) + } + if target.Keyspace != "" { + parts = append(parts, target.Keyspace) + } + if target.Shard != "" { + parts = append(parts, target.Shard) + } + return strings.Join(parts, " / ") +} + +func renderHeaderTarget(target Target, width int, mode headerTargetMode, state tableState, opts headerOptions) string { + if mode == headerTargetFull { + return renderTarget(target) + } + rendered := renderHeaderTargetText(target, mode, 0) + if mode != headerTargetEllipsizeWithShard && mode != headerTargetCompactFresh && mode != headerTargetNoShard { + return rendered + } + budget := headerTargetBudget(rendered, state, opts, width) + if budget <= 0 { + return "" + } + return renderHeaderTargetText(target, mode, budget) +} + +func renderHeaderTargetText(target Target, mode headerTargetMode, budget int) string { + keyspaceShard := compactKeyspaceShard(target) + includeKeyspaceShard := keyspaceShard != "" && mode != headerTargetNoShard + database := target.Database + if budget > 0 { + suffix := targetSuffixText(target.Branch, keyspaceShard, includeKeyspaceShard) + databaseBudget := budget + if suffix != "" && database != "" { + databaseBudget -= ansi.StringWidth(" / " + suffix) + } + if includeKeyspaceShard && database != "" && databaseBudget < minHeaderDatabaseWidth { + return renderHeaderTargetText(target, headerTargetCompactShard, 0) + } + if databaseBudget <= 0 { + if !includeKeyspaceShard && mode == headerTargetEllipsizeWithShard { + return renderTarget(target) + } + database = "" + if ansi.StringWidth(suffix) > budget { + return "" + } + } else { + database = clipLine(database, databaseBudget) + } + } + parts := []string{} + if database != "" { + parts = append(parts, database) + } + if target.Branch != "" { + parts = append(parts, target.Branch) + } + if includeKeyspaceShard { + parts = append(parts, keyspaceShard) + } + return strings.Join(parts, " / ") +} + +func compactKeyspaceShard(target Target) string { + switch { + case target.Keyspace != "" && target.Shard != "": + return target.Keyspace + "/" + target.Shard + case target.Keyspace != "": + return target.Keyspace + default: + return target.Shard + } +} + +func targetSuffixText(branch, keyspaceShard string, includeKeyspaceShard bool) string { + var parts []string + if branch != "" { + parts = append(parts, branch) + } + if includeKeyspaceShard && keyspaceShard != "" { + parts = append(parts, keyspaceShard) + } + return strings.Join(parts, " / ") +} + +func headerTargetBudget(target string, state tableState, opts headerOptions, width int) int { + if width <= 0 || target == "" { + return width + } + placeholder := "\x00" + line := renderHeaderWithTarget(state, opts, placeholder) + withoutTarget := strings.Replace(line, placeholder, "", 1) + return min(ansi.StringWidth(target), width-ansi.StringWidth(withoutTarget)) +} + +func renderHeaderWithTarget(state tableState, opts headerOptions, target string) string { + var parts []headerPart + if target != "" { + parts = append(parts, headerPart{text: target}) + } + parts = append(parts, headerPart{text: connectionCountHeaderText(state, opts.compactCount)}) + if state.Filter != "" { + parts = append(parts, headerPart{text: headerFilterText(state.Filter, tableWidth(state.Width))}) + } + if opts.showSort { + parts = append(parts, headerPart{text: sortHeaderText(state)}) + } + if state.Paused { + parts = append(parts, headerPart{text: pausedStyle.Render("paused")}) + } + if state.StepPos > 0 { + parts = append(parts, headerPart{text: fmt.Sprintf("step %d/%d", state.StepPos, state.StepTotal)}) + } + if state.CaptureStopped != "" { + parts = append(parts, headerPart{text: errorStyle.Render(state.CaptureStopped)}) + } + if state.CaptureStatus != "" && (state.CaptureStatus != "rec off" || opts.showRecOff) { + parts = append(parts, headerPart{text: state.CaptureStatus}) + } + if token := renderCapturedHeaderToken(state, opts.compactFresh); token != "" { + parts = append(parts, headerPart{text: token}) + } + return joinHeaderParts(parts) +} + +func displayTarget(target Target, list live.ConnectionList) Target { + if list.Topology == nil { + return target + } + if target.Keyspace == "" { + target.Keyspace = list.Topology.Keyspace + } + if target.Shard == "" { + target.Shard = list.Topology.Shard + } + return target +} + +func renderCapturedToken(state tableState) string { + return capturedToken(state.List.CapturedAt, state.Now, state.Interval) +} + +func renderCapturedHeaderToken(state tableState, compact bool) string { + if compact { + return capturedCompactToken(state.List.CapturedAt, state.Now, state.Interval) + } + return renderCapturedToken(state) +} + +func capturedCompactToken(captured, now time.Time, interval time.Duration) string { + if captured.IsZero() { + return "" + } + if now.IsZero() { + return "captured " + formatCapturedAbsolute(captured, now) + } + age := now.Sub(captured).Truncate(time.Second) + if age < 0 { + age = 0 + } + relative := fmt.Sprintf("(%ds)", int(age.Seconds())) + switch freshnessTierFor(captured, now, interval) { + case freshnessVeryStale: + relative = veryStaleStyle.Render(relative) + case freshnessStale: + relative = staleStyle.Render(relative) + } + return relative +} + +// capturedToken renders the absolute capture time plus a relative-age suffix +// that tints as the sample goes stale. Age is shown in every mode, including +// paused; the separate "paused" chip in the header communicates the freeze. +func capturedToken(captured, now time.Time, interval time.Duration) string { + if captured.IsZero() { + return "" + } + // Absolute timestamp stays unstyled so operators have a stable wall-clock + // reference; only the relative-age suffix tints to signal staleness. + absolute := "captured " + formatCapturedAbsolute(captured, now) + if now.IsZero() { + return absolute + } + age := now.Sub(captured).Truncate(time.Second) + if age < 0 { + age = 0 + } + relative := fmt.Sprintf("(%ds ago)", int(age.Seconds())) + switch freshnessTierFor(captured, now, interval) { + case freshnessVeryStale: + relative = veryStaleStyle.Render(relative) + case freshnessStale: + relative = staleStyle.Render(relative) + } + return absolute + " " + relative +} + +// formatCapturedAbsolute renders the absolute timestamp shown in the header. +// Same calendar day as `now`: time-only. Different day: include the date. The +// timezone is always the operator's local — operators looking at this header +// are debugging on their own clock. +func formatCapturedAbsolute(captured, now time.Time) string { + captured = captured.Local() + if !now.IsZero() && sameLocalDay(captured, now.Local()) { + return captured.Format("15:04:05") + } + return captured.Format("2006-01-02 15:04:05") +} + +func sameLocalDay(a, b time.Time) bool { + ay, am, ad := a.Date() + by, bm, bd := b.Date() + return ay == by && am == bm && ad == bd +} + +func renderInstanceFailureBanner(instances []live.InstanceMeta) string { + if len(instances) == 0 { + return "" + } + var failed []string + for _, inst := range instances { + if inst.Error != "" { + failed = append(failed, inst.ID) + } + } + if len(failed) == 0 { + return "" + } + return bannerStyle.Render(fmt.Sprintf( + "%d of %d instances unreachable: %s", + len(failed), len(instances), strings.Join(failed, ", "), + )) +} + +// renderFooter returns the command/status line at the bottom of the table. +func renderFooter(state tableState) string { + lines := []string{} + if selectedStatus := renderSelectedStatus(state); selectedStatus != "" { + lines = append(lines, selectedStatus) + } + + switch { + case state.Confirm != "": + // Confirmation prompt replaces the help line: appending would push the + // prompt past the terminal width on standard layouts, and only y/n/esc/ + // q/ctrl+c are valid keys while confirming anyway. + lines = append(lines, errorStyle.Render(clipLine(state.Confirm, tableWidth(state.Width)))) + default: + if state.LastError != "" && state.HasList { + lines = append(lines, errorStyle.Render(clipLine("error: "+sanitizeFooterText(state.LastError), tableWidth(state.Width)))) + } else if state.Notice != "" { + lines = append(lines, clipLine("status: "+state.Notice, tableWidth(state.Width))) + } + lines = append(lines, renderHelpFor(state.Width, state.ReadOnlyActions, state.Replay, state.DisplayPreset, state.Capabilities, state.CanSort, state.HasList, state.HasList && len(state.List.Connections) > 0, state.CanStepHistory, state.Paused)) + } + return strings.Join(lines, "\n") +} + +func renderSelectedStatus(state tableState) string { + if !state.HasList || state.Selected < 0 || state.Selected >= len(state.List.Connections) { + return "" + } + + conn := state.List.Connections[state.Selected] + status := fmt.Sprintf("selected pid %d", conn.PID) + query := strings.Join(strings.Fields(conn.QueryText), " ") + if query != "" { + status += " | " + query + } + return clipLine(status, tableWidth(state.Width)) +} + +// queryPreview returns the collapsed query text shown in the final table column. +func queryPreview(query string) string { + collapsed := strings.Join(strings.Fields(query), " ") + return truncateRunes(collapsed, queryPreviewLimit) +} + +// renderConnectionTable returns the visible slice of connection rows. bodyAvail +// is the terminal row count reserved for the table body (header row + rows). +func renderConnectionTable(state tableState, bodyAvail int) string { + width := tableWidth(state.Width) + connections := state.List.Connections + visibleRows := visibleRowCount(len(connections), bodyAvail) + start := viewportStartForSelection(state.ViewportStart, state.Selected, len(connections), visibleRows) + visible := visibleConnections(connections, start, visibleRows) + selectedInSlice := state.Selected - start + counts := live.BlockingCounts(connections) + headers, rows := buildConnectionRowsForDisplay(state.DisplayPreset, visible, counts, width, selectedInSlice) + if state.DisplayPreset == connectionDisplayProcesslist { + return renderProcesslistConnectionTable(headers, rows, visible, selectedInSlice, width) + } + + return lgtable.New(). + Border(lipgloss.HiddenBorder()). + BorderTop(false). + BorderBottom(false). + BorderLeft(false). + BorderRight(false). + BorderHeader(false). + BorderColumn(false). + Headers(headers...). + Rows(rows...). + Width(width). + Wrap(false). + StyleFunc(func(row, col int) lipgloss.Style { + if row == lgtable.HeaderRow { + return tableCellStyle(headerStyle, headers, col) + } + if row >= 0 && row < len(visible) { + conn := visible[row] + style := connectionRowStyle(conn, counts[conn.PID], row == selectedInSlice) + style = connectionColumnStyleForDisplay(state.DisplayPreset, style, conn, counts[conn.PID], col, headers) + return tableCellStyle(style, headers, col) + } + return lipgloss.NewStyle() + }). + Render() +} + +func renderProcesslistConnectionTable(headers []string, rows [][]string, connections []live.Connection, selectedInSlice int, width int) string { + if len(headers) == 0 { + return "" + } + columnWidths := processlistColumnWidths(headers, rows) + lines := []string{renderProcesslistRow(headers, headers, nil, headerStyle, columnWidths, width)} + for i, row := range rows { + style := lipgloss.NewStyle() + if i >= 0 && i < len(connections) { + conn := connections[i] + style = connectionRowStyle(conn, 0, i == selectedInSlice) + lines = append(lines, renderProcesslistRow(row, headers, &conn, style, columnWidths, width)) + continue + } + lines = append(lines, renderProcesslistRow(row, headers, nil, style, columnWidths, width)) + } + return strings.Join(lines, "\n") +} + +func processlistColumnWidths(headers []string, rows [][]string) []int { + widths := make([]int, len(headers)) + for i, header := range headers { + widths[i] = ansi.StringWidth(header) + } + for _, row := range rows { + for i, cell := range row { + if i < len(widths) && i < len(row)-1 { + widths[i] = max(widths[i], ansi.StringWidth(cell)) + } + } + } + return widths +} + +func renderProcesslistRow(cells []string, headers []string, conn *live.Connection, base lipgloss.Style, widths []int, width int) string { + var line strings.Builder + for i, cell := range cells { + style := base + if conn != nil && i < len(headers) && headers[i] == "STATE" { + style = processlistStateStyleFor(style, conn.State) + } + if i < len(cells)-1 && i < len(widths) { + cell = padCellToWidth(cell, widths[i]) + cell += processlistCellPadding(i) + } + line.WriteString(style.Render(cell)) + } + return clipLine(line.String(), width) +} + +func padCellToWidth(text string, width int) string { + if pad := width - ansi.StringWidth(text); pad > 0 { + text += strings.Repeat(" ", pad) + } + return text +} + +func processlistCellPadding(index int) string { + if index == 0 { + return " " + } + return " " +} + +func connectionRowStyle(conn live.Connection, blockCount int, selected bool) lipgloss.Style { + style := blockingRowStyle(blockCount) + if conn.InstanceRole == "replica" { + style = style.Inherit(replicaRowStyle) + } + if selected { + return style.Inherit(selectedRowStyle) + } + return style +} + +func tableCellStyle(style lipgloss.Style, headers []string, col int) lipgloss.Style { + if col < len(headers)-1 { + return style.PaddingRight(2) + } + return style +} + +func connectionColumnStyleForDisplay(display connectionDisplayPreset, style lipgloss.Style, conn live.Connection, blockCount int, col int, headers []string) lipgloss.Style { + if display == connectionDisplayProcesslist { + if col >= 0 && col < len(headers) && headers[col] == "STATE" { + return processlistStateStyleFor(style, conn.State) + } + return style + } + switch col { + case connectionColState: + return stateStyleFor(style, conn.State) + case connectionColBlock: + if blockCount > 0 { + return style.Foreground(adaptiveColor(colorStaleYellowLight, colorStaleYellow)).Bold(true) + } + if len(conn.BlockedBy) == 0 { + return style.Inherit(mutedStyle) + } + } + return style +} + +func stateStyleFor(style lipgloss.Style, state string) lipgloss.Style { + switch stateText(state) { + case "active": + return style.Foreground(adaptiveColor(colorActiveGreenLight, colorActiveGreen)).Bold(true) + case "idle": + return style.Foreground(adaptiveColor(colorIdleGrayLight, colorIdleGray)) + case "idle/xact", "idle/xact (aborted)": + return style.Foreground(adaptiveColor(colorXactYellowLight, colorXactYellow)).Bold(true) + default: + return style + } +} + +// buildConnectionRows returns the table headers and cell values for visible +// connections. +func buildConnectionRows(connections []live.Connection, counts map[int]int, width int, selectedInSlice int) ([]string, [][]string) { + return buildConnectionRowsForDisplay(connectionDisplayDefault, connections, counts, width, selectedInSlice) +} + +func buildConnectionRowsForDisplay(display connectionDisplayPreset, connections []live.Connection, counts map[int]int, width int, selectedInSlice int) ([]string, [][]string) { + if display == connectionDisplayProcesslist { + return buildProcesslistConnectionRows(connections, width, selectedInSlice) + } + + includeStart := width >= startColumnMinWidth + headers := []string{"", "PID", "STATE", "BLOCK", "WAIT", "DURATION", "APP"} + if includeStart { + headers = append(headers, "START") + } + headers = append(headers, "QUERY") + + rows := make([][]string, 0, len(connections)) + for i, conn := range connections { + marker := rowMarker(conn.InstanceRole, i == selectedInSlice) + row := []string{ + marker, + fmt.Sprint(conn.PID), + stateCell(conn.State), + blockedText(conn, counts[conn.PID]), + waitTextForWidth(conn, width), + formatDuration(conn.Duration), + appTextForWidth(conn.ApplicationName, width), + } + if includeStart { + row = append(row, formatTime(startTime(conn))) + } + row = append(row, queryPreview(conn.QueryText)) + rows = append(rows, row) + } + return headers, rows +} + +func buildProcesslistConnectionRows(connections []live.Connection, width int, selectedInSlice int) ([]string, [][]string) { + includeTablet := width >= processlistTabletMinWidth + headers := []string{"", "PID"} + if includeTablet { + headers = append(headers, "TABLET") + } + headers = append(headers, "STATE", "DURATION", "USER", "DB", "QUERY") + rows := make([][]string, 0, len(connections)) + for i, conn := range connections { + row := []string{ + processlistRowMarker(i == selectedInSlice), + fmt.Sprint(conn.PID), + } + if includeTablet { + row = append(row, emptyDash(conn.Instance)) + } + row = append(row, + emptyDash(processlistStateText(conn.State)), + processlistDuration(conn), + appText(conn.Username), + emptyDash(conn.DatabaseName), + queryPreview(conn.QueryText), + ) + rows = append(rows, row) + } + return headers, rows +} + +func rowMarker(role string, selected bool) string { + cursor := " " + if selected { + cursor = "▶" + } + roleMarker := instanceRoleMarker(role) + if roleMarker == "" { + roleMarker = " " + } + return cursor + roleMarker +} + +func processlistRowMarker(selected bool) string { + if selected { + return "▶" + } + return " " +} + +// stateText returns the abbreviated render of a connection's Postgres state. +// "idle in transaction" is the operator-critical state (holds locks), so the +// abbreviation must remain distinguishable from plain "idle". +func stateText(state string) string { + switch state { + case "idle in transaction": + return "idle/xact" + case "idle in transaction (aborted)": + return "idle/xact (aborted)" + default: + return state + } +} + +func stateCell(state string) string { + text := emptyDash(stateText(state)) + return padCell(text, stateColumnMinWidth) +} + +func processlistStateText(state string) string { + command, _, ok := strings.Cut(state, "/") + if ok { + return command + } + return state +} + +func processlistDuration(conn live.Connection) string { + if conn.Duration <= 0 && processlistConnectionHasWork(conn) { + return "00:00" + } + return formatDuration(conn.Duration) +} + +func processlistConnectionHasWork(conn live.Connection) bool { + state := strings.ToLower(strings.TrimSpace(processlistStateText(conn.State))) + if state == "sleep" || state == "idle" { + return false + } + return state != "" || strings.TrimSpace(conn.QueryText) != "" +} + +func processlistStateStyleFor(style lipgloss.Style, state string) lipgloss.Style { + switch strings.ToLower(processlistStateText(state)) { + case "query": + return style.Foreground(adaptiveColor(colorActiveGreenLight, colorActiveGreen)).Bold(true) + case "sleep": + return style.Foreground(adaptiveColor(colorIdleGrayLight, colorIdleGray)) + default: + return style + } +} + +// padCell right-pads a cell to a minimum display width, pinning the column so it +// doesn't jitter as values change between refreshes. lipgloss sizes columns to +// the widest cell, so a constant trailing pad is an effective minimum. +func padCell(text string, minWidth int) string { + if pad := minWidth - lipgloss.Width(text); pad > 0 { + text += strings.Repeat(" ", pad) + } + return text +} + +// waitText returns the combined wait-cause shown in the WAIT column. When the +// type and event are both populated, render "type/event" so operators see the +// discriminating suffix (e.g. "Lock/transactionid"). When only one half is +// populated, fall back to that half. +func waitText(conn live.Connection) string { + switch { + case conn.WaitEventType != "" && conn.WaitEvent != "": + return conn.WaitEventType + "/" + conn.WaitEvent + case conn.WaitEvent != "": + return conn.WaitEvent + case conn.WaitEventType != "": + return conn.WaitEventType + default: + return "-" + } +} + +func waitTextForWidth(conn live.Connection, width int) string { + wait := waitText(conn) + if width < 100 { + return clipLine(wait, 10) + } + return padCell(wait, waitColumnMinWidth) +} + +func appText(name string) string { + return appTextForWidth(name, 0) +} + +func appTextForWidth(name string, width int) string { + if name == "" { + return "-" + } + limit := appColumnTarget + if width >= 220 { + limit = 24 + } + runes := []rune(name) + if len(runes) <= limit { + return name + } + keep := limit - 1 + if keep > 0 && runes[keep-1] == '_' { + keep++ + } + return string(runes[:keep]) + "…" +} + +// blockedText renders the BLOCK column. The leading character carries the alarm +// signal: a digit if this connection is blocking N downstream sessions (the +// operationally critical case — root blockers show "3" not "N 3"), "W" if it's +// waiting on a lock with no downstream, "-" if quiet. A trailing " W" suffix +// indicates the connection is both blocking and waiting (chain victim). +// Right-padded so the rendered column never narrows below the BLOCK header +// width — lipgloss sizes columns to max cell width, so trailing whitespace +// here effectively sets a minimum. +// +// 3 W — blocking 3 downstream AND waiting on a lock +// 3 — blocking 3 downstream (root blocker) +// W — waiting on a lock, blocking nothing +// - — quiet +func blockedText(conn live.Connection, count int) string { + waiting := len(conn.BlockedBy) > 0 + var text string + switch { + case count > 0 && waiting: + text = fmt.Sprintf("%d W", count) + case count > 0: + text = fmt.Sprint(count) + case waiting: + text = "W" + default: + text = "-" + } + return padCell(text, blockColumnMinWidth) +} + +// blockingRowStyle shades rows that block downstream connections. +func blockingRowStyle(depth int) lipgloss.Style { + switch { + case depth >= 3: + return rowBlocksManySessionsStyle + case depth == 2: + return rowBlocksTwoSessionsStyle + case depth == 1: + return rowBlocksOneSessionStyle + default: + return lipgloss.NewStyle() + } +} + +// formatDuration returns the compact duration text shown in the DURATION column. +func formatDuration(d time.Duration) string { + if d <= 0 { + return "-" + } + d = d.Truncate(time.Second) + minutes := int(d.Minutes()) + seconds := int(d.Seconds()) % 60 + if minutes >= 60 { + hours := minutes / 60 + minutes = minutes % 60 + return fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + } + return fmt.Sprintf("%02d:%02d", minutes, seconds) +} + +// startTime returns the transaction start or query start shown in the START column. +func startTime(conn live.Connection) *time.Time { + if conn.XactStart != nil { + return conn.XactStart + } + return conn.QueryStart +} + +func formatTime(t *time.Time) string { + if t == nil { + return "-" + } + return t.Format("15:04:05") +} + +func emptyDash(value string) string { + if strings.TrimSpace(value) == "" { + return "-" + } + return value +} + +// visibleRowCount returns how many connection rows fit in bodyAvail terminal +// rows after reserving one row for the table header. Floors at 2 so navigation +// keeps some context even in tiny terminals. +func visibleRowCount(totalRows, bodyAvail int) int { + rows := bodyAvail - 1 + if rows < 2 { + rows = 2 + } + if totalRows > 0 && rows > totalRows { + rows = totalRows + } + return rows +} + +// bodyHeight returns the terminal rows available for the connection-body +// section (table header + rows). Mirrors the chrome accounting in renderTable +// so the model can clamp viewport scrolling to the same visible-row count the +// renderer uses. +func bodyHeight(state tableState) int { + headerLines := 2 + bannerLines := 0 + if renderInstanceFailureBanner(state.List.Instances) != "" { + bannerLines = 2 + } + footerLines := strings.Count(renderFooter(state), "\n") + 1 + return state.Height - headerLines - bannerLines - footerLines +} + +func visibleConnections(connections []live.Connection, start, height int) []live.Connection { + start = clampViewportStart(start, len(connections), height) + end := start + height + if end > len(connections) { + end = len(connections) + } + return connections[start:end] +} + +func viewportStartForSelection(currentStart, selected, totalRows, visibleRows int) int { + currentStart = clampViewportStart(currentStart, totalRows, visibleRows) + if totalRows <= visibleRows { + return 0 + } + selected = clampInt(selected, 0, totalRows-1) + row := selected - currentStart + if selected < currentStart || selected >= currentStart+visibleRows { + return centeredViewportStart(selected, totalRows, visibleRows) + } + if row < viewportEdgeRowThreshold && currentStart > 0 { + return centeredViewportStart(selected, totalRows, visibleRows) + } + if row >= visibleRows-viewportEdgeRowThreshold && currentStart < totalRows-visibleRows { + return centeredViewportStart(selected, totalRows, visibleRows) + } + return currentStart +} + +func centeredViewportStart(selected, totalRows, visibleRows int) int { + return clampViewportStart(selected-(visibleRows/2), totalRows, visibleRows) +} + +func clampViewportStart(start, totalRows, visibleRows int) int { + if totalRows <= 0 || visibleRows <= 0 || totalRows <= visibleRows { + return 0 + } + return clampInt(start, 0, totalRows-visibleRows) +} + +func clampInt(value, minValue, maxValue int) int { + if value < minValue { + return minValue + } + if value > maxValue { + return maxValue + } + return value +} + +func tableWidth(width int) int { + if width <= 0 { + return defaultTableWidth + } + return width +} + +func renderHelp(width int) string { + return renderHelpFor(width, false, false, connectionDisplayDefault, DefaultConnectionCapabilities(), true, true, true, true, false) +} + +func renderHelpFor(width int, readOnlyActions, replay bool, display connectionDisplayPreset, support ConnectionCapabilities, canSort, hasList, canSelectRow, canStepHistory, paused bool) string { + return renderWrappedHelp(defaultBindings().ShortHelpFor(readOnlyActions, replay, display, support, canSort, hasList, canSelectRow, canStepHistory, paused), " | ", tableWidth(width)) +} + +// renderWrappedHelp lays out key hints across as many lines as the width +// requires, wrapping rather than truncating so no hint vanishes off the right +// edge at narrow widths. +func renderWrappedHelp(bindings []key.Binding, separator string, width int) string { + helpModel := help.New() + helpModel.Width = width + helpModel.ShortSeparator = separator + helpModel.Styles.ShortKey = helpKeyStyle + helpModel.Styles.ShortDesc = helpDescStyle + helpModel.Styles.ShortSeparator = helpSeparatorStyle + rows := wrapHelpBindings(helpModel, bindings, width) + lines := make([]string, 0, len(rows)) + for _, row := range rows { + lines = append(lines, helpModel.ShortHelpView(row)) + } + return strings.Join(lines, "\n") +} + +func wrapHelpBindings(helpModel help.Model, bindings []key.Binding, width int) [][]key.Binding { + rows := [][]key.Binding{} + row := []key.Binding{} + + for _, binding := range bindings { + if !binding.Enabled() { + continue + } + + candidate := make([]key.Binding, 0, len(row)+1) + candidate = append(candidate, row...) + candidate = append(candidate, binding) + if len(row) > 0 && lipgloss.Width(renderUnwrappedHelp(helpModel, candidate)) > width { + rows = append(rows, row) + row = []key.Binding{binding} + continue + } + row = candidate + } + + if len(row) > 0 { + rows = append(rows, row) + } + return rows +} + +func renderUnwrappedHelp(helpModel help.Model, bindings []key.Binding) string { + helpModel.Width = 0 + return helpModel.ShortHelpView(bindings) +} + +func defaultBindings() bindings { + return bindings{ + Navigate: key.NewBinding(key.WithKeys("up", "down"), key.WithHelp("up/down", "select")), + Refresh: key.NewBinding(key.WithKeys("r"), key.WithHelp("r", "refresh")), + Pause: key.NewBinding(key.WithKeys("space"), key.WithHelp("space", "pause")), + Capture: key.NewBinding(key.WithKeys("C"), key.WithHelp("C", "capture")), + Sort: key.NewBinding(key.WithKeys("s"), key.WithHelp("s", "sort")), + Step: key.NewBinding(key.WithKeys("[", "]", "{", "}"), key.WithHelp("[ ] { }", "step history")), + Detail: key.NewBinding(key.WithKeys("enter", "v"), key.WithHelp("enter/v", "detail")), + Cancel: key.NewBinding(key.WithKeys("c"), key.WithHelp("c", "cancel query")), + TerminateTxn: key.NewBinding(key.WithKeys("k"), key.WithHelp("k", "kill transaction")), + TerminateConn: key.NewBinding(key.WithKeys("K"), key.WithHelp("shift+K", "force terminate")), + Help: key.NewBinding(key.WithKeys("?"), key.WithHelp("?", "help")), + Quit: key.NewBinding(key.WithKeys("q"), key.WithHelp("q", "quit")), + } +} + +// clipLine clips a rendered terminal line to the current viewport width. +// Width is measured in visible cells, so ANSI styling (e.g. a bold/underlined +// active tab) does not count against the budget or get truncated mid-escape. +func clipLine(line string, width int) string { + if width <= 0 { + return line + } + if ansi.StringWidth(line) <= width { + return line + } + if width <= 1 { + return ansi.Truncate(line, width, "") + } + return ansi.Truncate(line, width, "…") +} + +// truncateRunes clips user-controlled text without splitting UTF-8 runes. +func truncateRunes(value string, limit int) string { + if limit < 0 { + return "" + } + runes := []rune(value) + if len(runes) <= limit { + return value + } + return string(runes[:limit]) +} + +func sanitizeFooterText(value string) string { + return strings.Map(func(r rune) rune { + if r < 0x20 || r == 0x7f { + return -1 + } + return r + }, value) +} diff --git a/internal/connections/tui/table_test.go b/internal/connections/tui/table_test.go new file mode 100644 index 00000000..a05ac05d --- /dev/null +++ b/internal/connections/tui/table_test.go @@ -0,0 +1,1118 @@ +package tui + +import ( + "regexp" + "strings" + "testing" + "time" + + "github.com/charmbracelet/lipgloss" + "github.com/charmbracelet/x/ansi" + qt "github.com/frankban/quicktest" + "github.com/muesli/termenv" + live "github.com/planetscale/cli/internal/connections" +) + +var ansiRE = regexp.MustCompile(`\x1b\[[0-9;]*m`) + +var tableRenderTestTime = time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + +func stripANSI(s string) string { + return ansiRE.ReplaceAllString(s, "") +} + +func indexOf(slice []string, target string) int { + for i, s := range slice { + if s == target { + return i + } + } + return -1 +} + +func assertWidthAtMost(c *qt.C, value string, width int) { + c.Helper() + got := ansi.StringWidth(value) + c.Assert(got <= width, qt.IsTrue, qt.Commentf("width=%d limit=%d value=%q", got, width, value)) +} + +func TestRenderTableEmptyAndLoadingStates(t *testing.T) { + tests := []struct { + name string + state tableState + want string + }{ + { + name: "pre list", + state: tableState{Width: 100, Height: 6}, + want: "loading live connections...", + }, + { + name: "empty list", + state: tableState{ + HasList: true, + Width: 100, + Height: 6, + }, + want: "no live connections", + }, + { + name: "short body still shows rows", + state: tableState{ + List: live.NewConnectionList(tableRenderTestTime, []live.Connection{{PID: 1}, {PID: 2}, {PID: 3}}, live.SortByTransactionStart), + HasList: true, + Width: 100, + Height: 4, + }, + want: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + c.Assert(renderTable(tt.state), qt.Contains, tt.want) + }) + } +} + +func TestFooterHidesRowActionsOnEmptyList(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), nil, live.SortByTransactionStart) + rendered := stripANSI(renderTable(tableState{ + List: list, + HasList: true, + Width: 180, + Height: 8, + })) + + c.Assert(rendered, qt.Contains, "connections 0") + c.Assert(rendered, qt.Contains, "no live connections") + c.Assert(rendered, qt.Contains, "new connections appear on the next refresh") + c.Assert(rendered, qt.Not(qt.Contains), "enter detail") + c.Assert(rendered, qt.Not(qt.Contains), "cancel query") + c.Assert(rendered, qt.Not(qt.Contains), "kill transaction") + c.Assert(rendered, qt.Not(qt.Contains), "force terminate") +} + +func TestRenderTableRendersNonEmptyList(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + State: "active", + ApplicationName: "writer", + QueryText: "SELECT * FROM widgets", + }}, live.SortByTransactionStart) + + rendered := renderTable(tableState{ + List: list, + HasList: true, + Width: 180, + Height: 6, + }) + + c.Assert(rendered, qt.Contains, "connections 1") + c.Assert(rendered, qt.Contains, "PID") + c.Assert(rendered, qt.Contains, "SELECT * FROM widgets") +} + +func TestRenderTableInitialErrorDoesNotRepeatInFooter(t *testing.T) { + c := qt.New(t) + rendered := stripANSI(renderTable(tableState{ + LastError: "instances unreachable", + Width: 120, + Height: 8, + })) + + c.Assert(strings.Count(rendered, "instances unreachable"), qt.Equals, 1) + c.Assert(rendered, qt.Not(qt.Contains), "error: instances unreachable") +} + +func TestRenderTableFitsTerminalHeightWithBanner(t *testing.T) { + c := qt.New(t) + conns := make([]live.Connection, 0, 8) + for i := 0; i < 8; i++ { + conns = append(conns, live.Connection{PID: i + 1, QueryText: "SELECT 1"}) + } + list := live.NewConnectionList(time.Now(), conns, live.SortByTransactionStart) + list.Instances = []live.InstanceMeta{ + {ID: "primary", Role: "primary"}, + {ID: "replica-1", Role: "replica"}, + {ID: "replica-2", Role: "replica", Error: "timeout"}, + } + + rendered := renderTable(tableState{ + List: list, + HasList: true, + Width: 180, + Height: 14, + }) + lines := strings.Split(rendered, "\n") + + c.Assert(len(lines), qt.Equals, 14, qt.Commentf("rendered %d lines for height 14:\n%s", len(lines), rendered)) + c.Assert(lines[0], qt.Contains, "connections 8") + c.Assert(lines[len(lines)-1], qt.Contains, "q") + c.Assert(lines[len(lines)-1], qt.Contains, "quit") + c.Assert(rendered, qt.Not(qt.Contains), "role marker") + c.Assert(rendered, qt.Not(qt.Contains), "BLOCK: digit") +} + +func TestRenderTablePinsFooterToViewportBottom(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + State: "active", + QueryText: "SELECT * FROM widgets", + }}, live.SortByTransactionStart) + + rendered := renderTable(tableState{ + List: list, + HasList: true, + Width: 180, + Height: 10, + }) + lines := strings.Split(rendered, "\n") + + c.Assert(lines[len(lines)-1], qt.Contains, "q") + c.Assert(lines[len(lines)-1], qt.Contains, "quit") + c.Assert(rendered, qt.Not(qt.Contains), "role marker") + c.Assert(rendered, qt.Not(qt.Contains), "BLOCK: digit") +} + +func TestRenderFooterOmitsInstanceRoleLegend(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + got := renderFooter(tableState{List: list, HasList: true, Width: 60, Height: 24}) + + c.Assert(stripANSI(got), qt.Not(qt.Contains), "role marker") + c.Assert(stripANSI(got), qt.Not(qt.Contains), "W=waiting on lock") + for _, line := range strings.Split(stripANSI(got), "\n") { + c.Assert(lipgloss.Width(line) <= 60, qt.IsTrue) + } +} + +func TestRenderFooterStaysWithinVeryNarrowWidthWithoutLegend(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart) + got := renderFooter(tableState{List: list, HasList: true, Width: 8, Height: 24}) + + for _, line := range strings.Split(stripANSI(got), "\n") { + c.Assert(lipgloss.Width(line) <= 8, qt.IsTrue) + } +} + +func TestRenderTableShowsSelectedQueryStatus(t *testing.T) { + c := qt.New(t) + query := "SELECT id, owner_id FROM public.events WHERE owner_id = $1 ORDER BY id LIMIT 300" + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 10, QueryText: "SELECT 1"}, + {PID: 20, QueryText: query}, + }, live.SortByTransactionStart) + + rendered := renderTable(tableState{ + List: list, + HasList: true, + Selected: 1, + Width: 140, + Height: 10, + }) + + c.Assert(selectedStatusLine(c, rendered), qt.Equals, "selected pid 20 | "+query) +} + +func TestRenderTableClipsSelectedQueryStatus(t *testing.T) { + c := qt.New(t) + query := "SELECT id, owner_id, created_at FROM public.events WHERE owner_id = $1 ORDER BY created_at DESC LIMIT 300" + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 39123, + QueryText: query, + }}, live.SortByTransactionStart) + + rendered := renderTable(tableState{ + List: list, + HasList: true, + Selected: 0, + Width: 72, + Height: 10, + }) + status := selectedStatusLine(c, rendered) + + c.Assert(lipgloss.Width(status) <= 72, qt.IsTrue) + c.Assert(status, qt.Contains, "selected pid 39123 | SELECT id") + c.Assert(status, qt.Contains, "…") +} + +func TestBuildConnectionRowsUsesCoreColumns(t *testing.T) { + c := qt.New(t) + + headers, rows := buildConnectionRows([]live.Connection{{PID: 10}}, nil, 200, -1) + + c.Assert(headers, qt.DeepEquals, []string{"", "PID", "STATE", "BLOCK", "WAIT", "DURATION", "APP", "START", "QUERY"}) + c.Assert(rows, qt.HasLen, 1) + c.Assert(rows[0], qt.HasLen, len(headers)) +} + +func TestBuildConnectionRowsProcesslistDisplay(t *testing.T) { + c := qt.New(t) + headers, rows := buildConnectionRowsForDisplay(connectionDisplayProcesslist, []live.Connection{{ + Instance: "zone1-2001", + PID: 101, + State: "Query/executing", + Duration: 42 * time.Second, + Username: "vt_app", + DatabaseName: "checkout", + QueryText: "SELECT 1", + }}, nil, 120, 0) + + c.Assert(headers, qt.DeepEquals, []string{"", "PID", "TABLET", "STATE", "DURATION", "USER", "DB", "QUERY"}) + c.Assert(stripANSI(rows[0][1]), qt.Equals, "101") + c.Assert(stripANSI(rows[0][2]), qt.Equals, "zone1-2001") + c.Assert(stripANSI(rows[0][3]), qt.Equals, "Query") + c.Assert(stripANSI(rows[0][5]), qt.Equals, "vt_app") + c.Assert(stripANSI(rows[0][6]), qt.Equals, "checkout") + c.Assert(strings.Join(headers, ","), qt.Not(qt.Contains), "BLOCK") + c.Assert(strings.Join(headers, ","), qt.Not(qt.Contains), "WAIT") + c.Assert(strings.Join(headers, ","), qt.Not(qt.Contains), "APP") +} + +func TestBuildConnectionRowsProcesslistDisplayShowsZeroDurationActiveQuery(t *testing.T) { + c := qt.New(t) + headers, rows := buildConnectionRowsForDisplay(connectionDisplayProcesslist, []live.Connection{{ + PID: 101, + State: "Query/update", + QueryText: "INSERT INTO events VALUES (1)", + }}, nil, 120, 0) + + durationIdx := indexOf(headers, "DURATION") + c.Assert(durationIdx, qt.Not(qt.Equals), -1) + c.Assert(stripANSI(rows[0][durationIdx]), qt.Equals, "00:00") +} + +func TestBuildConnectionRowsWidensAppColumnWhenSpaceAllows(t *testing.T) { + c := qt.New(t) + + headers, rows := buildConnectionRows([]live.Connection{{ + PID: 10, + ApplicationName: "interactive_client_47", + QueryText: "SELECT 1", + }}, nil, 260, -1) + + appIdx := indexOf(headers, "APP") + c.Assert(appIdx, qt.Not(qt.Equals), -1) + c.Assert(stripANSI(rows[0][appIdx]), qt.Equals, "interactive_client_47") +} + +func TestRenderConnectionTableProcesslistDisplayStaysLeftAligned(t *testing.T) { + c := qt.New(t) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + Instance: "zone1-2001", + PID: 101, + State: "Query/executing", + Duration: 42 * time.Second, + Username: "vt_app", + DatabaseName: "checkout", + QueryText: "SELECT 1", + }}, live.SortByDuration) + + rendered := stripANSI(renderConnectionTable(tableState{ + List: list, + HasList: true, + Selected: 0, + Width: 200, + DisplayPreset: connectionDisplayProcesslist, + }, 10)) + row := lineContaining(c, rendered, "SELECT 1") + + c.Assert(strings.Index(row, "101") <= 4, qt.IsTrue, qt.Commentf("row = %q", row)) + c.Assert(strings.HasPrefix(row, "▶"), qt.IsTrue, qt.Commentf("row = %q", row)) +} + +func TestRenderConnectionTableProcesslistDisplayStylesStates(t *testing.T) { + c := qt.New(t) + prev := lipgloss.ColorProfile() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prev) + list := live.NewConnectionList(time.Now(), []live.Connection{ + {PID: 101, State: "Query/executing", QueryText: "SELECT 1"}, + {PID: 102, State: "Sleep"}, + }, live.SortByDuration) + + rendered := renderConnectionTable(tableState{ + List: list, + HasList: true, + Selected: -1, + Width: 120, + DisplayPreset: connectionDisplayProcesslist, + }, 10) + queryRow := lineContaining(c, rendered, "SELECT 1") + sleepRow := lineContaining(c, rendered, "102") + + c.Assert(queryRow, qt.Contains, "\x1b[") + c.Assert(sleepRow, qt.Contains, "\x1b[") +} + +func TestProcesslistHighlightedRowsRenderContiguousBackground(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(true) + list := live.NewConnectionList(tableRenderTestTime, []live.Connection{{ + PID: 101, + Instance: "zone1-2001", + State: "Query/update", + Username: "vt_app", + DatabaseName: "checkout", + QueryText: "INSERT INTO events VALUES (1)", + }}, live.SortByDuration) + + rendered := renderConnectionTable(tableState{ + List: list, + HasList: true, + Selected: 0, + Width: 160, + DisplayPreset: connectionDisplayProcesslist, + }, 10) + selectedRow := lineContaining(c, rendered, "INSERT INTO events") + + c.Assert(selectedRow, qt.Not(qt.Contains), "\x1b[0m \x1b[", qt.Commentf("row = %q", selectedRow)) +} + +func TestBuildConnectionRowsFormatsWaitAndStart(t *testing.T) { + c := qt.New(t) + xactStart := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + queryStart := time.Date(2026, 4, 29, 12, 1, 0, 0, time.UTC) + + _, rows := buildConnectionRows([]live.Connection{{ + PID: 10, + InstanceRole: "primary", + State: "active", + WaitEventType: "Lock", + WaitEvent: "tuple", + ApplicationName: "writer", + XactStart: &xactStart, + QueryStart: &queryStart, + QueryText: "SELECT * FROM widgets", + }}, nil, 200, -1) + + // STATE and WAIT are right-padded to a stable min width; trim for content + // comparison. + got := rows[0] + got[2] = strings.TrimRight(got[2], " ") + got[4] = strings.TrimRight(got[4], " ") + c.Assert(got, qt.DeepEquals, []string{ + " ", + "10", + "active", + "- ", + "Lock/tuple", + "-", + "writer", + "12:00:00", + "SELECT * FROM widgets", + }) +} + +func TestBuildConnectionRowsRoleMarkerLeftmostColumn(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 1, Instance: "primary", InstanceRole: "primary", State: "active"}, + {PID: 2, Instance: "replica-a", InstanceRole: "replica", State: "idle"}, + {PID: 3, Instance: "ghost", InstanceRole: "", State: "idle"}, + } + + _, rows := buildConnectionRows(connections, nil, 200, -1) + + c.Assert(rows[0][0], qt.Equals, " ") + c.Assert(rows[1][0], qt.Equals, " R") + c.Assert(rows[2][0], qt.Equals, " ") +} + +func TestBuildConnectionRowsWaitColumnCombinesTypeAndEvent(t *testing.T) { + c := qt.New(t) + + connections := []live.Connection{ + {PID: 1, WaitEventType: "Lock", WaitEvent: "transactionid"}, + {PID: 2, WaitEventType: "Client", WaitEvent: "ClientRead"}, + {PID: 3, WaitEvent: "ClientRead"}, + {PID: 4, WaitEventType: "IPC"}, + {PID: 5}, + } + headers, rows := buildConnectionRows(connections, nil, 200, -1) + + waitIdx := indexOf(headers, "WAIT") + c.Assert(waitIdx, qt.Not(qt.Equals), -1) + // WAIT cells are right-padded to a stable min width; compare trimmed. + c.Assert(strings.TrimRight(rows[0][waitIdx], " "), qt.Equals, "Lock/transactionid") + c.Assert(strings.TrimRight(rows[1][waitIdx], " "), qt.Equals, "Client/ClientRead") + c.Assert(strings.TrimRight(rows[2][waitIdx], " "), qt.Equals, "ClientRead") + c.Assert(strings.TrimRight(rows[3][waitIdx], " "), qt.Equals, "IPC") + c.Assert(strings.TrimRight(rows[4][waitIdx], " "), qt.Equals, "-") + + wait := waitTextForWidth(live.Connection{ + WaitEventType: "Client", + WaitEvent: "ClientRead", + }, 80) + c.Assert(strings.Contains(wait, "…"), qt.IsTrue) + c.Assert(strings.Contains(wait, "..."), qt.IsFalse) +} + +func TestBuildConnectionRowsNarrowKeepsWaitOverLowerPriorityColumns(t *testing.T) { + c := qt.New(t) + headers, rows := buildConnectionRows([]live.Connection{{ + PID: 10, + State: "idle in transaction", + Duration: 90 * time.Second, + WaitEventType: "Client", + WaitEvent: "ClientRead", + }}, nil, 80, -1) + + c.Assert(indexOf(headers, "STATE"), qt.Not(qt.Equals), -1) + c.Assert(indexOf(headers, "DURATION"), qt.Not(qt.Equals), -1) + c.Assert(indexOf(headers, "BLOCK"), qt.Not(qt.Equals), -1) + waitIdx := indexOf(headers, "WAIT") + c.Assert(waitIdx, qt.Not(qt.Equals), -1) + assertWidthAtMost(c, rows[0][waitIdx], 10) +} + +func TestStateTextAbbreviatesIdleInTransaction(t *testing.T) { + c := qt.New(t) + c.Assert(stateText("idle"), qt.Equals, "idle") + c.Assert(stateText("active"), qt.Equals, "active") + c.Assert(stateText("idle in transaction"), qt.Equals, "idle/xact") + c.Assert(stateText("idle in transaction (aborted)"), qt.Equals, "idle/xact (aborted)") + c.Assert(stateText(""), qt.Equals, "") +} + +func TestHighlightedRowsRenderContiguousBackground(t *testing.T) { + c := qt.New(t) + prevProfile := lipgloss.ColorProfile() + prevBackground := lipgloss.HasDarkBackground() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prevProfile) + defer lipgloss.SetHasDarkBackground(prevBackground) + + lipgloss.SetHasDarkBackground(true) + list := live.NewConnectionList(time.Now(), []live.Connection{{ + PID: 10, + QueryText: "SELECT 1", + }}, live.SortByTransactionStart) + + rendered := renderConnectionTable(tableState{List: list, HasList: true, Selected: 0, Width: 120}, 10) + selectedRow := lineContaining(c, rendered, "SELECT 1") + + c.Assert(selectedRow, qt.Not(qt.Contains), "\x1b[0m \x1b[") +} + +func lineContaining(c *qt.C, text, needle string) string { + c.Helper() + for _, line := range strings.Split(text, "\n") { + if strings.Contains(line, needle) { + return line + } + } + c.Fatalf("no line containing %q in:\n%s", needle, text) + return "" +} + +func TestAppTextRightTruncatesLongNames(t *testing.T) { + c := qt.New(t) + c.Assert(appText(""), qt.Equals, "-") + c.Assert(appText("psql"), qt.Equals, "psql") + c.Assert(appText("owner_writer_226"), qt.Equals, "owner_writer_2…") + c.Assert(appText("interactive_client_47"), qt.Equals, "interactive_c…") + c.Assert(appText("xxxxxxxxxxxxxx"), qt.Equals, "xxxxxxxxxxxxxx") +} + +func TestBuildConnectionRowsMarksSelectedRow(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{ + {PID: 10, InstanceRole: "primary"}, + {PID: 20, InstanceRole: "replica"}, + {PID: 30, InstanceRole: "unknown"}, + } + + headers, rows := buildConnectionRows(connections, nil, 200, 1) + + markerIdx := indexOf(headers, "") + pidIdx := indexOf(headers, "PID") + c.Assert(markerIdx, qt.Equals, 0) + c.Assert(pidIdx, qt.Equals, 1) + c.Assert(rows[0][markerIdx], qt.Equals, " ") + c.Assert(rows[0][pidIdx], qt.Equals, "10") + c.Assert(rows[1][markerIdx], qt.Equals, "▶R") + c.Assert(rows[1][pidIdx], qt.Equals, "20") + c.Assert(rows[2][markerIdx], qt.Equals, " ") + c.Assert(rows[2][pidIdx], qt.Equals, "30") + + _, rows = buildConnectionRows(connections, nil, 200, 2) + c.Assert(rows[2][markerIdx], qt.Equals, "▶ ") + c.Assert(rows[2][pidIdx], qt.Equals, "30") +} + +func TestBuildConnectionRowsSeparatesSelectedReplicaMarker(t *testing.T) { + c := qt.New(t) + _, rows := buildConnectionRows([]live.Connection{{PID: 10, InstanceRole: "replica"}}, nil, 120, 0) + + c.Assert(stripANSI(rows[0][0]), qt.Equals, "▶R") + c.Assert(lipgloss.Width(stripANSI(rows[0][0])), qt.Equals, 2) +} + +func TestBlockedTextRendersBlockedAndDownstream(t *testing.T) { + c := qt.New(t) + prev := lipgloss.ColorProfile() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prev) + + // Count-first encoding so the alarm digit leads. "W" appears alone for + // pure-victim sessions, as a suffix when the connection both blocks and + // waits. All cells right-padded to blockColumnMinWidth so the column + // renders at least as wide as the BLOCK header. + c.Assert(blockedText(live.Connection{}, 0), qt.Equals, "- ") + c.Assert(blockedText(live.Connection{BlockedBy: []int{99}}, 0), qt.Equals, "W ") + c.Assert(blockedText(live.Connection{}, 3), qt.Equals, "3 ") + c.Assert(blockedText(live.Connection{BlockedBy: []int{99}}, 3), qt.Equals, "3 W ") +} + +func TestBuildConnectionRowsDropsStartOnNarrowTerminals(t *testing.T) { + c := qt.New(t) + conns := []live.Connection{{PID: 1}} + + wideHeaders, _ := buildConnectionRows(conns, nil, 180, -1) + narrowHeaders, _ := buildConnectionRows(conns, nil, 120, -1) + + c.Assert(indexOf(wideHeaders, "START"), qt.Not(qt.Equals), -1) + c.Assert(indexOf(narrowHeaders, "START"), qt.Equals, -1) +} + +func TestRenderCapturedToken(t *testing.T) { + tests := []struct { + name string + state tableState + contains []string + omits []string + }{ + { + name: "same day omits date", + state: tableState{ + List: live.ConnectionList{CapturedAt: time.Date(2026, 5, 8, 21, 18, 14, 0, time.Local)}, + Now: time.Date(2026, 5, 8, 21, 18, 14, 0, time.Local), + Interval: time.Second, + }, + contains: []string{"21:18:14"}, + omits: []string{"2026-05-08"}, + }, + { + name: "different day includes date", + state: tableState{ + List: live.ConnectionList{CapturedAt: time.Date(2026, 5, 7, 21, 18, 14, 0, time.Local)}, + Now: time.Date(2026, 5, 8, 9, 0, 0, 0, time.Local), + Interval: time.Second, + }, + contains: []string{"2026-05-07 21:18:14"}, + }, + { + name: "paused keeps age", + state: tableState{ + List: live.ConnectionList{CapturedAt: tableRenderTestTime}, + Now: tableRenderTestTime.Add(30 * time.Second), + Interval: time.Second, + Paused: true, + }, + contains: []string{"captured ", "(30s ago)"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := qt.New(t) + got := stripANSI(renderCapturedToken(tt.state)) + for _, want := range tt.contains { + c.Assert(got, qt.Contains, want) + } + for _, notWant := range tt.omits { + c.Assert(got, qt.Not(qt.Contains), notWant) + } + }) + } +} + +func TestRenderHeaderStylesPausedToken(t *testing.T) { + c := qt.New(t) + prev := lipgloss.ColorProfile() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prev) + + got := renderHeader(tableState{Paused: true, Sort: live.SortByTransactionStart}) + + c.Assert(stripANSI(got), qt.Contains, "paused") + c.Assert(got, qt.Not(qt.Equals), stripANSI(got)) +} + +func TestRenderHeaderKeepsFreshnessAndPausedStateWhenNarrow(t *testing.T) { + c := qt.New(t) + + now := tableRenderTestTime.Add(5 * time.Second) + state := tableState{ + Target: Target{ + Database: "kind-live-connections", + Branch: "main", + }, + List: live.NewConnectionList(tableRenderTestTime, []live.Connection{{PID: 1}, {PID: 2}, {PID: 3}, {PID: 4}}, live.SortByTransactionStart), + HasList: true, + Sort: live.SortByTransactionStart, + Now: now, + Interval: time.Second, + CaptureStatus: "rec off", + Paused: true, + StepPos: 237, + StepTotal: 300, + Width: 80, + } + + header := stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, 80) + c.Assert(strings.Contains(header, "captured "), qt.IsTrue) + c.Assert(strings.Contains(header, "ago)"), qt.IsTrue) + c.Assert(strings.Contains(header, "paused"), qt.IsTrue) + c.Assert(strings.Contains(header, "step 237/300"), qt.IsTrue) + c.Assert(strings.Contains(header, "…"), qt.IsTrue) +} + +func TestRenderHeaderKeepsCompactVitessTargetWhenNarrow(t *testing.T) { + c := qt.New(t) + captured := time.Date(2026, 4, 29, 1, 3, 9, 0, time.Local) + state := tableState{ + Target: Target{ + Database: "kind-live-connections-mysql", + Branch: "main", + Keyspace: "planetscale", + Shard: "-", + }, + List: live.NewConnectionList(captured, make([]live.Connection, 14), live.SortByDuration), + HasList: true, + Sort: live.SortByDuration, + CanSort: false, + Now: captured, + Interval: time.Second, + CaptureStatus: "rec off", + Width: 80, + } + + header := stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, 80) + c.Assert(header, qt.Equals, "kind-live-connection… / main / planetscale/- | ● 14 | captured 01:03:09 (0s ago)") + + state.Width = 40 + header = stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, 40) + c.Assert(header, qt.Contains, " / main | ● 14 | (0s)") + c.Assert(header, qt.Contains, "…") + c.Assert(header, qt.Not(qt.Contains), "planetscale") +} + +func TestRenderHeaderVeryNarrowKeepsProtectedTokens(t *testing.T) { + c := qt.New(t) + captured := time.Date(2026, 4, 29, 1, 3, 9, 0, time.Local) + state := tableState{ + Target: Target{ + Database: "kind-live-connections-mysql", + Branch: "very-long-branch-name", + Keyspace: "planetscale", + Shard: "-", + }, + List: live.NewConnectionList(captured, make([]live.Connection, 14), live.SortByDuration), + HasList: true, + Sort: live.SortByDuration, + CanSort: false, + Now: captured, + Interval: time.Second, + Paused: true, + Width: 24, + } + + header := stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, 24) + c.Assert(header, qt.Contains, "● 14") + c.Assert(header, qt.Contains, "paused") + c.Assert(header, qt.Contains, "(0s)") +} + +func TestRenderHeaderNarrowPostgresKeepsTarget(t *testing.T) { + c := qt.New(t) + captured := time.Date(2026, 4, 29, 1, 3, 9, 0, time.Local) + state := tableState{ + Target: Target{ + Database: "kind-live-connections-postgres", + Branch: "main", + }, + List: live.NewConnectionList(captured, []live.Connection{{PID: 1}, {PID: 2}}, live.SortByTransactionStart), + HasList: true, + Sort: live.SortByTransactionStart, + CanSort: true, + Now: captured, + Interval: time.Second, + CaptureStatus: "rec off", + } + + for _, width := range []int{40, 80, 120} { + state.Width = width + header := stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, width) + c.Assert(header, qt.Contains, " / main") + if width < 120 { + c.Assert(header, qt.Contains, "● 2") + } else { + c.Assert(header, qt.Contains, "● connections 2") + } + c.Assert(header, qt.Contains, "(0s") + } +} + +func TestRenderHeaderLabelsSingleSortAsStatic(t *testing.T) { + c := qt.New(t) + state := tableState{ + List: live.NewConnectionList(tableRenderTestTime, []live.Connection{{PID: 1}}, live.SortByDuration), + HasList: true, + Sort: live.SortByDuration, + CanSort: false, + Width: 120, + } + + header := stripANSI(renderHeader(state)) + + c.Assert(header, qt.Contains, "sorted by duration") + c.Assert(header, qt.Not(qt.Contains), "sort duration") +} + +func TestRenderHeaderBoundsLongFilterChipWhenNarrow(t *testing.T) { + c := qt.New(t) + + now := tableRenderTestTime.Add(5 * time.Second) + state := tableState{ + List: live.NewConnectionList(tableRenderTestTime, []live.Connection{{PID: 1}}, live.SortByTransactionStart), + HasList: true, + Sort: live.SortByTransactionStart, + Now: now, + Interval: time.Second, + Filter: "filter: instance qa-replica-with-a-very-long-generated-name", + Paused: true, + StepPos: 237, + StepTotal: 300, + Width: 80, + } + + header := stripANSI(renderHeader(state)) + + assertWidthAtMost(c, header, 80) + c.Assert(strings.Contains(header, "filter:"), qt.IsTrue) + c.Assert(strings.Contains(header, "captured "), qt.IsTrue) + c.Assert(strings.Contains(header, "paused"), qt.IsTrue) + c.Assert(strings.Contains(header, "step 237/300"), qt.IsTrue) + c.Assert(strings.Contains(header, "…"), qt.IsTrue) +} + +func TestFreshnessTierFor(t *testing.T) { + c := qt.New(t) + captured := time.Date(2026, 4, 29, 12, 0, 0, 0, time.UTC) + interval := time.Second + + c.Assert(freshnessTierFor(captured, captured.Add(2*time.Second), interval), qt.Equals, freshnessFresh) + c.Assert(freshnessTierFor(captured, captured.Add(3*time.Second-time.Millisecond), interval), qt.Equals, freshnessFresh) + c.Assert(freshnessTierFor(captured, captured.Add(3*time.Second), interval), qt.Equals, freshnessStale) + c.Assert(freshnessTierFor(captured, captured.Add(9*time.Second), interval), qt.Equals, freshnessStale) + c.Assert(freshnessTierFor(captured, captured.Add(10*time.Second), interval), qt.Equals, freshnessVeryStale) + c.Assert(freshnessTierFor(time.Time{}, captured, interval), qt.Equals, freshnessFresh) + c.Assert(freshnessTierFor(captured, captured.Add(time.Hour), 0), qt.Equals, freshnessFresh) +} + +func TestRenderHelpIncludesCurrentBindings(t *testing.T) { + c := qt.New(t) + + rendered := renderHelp(180) + + for _, binding := range defaultBindings().ShortHelp() { + help := binding.Help() + c.Assert(rendered, qt.Contains, help.Key) + c.Assert(rendered, qt.Contains, help.Desc) + } + + bindings := defaultBindings() + c.Assert(bindings.Navigate.Help().Desc, qt.Equals, "select") + c.Assert(bindings.Cancel.Help().Desc, qt.Equals, "cancel query") + c.Assert(bindings.TerminateTxn.Help().Desc, qt.Equals, "kill transaction") + c.Assert(bindings.TerminateConn.Help().Desc, qt.Equals, "force terminate") +} + +func TestRenderHelpDocumentsVDetailAlias(t *testing.T) { + c := qt.New(t) + + footer := stripANSI(renderHelp(180)) + modal := stripANSI(renderHelpModal(helpState{ + Target: Target{Database: "prod", Branch: "main"}, + Width: 120, + Height: 40, + CanSort: true, + Capabilities: DefaultConnectionCapabilities(), + })) + + c.Assert(footer, qt.Contains, "enter/v detail") + c.Assert(modal, qt.Contains, "enter/v detail") +} + +func TestRenderHelpModalHidesRefreshWhilePaused(t *testing.T) { + c := qt.New(t) + modal := stripANSI(renderHelpModal(helpState{ + Target: Target{Database: "prod", Branch: "main"}, + Width: 120, + Height: 40, + Paused: true, + Capabilities: DefaultConnectionCapabilities(), + })) + + c.Assert(modal, qt.Contains, "space resume") + c.Assert(modal, qt.Not(qt.Contains), "r refresh") +} + +func TestRenderHelpShowsShiftKForForceTerminate(t *testing.T) { + c := qt.New(t) + footer := stripANSI(renderHelp(160)) + + c.Assert(footer, qt.Contains, "shift+K force terminate") +} + +func TestRenderHelpHidesUnavailableActions(t *testing.T) { + c := qt.New(t) + footer := stripANSI(renderHelpFor(160, false, false, connectionDisplayProcesslist, ConnectionCapabilities{ + CancelQuery: ActionTargetQueryID, + TerminateConnection: ActionTargetConnectionID, + ShowBlockers: false, + }, true, true, true, true, false)) + + c.Assert(footer, qt.Contains, "c KILL QUERY") + c.Assert(footer, qt.Contains, "shift+K KILL") + c.Assert(footer, qt.Not(qt.Contains), "kill transaction") +} + +func TestRenderHelpUsesOperatorActionCopy(t *testing.T) { + c := qt.New(t) + + help := stripANSI(renderHelpModal(helpState{ + Target: Target{Database: "prod", Branch: "main"}, + Width: 120, + Height: 40, + CanSort: true, + Capabilities: DefaultConnectionCapabilities(), + })) + + c.Assert(help, qt.Contains, "Actions") + c.Assert(help, qt.Not(qt.Contains), "Safe Actions") + c.Assert(help, qt.Contains, "c Cancel the selected query (pg_cancel_backend)") + c.Assert(help, qt.Contains, "k Kill the selected transaction (pg_terminate_backend)") + c.Assert(help, qt.Contains, "K Force terminate the selected connection (pg_terminate_backend)") + c.Assert(help, qt.Contains, "c, k, and K require confirmation. Replay mode blocks backend actions.") +} + +func TestVisibleConnectionsLimitsRowsFromTop(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{{PID: 1}, {PID: 2}, {PID: 3}, {PID: 4}, {PID: 5}} + + visible := visibleConnections(connections, 0, 3) + + c.Assert(visible, qt.DeepEquals, []live.Connection{{PID: 1}, {PID: 2}, {PID: 3}}) +} + +func TestViewportStartForSelectionCentersNearViewportEdges(t *testing.T) { + c := qt.New(t) + connections := []live.Connection{{PID: 1}, {PID: 2}, {PID: 3}, {PID: 4}, {PID: 5}} + start := viewportStartForSelection(0, 2, len(connections), 3) + + c.Assert(visibleConnections(connections, start, 3), qt.DeepEquals, []live.Connection{{PID: 2}, {PID: 3}, {PID: 4}}) +} + +func TestFormatDuration(t *testing.T) { + c := qt.New(t) + + c.Assert(formatDuration(0), qt.Equals, "-") + c.Assert(formatDuration(250*time.Millisecond), qt.Equals, "00:00") + c.Assert(formatDuration(1500*time.Millisecond), qt.Equals, "00:01") + c.Assert(formatDuration(90*time.Second), qt.Equals, "01:30") +} + +func TestEmptyDash(t *testing.T) { + c := qt.New(t) + + c.Assert(emptyDash(" \t "), qt.Equals, "-") + c.Assert(emptyDash("app"), qt.Equals, "app") +} + +func TestClipLine(t *testing.T) { + c := qt.New(t) + + c.Assert(clipLine("abcdef", 0), qt.Equals, "abcdef") + c.Assert(clipLine("abcdef", 1), qt.Equals, "a") + c.Assert(clipLine("abcdef", 3), qt.Equals, "ab…") + c.Assert(clipLine("abcdef", 5), qt.Equals, "abcd…") + + clipped := clipLine("the quick brown fox jumps", 12) + c.Assert(strings.Contains(clipped, "…"), qt.IsTrue) + c.Assert(strings.Contains(clipped, "..."), qt.IsFalse) +} + +func TestRenderDetailTabsHaveEqualWidths(t *testing.T) { + c := qt.New(t) + + prev := lipgloss.ColorProfile() + lipgloss.SetColorProfile(termenv.ANSI256) + defer lipgloss.SetColorProfile(prev) + + query := renderDetailTabs(tabQuery, DefaultConnectionCapabilities()) + blockers := renderDetailTabs(tabBlockers, DefaultConnectionCapabilities()) + c.Assert(ansi.StringWidth(blockers), qt.Equals, ansi.StringWidth(query)) + c.Assert(ansi.StringWidth(clipLine(blockers, 200)), qt.Equals, ansi.StringWidth(blockers)) + c.Assert(ansi.StringWidth(clipLine(query, 200)), qt.Equals, ansi.StringWidth(query)) +} + +func TestRenderFooterShowsConfirmPrompt(t *testing.T) { + c := qt.New(t) + state := tableState{ + HasList: true, + Width: 180, + Height: 24, + Interval: time.Second, + Confirm: "Force terminate PID 10 on primary? (y/n)", + } + got := renderFooter(state) + c.Assert(got, qt.Contains, "Force terminate PID 10") +} + +func TestRenderFooterKeepsLongErrorVisible(t *testing.T) { + c := qt.New(t) + state := tableState{ + List: live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart), + HasList: true, + Selected: -1, + Width: 80, + Height: 24, + Interval: time.Second, + LastError: "list connections: server is warming up, please retry in a moment", + } + + got := renderFooter(state) + lines := strings.Split(stripANSI(got), "\n") + + c.Assert(lines[0], qt.Equals, "error: list connections: server is warming up, please retry in a moment") + c.Assert(lipgloss.Width(lines[0]) <= 80, qt.IsTrue) + c.Assert(got, qt.Contains, "up/down") +} + +func TestRenderFooterShowsResumeWhenPaused(t *testing.T) { + c := qt.New(t) + state := tableState{ + List: live.NewConnectionList(time.Now(), []live.Connection{{PID: 10}}, live.SortByTransactionStart), + HasList: true, + Paused: true, + Width: 120, + Height: 24, + } + + got := stripANSI(renderFooter(state)) + + c.Assert(got, qt.Contains, "space resume") + c.Assert(got, qt.Not(qt.Contains), "space pause") + c.Assert(got, qt.Not(qt.Contains), "r refresh") +} + +func TestRenderFooterClipsLongConfirmPrompt(t *testing.T) { + c := qt.New(t) + state := tableState{ + HasList: true, + Width: 72, + Height: 24, + Interval: time.Second, + Confirm: "Force terminate PID 39123 on hzi-4gomesickgvbywmt-cell1-1486496651-c9099829? (y/n)", + } + + got := renderFooter(state) + lines := strings.Split(stripANSI(got), "\n") + + c.Assert(lines[0], qt.Contains, "Force terminate PID 39123") + c.Assert(lines[0], qt.Contains, "…") + c.Assert(lipgloss.Width(lines[0]) <= 72, qt.IsTrue) +} + +func selectedStatusLine(c *qt.C, rendered string) string { + for _, line := range strings.Split(rendered, "\n") { + if strings.HasPrefix(line, "selected pid ") { + return line + } + } + c.Fatalf("rendered output does not contain selected status:\n%s", rendered) + return "" +} + +// The refresh indicator is a fixed-width dot whose style (not width) changes +// with loading, so the header never reflows on refresh. +func TestRefreshIndicatorFixedWidth(t *testing.T) { + c := qt.New(t) + defer lipgloss.SetColorProfile(termenv.Ascii) + lipgloss.SetColorProfile(termenv.ANSI256) + + pending := refreshIndicator(refreshDotPending) + idle := refreshIndicator(refreshDotIdle) + failing := refreshIndicator(refreshDotFailing) + hidden := refreshIndicator(refreshDotHidden) + + // Every state renders one cell wide (constant width → no header reflow), + // including replay's blank dot. + for _, s := range []string{pending, idle, failing, hidden} { + c.Assert(ansi.StringWidth(s), qt.Equals, 1) + } + c.Assert(pending, qt.Contains, "●") + c.Assert(idle, qt.Contains, "●") + c.Assert(failing, qt.Contains, "●") + // ...but the styling differs so each live state is distinguishable. + c.Assert(pending, qt.Not(qt.Equals), idle) + c.Assert(failing, qt.Not(qt.Equals), pending) + c.Assert(failing, qt.Not(qt.Equals), idle) +} + +func TestRenderHelpVitessKeepsKillNames(t *testing.T) { + c := qt.New(t) + + help := stripANSI(renderHelpModal(helpState{ + Target: Target{Database: "prod", Branch: "main"}, + Width: 120, + Height: 40, + Capabilities: ConnectionCapabilities{ + CancelQuery: ActionTargetQueryID, + TerminateConnection: ActionTargetConnectionID, + configured: true, + }, + })) + + c.Assert(help, qt.Contains, "c Kill the selected query (KILL QUERY)") + c.Assert(help, qt.Contains, "K Kill the selected connection (KILL)") +} + +func TestRenderHelpModalReplayHidesRefreshAndCapture(t *testing.T) { + c := qt.New(t) + modal := stripANSI(renderHelpModal(helpState{ + Target: Target{Database: "prod", Branch: "main"}, + Width: 120, + Height: 40, + Replay: true, + })) + + c.Assert(modal, qt.Not(qt.Contains), "r refresh") + c.Assert(modal, qt.Not(qt.Contains), "C capture") + c.Assert(modal, qt.Contains, "space pause") +}