diff --git a/README.md b/README.md index 4b50d7fd..7a1063a5 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ With lazyssh, you can quickly navigate, connect, manage, and transfer files betw ## โœจ Features ### Server Management + - ๐Ÿ“œ Read & display servers from your `~/.ssh/config` in a scrollable list. - โž• Add a new server from the UI with comprehensive SSH configuration options. - โœ Edit existing server entries directly from the UI with a tabbed interface. @@ -21,12 +22,14 @@ With lazyssh, you can quickly navigate, connect, manage, and transfer files betw - ๐Ÿ“ Ping server to check status. ### Quick Server Navigation + - ๐Ÿ” Fuzzy search by alias, IP, or tags. - ๐Ÿ–ฅ Oneโ€‘keypress SSH into the selected server (Enter). - ๐Ÿท Tag servers (e.g., prod, dev, test) for quick filtering. - โ†•๏ธ Sort by alias or last SSH (toggle + reverse). ### Advanced SSH Configuration + - ๐Ÿ”— Port forwarding (LocalForward, RemoteForward, DynamicForward). - ๐Ÿš€ Connection multiplexing for faster subsequent connections. - ๐Ÿ” Advanced authentication options (public key, password, agent forwarding). @@ -35,17 +38,19 @@ With lazyssh, you can quickly navigate, connect, manage, and transfer files betw - โš™๏ธ Extensive SSH config options organized in tabbed interface. ### Key Management + - ๐Ÿ”‘ SSH key autocomplete with automatic detection of available keys. - ๐Ÿ“ Smart key selection with support for multiple keys. - ### Upcoming + - ๐Ÿ“ Copy files between local and servers with an easy picker UI. - ๐Ÿ”‘ SSH Key Deployment Features: - - Use default local public key (`~/.ssh/id_ed25519.pub` or `~/.ssh/id_rsa.pub`) - - Paste custom public keys manually - - Generate new keypairs and deploy them - - Automatically append keys to `~/.ssh/authorized_keys` with correct permissions + - Use default local public key (`~/.ssh/id_ed25519.pub` or `~/.ssh/id_rsa.pub`) + - Paste custom public keys manually + - Generate new keypairs and deploy them + - Automatically append keys to `~/.ssh/authorized_keys` with correct permissions + --- ## ๐Ÿ” Security Notice @@ -63,7 +68,6 @@ It is simply a UI/TUI wrapper around your existing `~/.ssh/config` file. - File permissions on your SSH config are preserved to ensure security. - ## ๐Ÿ›ก๏ธ Config Safety: Nonโ€‘destructive writes and backups - Nonโ€‘destructive edits: lazyssh only writes the minimal required changes to your ~/.ssh/config. It uses a parser that preserves existing comments, spacing, order, and any settings it didnโ€™t touch. Your handcrafted comments and formatting remain intact. @@ -72,11 +76,27 @@ It is simply a UI/TUI wrapper around your existing `~/.ssh/config` file. - Oneโ€‘time original backup: before lazyssh makes its first change, it creates a single snapshot named config.original.backup beside your SSH config. If this file is present, it will never be recreated or overwritten. - Rolling backups: on every subsequent save, lazyssh also creates a timestamped backup named like: ~/.ssh/config--lazyssh.backup. The app keeps at most 10 of these backups, automatically removing the oldest ones. +## ๐Ÿ“‚ SSH Config `Include` Support + +lazyssh honours top-level `Include` directives in your `~/.ssh/config`. Hosts defined in included files (e.g. `~/.ssh/config.d/work`) appear in the server list alongside hosts defined in the main config. + +- **Reads:** all `Include`d files are parsed in OpenSSH precedence order. When the same alias is defined in more than one file, the first definition wins (matching OpenSSH semantics) but every source file is recorded so the UI can prompt on edit. +- **Writes route back to the source file:** editing or deleting a host modifies whichever file actually defines it. Other files are never touched, and only the file that changed is re-serialized โ€” preserving handcrafted formatting elsewhere. +- **Ambiguity prompt:** if the same alias is defined in multiple included files, the first edit/delete shows a modal asking which file to write to. Your choice is remembered in `~/.lazyssh/metadata.json` (per-alias `file` field), so subsequent edits go straight through without re-prompting. +- **New hosts always go to the main config.** This keeps `Include`d files clean and predictable; you can move a host between files manually if you want it to live elsewhere. +- **Per-file backups:** rolling backups (`--lazyssh.backup`) and the one-time `.original.backup` are created alongside each included file the first time lazyssh writes to it. + +### Include support Limitations + +- `Include` directives **inside** `Host`/`Match` blocks are ignored. Only top-level Includes are honored. +- `Match` directives are not modelled as host entries. + ## ๐Ÿ“ท Screenshots
### ๐Ÿš€ Startup + App starting splash/loader Clean loading screen when launching the app @@ -84,6 +104,7 @@ Clean loading screen when launching the app --- ### ๐Ÿ“‹ Server Management Dashboard + Server list view Main dashboard displaying all configured servers with status indicators, pinned favorites at the top, and easy navigation @@ -91,6 +112,7 @@ Main dashboard displaying all configured servers with status indicators, pinned --- ### ๐Ÿ”Ž Search + Fuzzy search servers Fuzzy search functionality to quickly find servers by name, IP address, or tags @@ -98,9 +120,11 @@ Fuzzy search functionality to quickly find servers by name, IP address, or tags --- ### โž• Add/Edit Server + Add a new server Tabbed interface for managing SSH connections with extensive configuration options organized into: + - **Basic** - Host, user, port, keys, tags - **Connection** - Proxy, timeouts, multiplexing, canonicalization - **Forwarding** - Port forwarding, X11, agent @@ -110,6 +134,7 @@ Tabbed interface for managing SSH connections with extensive configuration optio --- ### ๐Ÿ” Connect to server + SSH connection details SSH into the selected server @@ -180,12 +205,12 @@ make run | q | Quit | **In Server Form:** -| Key | Action | +| Key | Action | | ------ | -------------------- | -| Ctrl+H | Previous tab | -| Ctrl+L | Next tab | -| Ctrl+S | Save | -| Esc | Cancel | +| Ctrl+H | Previous tab | +| Ctrl+L | Next tab | +| Ctrl+S | Save | +| Esc | Cancel | Tip: The hint bar at the top of the list shows the most useful shortcuts. @@ -205,10 +230,11 @@ We love seeing the community make Lazyssh better ๐Ÿš€ This repository enforces semantic PR titles via an automated GitHub Action. Please format your PR title as: - type(scope): short descriptive subject -Notes: + Notes: - Scope is optional and should be one of: ui, cli, config, parser. Allowed types in this repo: + - feat: a new feature - fix: a bug fix - improve: quality or UX improvements that are not a refactor or perf @@ -220,6 +246,7 @@ Allowed types in this repo: - revert: reverts a previous commit Examples: + - feat(ui): add server pinning and sorting options - fix(parser): handle comments at end of Host blocks - improve(cli): show friendly error when ssh binary missing @@ -239,11 +266,9 @@ If you find Lazyssh useful, please consider giving the repo a **star** โญ๏ธ an
- --- ## ๐Ÿ™ Acknowledgments - Built with [tview](https://github.com/rivo/tview) and [tcell](https://github.com/gdamore/tcell). - Inspired by [k9s](https://github.com/derailed/k9s) and [lazydocker](https://github.com/jesseduffield/lazydocker). - diff --git a/internal/adapters/data/ssh_config_file/backup.go b/internal/adapters/data/ssh_config_file/backup.go index 2f5f1564..c752170b 100644 --- a/internal/adapters/data/ssh_config_file/backup.go +++ b/internal/adapters/data/ssh_config_file/backup.go @@ -24,26 +24,28 @@ import ( "time" ) -// createBackup creates a timestamped backup of the current config file -func (r *Repository) createBackup() error { - if _, err := r.fileSystem.Stat(r.configPath); os.IsNotExist(err) { +// createBackupFor creates a timestamped backup of the given config file and +// prunes older backups for that file beyond MaxBackups. +func (r *Repository) createBackupFor(path string) error { + if _, err := r.fileSystem.Stat(path); os.IsNotExist(err) { return nil } else if err != nil { return fmt.Errorf("failed to check if config file exists: %w", err) } timestamp := time.Now().UnixMilli() - backupPath := fmt.Sprintf("%s-%d-%s", r.configPath, timestamp, BackupSuffix) + backupPath := fmt.Sprintf("%s-%d-%s", path, timestamp, BackupSuffix) - if err := r.copyFile(r.configPath, backupPath); err != nil { + if err := r.copyFile(path, backupPath); err != nil { return fmt.Errorf("failed to copy config to backup: %w", err) } r.logger.Infof("Created backup: %s", backupPath) - configDir := filepath.Dir(r.configPath) + configDir := filepath.Dir(path) + baseName := filepath.Base(path) - backupFiles, err := r.findBackupFiles(configDir) + backupFiles, err := r.findBackupFilesFor(configDir, baseName) if err != nil { return err } @@ -102,41 +104,45 @@ func (r *Repository) copyFile(src, dst string) error { return destFile.Sync() } -// findBackupFiles finds all backup files for the given config file -func (r *Repository) findBackupFiles(dir string) ([]os.FileInfo, error) { +// findBackupFilesFor finds rolling backup files for the named config file in +// dir. Backups are recognized by `--`. +func (r *Repository) findBackupFilesFor(dir, baseName string) ([]os.FileInfo, error) { entries, err := r.fileSystem.ReadDir(dir) if err != nil { return nil, err } - var backupFiles []os.FileInfo + prefix := baseName + "-" + backupFiles := make([]os.FileInfo, 0, len(entries)) for _, entry := range entries { name := entry.Name() - if strings.HasSuffix(name, BackupSuffix) { - info, err := entry.Info() - if err != nil { - r.logger.Warnf("failed to get info for backup file %s: %v", name, err) - continue - } - backupFiles = append(backupFiles, info) + if !strings.HasPrefix(name, prefix) || !strings.HasSuffix(name, BackupSuffix) { + continue + } + info, err := entry.Info() + if err != nil { + r.logger.Warnf("failed to get info for backup file %s: %v", name, err) + continue } + backupFiles = append(backupFiles, info) } return backupFiles, nil } -// createOriginalBackupIfNeeded creates a one-time original backup of the current SSH config. -func (r *Repository) createOriginalBackupIfNeeded() error { - // If no SSH config file, nothing to do. - if _, err := r.fileSystem.Stat(r.configPath); os.IsNotExist(err) { +// createOriginalBackupForIfNeeded creates a one-time original backup of the +// given config file (next to it, named `.original.backup`). +func (r *Repository) createOriginalBackupForIfNeeded(path string) error { + if _, err := r.fileSystem.Stat(path); os.IsNotExist(err) { return nil } else if err != nil { return fmt.Errorf("failed to check if config file exists: %w", err) } - configDir := filepath.Dir(r.configPath) - originalBackupPath := filepath.Join(configDir, OriginalBackupName) + configDir := filepath.Dir(path) + baseName := filepath.Base(path) + originalBackupPath := filepath.Join(configDir, baseName+".original.backup") if _, err := r.fileSystem.Stat(originalBackupPath); err == nil { return nil @@ -144,7 +150,7 @@ func (r *Repository) createOriginalBackupIfNeeded() error { return fmt.Errorf("failed to check if original backup exists: %w", err) } - if err := r.copyFile(r.configPath, originalBackupPath); err != nil { + if err := r.copyFile(path, originalBackupPath); err != nil { return fmt.Errorf("failed to create original backup: %w", err) } diff --git a/internal/adapters/data/ssh_config_file/config_io.go b/internal/adapters/data/ssh_config_file/config_io.go index b5a5da1c..d7225982 100644 --- a/internal/adapters/data/ssh_config_file/config_io.go +++ b/internal/adapters/data/ssh_config_file/config_io.go @@ -23,37 +23,42 @@ import ( "github.com/kevinburke/ssh_config" ) -// loadConfig reads and parses the SSH config file. -// If the file does not exist, it returns an empty config without error to support first-run behavior. -func (r *Repository) loadConfig() (*ssh_config.Config, error) { - file, err := r.fileSystem.Open(r.configPath) +// loadConfig reads and parses the SSH config file plus every file pulled in +// via top-level `Include` directives. Returns a loadedConfig containing all +// per-file parses in OpenSSH precedence order (main first). +func (r *Repository) loadConfig() (*loadedConfig, error) { + lc, err := r.resolveIncludes(r.configPath) if err != nil { - if r.fileSystem.IsNotExist(err) { - return &ssh_config.Config{Hosts: []*ssh_config.Host{}}, nil - } - return nil, fmt.Errorf("failed to open config file: %w", err) + return nil, fmt.Errorf("failed to load config: %w", err) } - defer func() { - if cerr := file.Close(); cerr != nil { - r.logger.Warnf("failed to close config file: %v", cerr) - } - }() + return lc, nil +} - cfg, err := ssh_config.Decode(file) - if err != nil { - return nil, fmt.Errorf("failed to decode config: %w", err) +// saveFiles writes only the entries of lc whose paths appear in dirty back to +// disk. Each file gets its own atomic temp+rename and its own rolling backup. +func (r *Repository) saveFiles(lc *loadedConfig, dirty []string) error { + dirtySet := make(map[string]bool, len(dirty)) + for _, p := range dirty { + dirtySet[p] = true } - return cfg, nil + for _, f := range lc.files { + if !dirtySet[f.path] { + continue + } + if err := r.writeOneFile(f.path, f.cfg); err != nil { + return err + } + } + return nil } -// saveConfig writes the SSH config back to the file with atomic operations and backup management. -func (r *Repository) saveConfig(cfg *ssh_config.Config) error { - configDir := filepath.Dir(r.configPath) +func (r *Repository) writeOneFile(path string, cfg *ssh_config.Config) error { + configDir := filepath.Dir(path) tempFile, err := r.createTempFile(configDir) if err != nil { - return fmt.Errorf("failed to create temporary file: %w", err) + return fmt.Errorf("failed to create temporary file for %s: %w", path, err) } defer func() { @@ -66,24 +71,29 @@ func (r *Repository) saveConfig(cfg *ssh_config.Config) error { return fmt.Errorf("failed to write config to temporary file: %w", err) } - // Ensure a one-time original backup exists before any modifications managed by lazyssh. - if err := r.createOriginalBackupIfNeeded(); err != nil { - return fmt.Errorf("failed to create original backup: %w", err) + if err := r.createOriginalBackupForIfNeeded(path); err != nil { + return fmt.Errorf("failed to create original backup for %s: %w", path, err) + } + + if err := r.createBackupFor(path); err != nil { + return fmt.Errorf("failed to create backup for %s: %w", path, err) } - if err := r.createBackup(); err != nil { - return fmt.Errorf("failed to create backup: %w", err) + // Resolve symlinks before atomic rename so we don't replace a symlink with a regular file. + target := path + if resolved, err := filepath.EvalSymlinks(path); err == nil { + target = resolved } - if err := r.fileSystem.Rename(tempFile, r.configPath); err != nil { - return fmt.Errorf("failed to atomically replace config file: %w", err) + if err := r.fileSystem.Rename(tempFile, target); err != nil { + return fmt.Errorf("failed to atomically replace %s: %w", target, err) } - r.logger.Infof("SSH config successfully updated: %s", r.configPath) + r.logger.Infof("SSH config successfully updated: %s", target) return nil } -// writeConfigToFile writes the SSH config content to the specified file +// writeConfigToFile writes the SSH config content to the specified file. func (r *Repository) writeConfigToFile(filePath string, cfg *ssh_config.Config) error { file, err := r.fileSystem.OpenFile(filePath, os.O_WRONLY|os.O_TRUNC, SSHConfigPerms) if err != nil { @@ -107,13 +117,12 @@ func (r *Repository) writeConfigToFile(filePath string, cfg *ssh_config.Config) return nil } -// createTempFile creates a temporary file in the specified directory +// createTempFile creates a temporary file in the specified directory. func (r *Repository) createTempFile(dir string) (string, error) { - timestamp := time.Now().Format("20060102150405") + timestamp := time.Now().Format("20060102150405.000000") tempFileName := fmt.Sprintf("config%s%s", timestamp, TempSuffix) tempFilePath := filepath.Join(dir, tempFileName) - // Create the temp file with explicit 0600 permissions f, err := r.fileSystem.OpenFile(tempFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, SSHConfigPerms) if err != nil { return "", err diff --git a/internal/adapters/data/ssh_config_file/crud.go b/internal/adapters/data/ssh_config_file/crud.go index c1bfc406..435dad68 100644 --- a/internal/adapters/data/ssh_config_file/crud.go +++ b/internal/adapters/data/ssh_config_file/crud.go @@ -16,6 +16,7 @@ package ssh_config_file import ( "fmt" + "slices" "strings" "github.com/Adembc/lazyssh/internal/core/domain" @@ -66,19 +67,44 @@ func (r *Repository) matchesQuery(server domain.Server, query string) bool { return false } -// serverExists checks if a server with the given alias already exists in the config. -func (r *Repository) serverExists(cfg *ssh_config.Config, alias string) bool { - return r.findHostByAlias(cfg, alias) != nil +// hostMatch records a single occurrence of an alias somewhere in the loaded +// config tree. allMatches > 1 means the alias is duplicated across files. +type hostMatch struct { + path string + cfg *ssh_config.Config + host *ssh_config.Host } -// findHostByAlias finds a host by its alias in the SSH config. -func (r *Repository) findHostByAlias(cfg *ssh_config.Config, alias string) *ssh_config.Host { - for _, host := range cfg.Hosts { - if r.hostContainsPattern(host, alias) { - return host +// serverExists checks if a server with the given alias already exists anywhere +// in the loaded config (main or any included file). +func (r *Repository) serverExists(lc *loadedConfig, alias string) bool { + matches := r.findHostMatches(lc, alias) + return len(matches) > 0 +} + +// findHostMatches returns every file in lc that defines the alias, in OpenSSH +// precedence order (main first, then includes depth-first). +func (r *Repository) findHostMatches(lc *loadedConfig, alias string) []hostMatch { + var out []hostMatch + for i := range lc.files { + cf := &lc.files[i] + for _, host := range cf.cfg.Hosts { + if r.hostContainsPattern(host, alias) { + out = append(out, hostMatch{path: cf.path, cfg: cf.cfg, host: host}) + break + } } } - return nil + return out +} + +// matchPaths returns just the file paths from a slice of hostMatches. +func matchPaths(ms []hostMatch) []string { + out := make([]string, 0, len(ms)) + for _, m := range ms { + out = append(out, m.path) + } + return out } // hostContainsPattern checks if a host contains a specific pattern. @@ -258,121 +284,122 @@ func removeNodesByKey(nodes []ssh_config.Node, key string) []ssh_config.Node { return filtered } -// updateHostNodes updates the nodes of an existing host with new server details. -func (r *Repository) updateHostNodes(host *ssh_config.Host, newServer domain.Server) { - // Handle Port - include if explicitly set (even if it's 22) +// scalarFieldMap returns the canonical keyโ†’value map for a server's scalar +// SSH config fields. Used by updateHostNodes to diff old vs new and apply +// only the keys that actually changed (so editing a host in one of several +// Include files doesn't pollute it with the merged view's other fields). +func scalarFieldMap(s domain.Server) map[string]string { portValue := "" - if newServer.Port != 0 { - portValue = fmt.Sprintf("%d", newServer.Port) + if s.Port != 0 { + portValue = fmt.Sprintf("%d", s.Port) } - - updates := map[string]string{ - "hostname": newServer.Host, - "user": newServer.User, + return map[string]string{ + "hostname": s.Host, + "user": s.User, "port": portValue, - "proxycommand": newServer.ProxyCommand, - "proxyjump": newServer.ProxyJump, - "remotecommand": newServer.RemoteCommand, - "requesttty": newServer.RequestTTY, - "sessiontype": newServer.SessionType, - "connecttimeout": newServer.ConnectTimeout, - "connectionattempts": newServer.ConnectionAttempts, - "bindaddress": newServer.BindAddress, - "bindinterface": newServer.BindInterface, - "addressfamily": newServer.AddressFamily, - "exitonforwardfailure": newServer.ExitOnForwardFailure, - "ipqos": newServer.IPQoS, - "canonicalizehostname": newServer.CanonicalizeHostname, - "canonicaldomains": newServer.CanonicalDomains, - "canonicalizefallbacklocal": newServer.CanonicalizeFallbackLocal, - "canonicalizemaxdots": newServer.CanonicalizeMaxDots, - "canonicalizepermittedcnames": newServer.CanonicalizePermittedCNAMEs, - "clearallforwardings": newServer.ClearAllForwardings, - "gatewayports": newServer.GatewayPorts, - "pubkeyauthentication": newServer.PubkeyAuthentication, - "passwordauthentication": newServer.PasswordAuthentication, - "preferredauthentications": newServer.PreferredAuthentications, - "pubkeyacceptedalgorithms": newServer.PubkeyAcceptedAlgorithms, - "pubkeyacceptedkeytypes": newServer.PubkeyAcceptedAlgorithms, // Deprecated alias (since OpenSSH 8.5) - "hostbasedacceptedalgorithms": newServer.HostbasedAcceptedAlgorithms, - "hostbasedkeytypes": newServer.HostbasedAcceptedAlgorithms, // Deprecated alias (since OpenSSH 8.5) - "hostbasedacceptedkeytypes": newServer.HostbasedAcceptedAlgorithms, // Deprecated alias (since OpenSSH 8.5) - "identitiesonly": newServer.IdentitiesOnly, - "addkeystoagent": newServer.AddKeysToAgent, - "identityagent": newServer.IdentityAgent, - "kbdinteractiveauthentication": newServer.KbdInteractiveAuthentication, - "challengeresponseauthentication": newServer.KbdInteractiveAuthentication, // Deprecated alias - "numberofpasswordprompts": newServer.NumberOfPasswordPrompts, - "forwardagent": newServer.ForwardAgent, - "forwardx11": newServer.ForwardX11, - "forwardx11trusted": newServer.ForwardX11Trusted, - "controlmaster": newServer.ControlMaster, - "controlpath": newServer.ControlPath, - "controlpersist": newServer.ControlPersist, - "serveraliveinterval": newServer.ServerAliveInterval, - "serveralivecountmax": newServer.ServerAliveCountMax, - "compression": newServer.Compression, - "tcpkeepalive": newServer.TCPKeepAlive, - "batchmode": newServer.BatchMode, - "stricthostkeychecking": newServer.StrictHostKeyChecking, - "checkhostip": newServer.CheckHostIP, - "fingerprinthash": newServer.FingerprintHash, - "userknownhostsfile": newServer.UserKnownHostsFile, - "hostkeyalgorithms": newServer.HostKeyAlgorithms, - "macs": newServer.MACs, - "ciphers": newServer.Ciphers, - "kexalgorithms": newServer.KexAlgorithms, - "verifyhostkeydns": newServer.VerifyHostKeyDNS, - "updatehostkeys": newServer.UpdateHostKeys, - "hashknownhosts": newServer.HashKnownHosts, - "visualhostkey": newServer.VisualHostKey, - "localcommand": newServer.LocalCommand, - "permitlocalcommand": newServer.PermitLocalCommand, - "escapechar": newServer.EscapeChar, - "loglevel": newServer.LogLevel, - } - - // Update or remove nodes based on value - for key, value := range updates { - if value != "" { - r.updateOrAddKVNode(host, key, value) + "proxycommand": s.ProxyCommand, + "proxyjump": s.ProxyJump, + "remotecommand": s.RemoteCommand, + "requesttty": s.RequestTTY, + "sessiontype": s.SessionType, + "connecttimeout": s.ConnectTimeout, + "connectionattempts": s.ConnectionAttempts, + "bindaddress": s.BindAddress, + "bindinterface": s.BindInterface, + "addressfamily": s.AddressFamily, + "exitonforwardfailure": s.ExitOnForwardFailure, + "ipqos": s.IPQoS, + "canonicalizehostname": s.CanonicalizeHostname, + "canonicaldomains": s.CanonicalDomains, + "canonicalizefallbacklocal": s.CanonicalizeFallbackLocal, + "canonicalizemaxdots": s.CanonicalizeMaxDots, + "canonicalizepermittedcnames": s.CanonicalizePermittedCNAMEs, + "clearallforwardings": s.ClearAllForwardings, + "gatewayports": s.GatewayPorts, + "pubkeyauthentication": s.PubkeyAuthentication, + "passwordauthentication": s.PasswordAuthentication, + "preferredauthentications": s.PreferredAuthentications, + "pubkeyacceptedalgorithms": s.PubkeyAcceptedAlgorithms, + "pubkeyacceptedkeytypes": s.PubkeyAcceptedAlgorithms, + "hostbasedacceptedalgorithms": s.HostbasedAcceptedAlgorithms, + "hostbasedkeytypes": s.HostbasedAcceptedAlgorithms, + "hostbasedacceptedkeytypes": s.HostbasedAcceptedAlgorithms, + "identitiesonly": s.IdentitiesOnly, + "addkeystoagent": s.AddKeysToAgent, + "identityagent": s.IdentityAgent, + "kbdinteractiveauthentication": s.KbdInteractiveAuthentication, + "challengeresponseauthentication": s.KbdInteractiveAuthentication, + "numberofpasswordprompts": s.NumberOfPasswordPrompts, + "forwardagent": s.ForwardAgent, + "forwardx11": s.ForwardX11, + "forwardx11trusted": s.ForwardX11Trusted, + "controlmaster": s.ControlMaster, + "controlpath": s.ControlPath, + "controlpersist": s.ControlPersist, + "serveraliveinterval": s.ServerAliveInterval, + "serveralivecountmax": s.ServerAliveCountMax, + "compression": s.Compression, + "tcpkeepalive": s.TCPKeepAlive, + "batchmode": s.BatchMode, + "stricthostkeychecking": s.StrictHostKeyChecking, + "checkhostip": s.CheckHostIP, + "fingerprinthash": s.FingerprintHash, + "userknownhostsfile": s.UserKnownHostsFile, + "hostkeyalgorithms": s.HostKeyAlgorithms, + "macs": s.MACs, + "ciphers": s.Ciphers, + "kexalgorithms": s.KexAlgorithms, + "verifyhostkeydns": s.VerifyHostKeyDNS, + "updatehostkeys": s.UpdateHostKeys, + "hashknownhosts": s.HashKnownHosts, + "visualhostkey": s.VisualHostKey, + "localcommand": s.LocalCommand, + "permitlocalcommand": s.PermitLocalCommand, + "escapechar": s.EscapeChar, + "loglevel": s.LogLevel, + } +} + +// updateHostNodes applies the diff between oldServer and newServer to host's +// KV nodes. Unchanged fields are not touched, so editing a single field on a +// host that's defined across multiple Include files won't drag the merged +// view's other fields into the file being written. +func (r *Repository) updateHostNodes(host *ssh_config.Host, oldServer, newServer domain.Server) { + oldVals := scalarFieldMap(oldServer) + newVals := scalarFieldMap(newServer) + for key, newVal := range newVals { + if oldVals[key] == newVal { + continue + } + if newVal != "" { + r.updateOrAddKVNode(host, key, newVal) } else { - // Remove the key if value is empty (user selected default) r.removeKVNode(host, key) } } - // Replace multi-value entries entirely to reflect the new state - host.Nodes = removeNodesByKey(host.Nodes, "IdentityFile") - for _, identityFile := range newServer.IdentityFiles { - r.addKVNodeIfNotEmpty(host, "IdentityFile", identityFile) - } - - host.Nodes = removeNodesByKey(host.Nodes, "LocalForward") - for _, forward := range newServer.LocalForward { - configFormat := r.convertCLIForwardToConfigFormat(forward) - r.addKVNodeIfNotEmpty(host, "LocalForward", configFormat) - } - - host.Nodes = removeNodesByKey(host.Nodes, "RemoteForward") - for _, forward := range newServer.RemoteForward { - configFormat := r.convertCLIForwardToConfigFormat(forward) - r.addKVNodeIfNotEmpty(host, "RemoteForward", configFormat) - } - - host.Nodes = removeNodesByKey(host.Nodes, "DynamicForward") - for _, forward := range newServer.DynamicForward { - r.addKVNodeIfNotEmpty(host, "DynamicForward", forward) - } + r.updateListField(host, "IdentityFile", oldServer.IdentityFiles, newServer.IdentityFiles, nil) + r.updateListField(host, "LocalForward", oldServer.LocalForward, newServer.LocalForward, r.convertCLIForwardToConfigFormat) + r.updateListField(host, "RemoteForward", oldServer.RemoteForward, newServer.RemoteForward, r.convertCLIForwardToConfigFormat) + r.updateListField(host, "DynamicForward", oldServer.DynamicForward, newServer.DynamicForward, nil) + r.updateListField(host, "SendEnv", oldServer.SendEnv, newServer.SendEnv, nil) + r.updateListField(host, "SetEnv", oldServer.SetEnv, newServer.SetEnv, nil) +} - host.Nodes = removeNodesByKey(host.Nodes, "SendEnv") - for _, env := range newServer.SendEnv { - r.addKVNodeIfNotEmpty(host, "SendEnv", env) +// updateListField rewrites a multi-valued KV (e.g. IdentityFile) only when +// its values actually changed. transform is applied per-value before writing +// (used for converting CLI forwarding format to SSH config format); pass nil +// for an identity transform. +func (r *Repository) updateListField(host *ssh_config.Host, key string, oldVals, newVals []string, transform func(string) string) { + if slices.Equal(oldVals, newVals) { + return } - - host.Nodes = removeNodesByKey(host.Nodes, "SetEnv") - for _, env := range newServer.SetEnv { - r.addKVNodeIfNotEmpty(host, "SetEnv", env) + host.Nodes = removeNodesByKey(host.Nodes, key) + for _, v := range newVals { + if transform != nil { + v = transform(v) + } + r.addKVNodeIfNotEmpty(host, key, v) } } @@ -573,3 +600,32 @@ func (r *Repository) removeHostByAlias(hosts []*ssh_config.Host, alias string) [ } return hosts } + +// preferenceResolves reports whether preferPath unambiguously selects one of +// the matches. Empty preferPath never resolves; an unknown path also doesn't. +func preferenceResolves(matches []hostMatch, preferPath string) bool { + if preferPath == "" { + return false + } + for _, m := range matches { + if m.path == preferPath { + return true + } + } + return false +} + +// pickWritableMatch chooses which match to mutate when callers haven't passed +// a preferred file. If preferPath is non-empty and matches one of the +// candidates, we use that. Otherwise the first (highest-precedence) match +// wins. Callers must ensure matches is non-empty. +func pickWritableMatch(matches []hostMatch, preferPath string) hostMatch { + if preferPath != "" { + for _, m := range matches { + if m.path == preferPath { + return m + } + } + } + return matches[0] +} diff --git a/internal/adapters/data/ssh_config_file/crud_multifile_test.go b/internal/adapters/data/ssh_config_file/crud_multifile_test.go new file mode 100644 index 00000000..0828b15b --- /dev/null +++ b/internal/adapters/data/ssh_config_file/crud_multifile_test.go @@ -0,0 +1,313 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_config_file + +import ( + "errors" + "path/filepath" + "strings" + "testing" + + "github.com/Adembc/lazyssh/internal/core/domain" + "go.uber.org/zap" +) + +func newRepoForFS(t *testing.T, fs *memFS, metaPath string) *Repository { + t.Helper() + logger := zap.NewNop().Sugar() + return &Repository{ + logger: logger, + configPath: "/home/u/.ssh/config", + fileSystem: fs, + metadataManager: newMetadataManager(metaPath, logger), + } +} + +func TestUpdateServer_AmbiguousAcrossFiles(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + work := "/home/u/.ssh/work" + personal := "/home/u/.ssh/personal" + fs.write(main, "Include "+work+"\nInclude "+personal+"\n") + fs.write(work, "Host shared\n HostName work.example.com\n") + fs.write(personal, "Host shared\n HostName home.example.com\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + srv := domain.Server{Alias: "shared", Host: "work.example.com", User: "u"} + newSrv := srv + newSrv.User = "ubuntu" + + err := r.UpdateServer(srv, newSrv) + var ambig *domain.ErrAmbiguousHost + if !errors.As(err, &ambig) { + t.Fatalf("want ErrAmbiguousHost, got %v", err) + } + if ambig.Alias != "shared" { + t.Errorf("alias = %q", ambig.Alias) + } + if len(ambig.Candidates) != 2 { + t.Fatalf("want 2 candidates, got %v", ambig.Candidates) + } + + // Re-invoke with explicit SourceFile picks the right file. + srv2 := srv + srv2.SourceFile = personal + newSrv2 := newSrv + newSrv2.SourceFile = personal + if err := r.UpdateServer(srv2, newSrv2); err != nil { + t.Fatalf("update with SourceFile: %v", err) + } + + personalContent := fs.read(personal) + if !strings.Contains(personalContent, "User ubuntu") { + t.Errorf("personal file missing update: %s", personalContent) + } + workContent := fs.read(work) + if strings.Contains(workContent, "User ubuntu") { + t.Errorf("work file should be untouched: %s", workContent) + } +} + +func TestUpdateServer_RoutesToOwningFile(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + work := "/home/u/.ssh/work" + fs.write(main, "Include "+work+"\nHost local\n HostName 127.0.0.1\n") + fs.write(work, "Host prod\n HostName prod.example.com\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + srv := domain.Server{Alias: "prod", Host: "prod.example.com", User: ""} + newSrv := srv + newSrv.User = "deploy" + + if err := r.UpdateServer(srv, newSrv); err != nil { + t.Fatalf("update: %v", err) + } + + workContent := fs.read(work) + if !strings.Contains(workContent, "User deploy") { + t.Errorf("work file missing update: %s", workContent) + } + mainContent := fs.read(main) + if strings.Contains(mainContent, "User deploy") { + t.Errorf("main file should be untouched") + } +} + +func TestAddServer_DefaultsToMainFile(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + work := "/home/u/.ssh/work" + fs.write(main, "Include "+work+"\n") + fs.write(work, "Host existing\n HostName 1.1.1.1\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + srv := domain.Server{Alias: "fresh", Host: "fresh.example.com", User: "u"} + if err := r.AddServer(srv); err != nil { + t.Fatalf("add: %v", err) + } + + mainContent := fs.read(main) + if !strings.Contains(mainContent, "Host fresh") { + t.Errorf("new host should be in main file: %s", mainContent) + } + workContent := fs.read(work) + if strings.Contains(workContent, "Host fresh") { + t.Errorf("new host should not be in include file") + } +} + +func TestDeleteServer_AmbiguousReturnsErr(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + work := "/home/u/.ssh/work" + fs.write(main, "Include "+work+"\nHost shared\n HostName a\n") + fs.write(work, "Host shared\n HostName b\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + err := r.DeleteServer(domain.Server{Alias: "shared"}) + var ambig *domain.ErrAmbiguousHost + if !errors.As(err, &ambig) { + t.Fatalf("want ErrAmbiguousHost, got %v", err) + } +} + +func TestListServers_MergesDirectivesAcrossFiles(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + override := "/home/u/.ssh/config.d/ogma.override" + conf := "/home/u/.ssh/config.d/ogma.conf" + // Order matters: override is included first, so its ProxyCommand wins + // (first-seen). Scalars only present in conf (HostName, User) must still + // reach the merged server. + fs.write(main, "Include "+override+"\nInclude "+conf+"\n") + fs.write(override, "Host ogma\n ProxyCommand ssh -W %h:%p eostre\n") + fs.write(conf, "Host ogma\n HostName ogma.hrafn.xyz\n User DelphicOkami\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + servers, err := r.ListServers("") + if err != nil { + t.Fatalf("ListServers: %v", err) + } + if len(servers) != 1 { + t.Fatalf("want 1 server, got %d", len(servers)) + } + got := servers[0] + if got.ProxyCommand != "ssh -W %h:%p eostre" { + t.Errorf("ProxyCommand from override missing: %q", got.ProxyCommand) + } + if got.Host != "ogma.hrafn.xyz" { + t.Errorf("HostName from .conf missing: %q", got.Host) + } + if got.User != "DelphicOkami" { + t.Errorf("User from .conf missing: %q", got.User) + } + if len(got.SourceFiles) != 2 { + t.Errorf("SourceFiles = %v, want both files tracked", got.SourceFiles) + } +} + +func TestUpdateServer_OnlyWritesChangedFields(t *testing.T) { + // Regression: editing a host defined across multiple Include files must + // only write the *changed* fields to the chosen file. Unchanged fields + // (drawn from the merged view) must not be dragged into the file. + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + override := "/home/u/.ssh/config.d/ogma.override" + conf := "/home/u/.ssh/config.d/ogma.conf" + fs.write(main, "Include "+override+"\nInclude "+conf+"\n") + fs.write(override, "Host ogma\n ProxyCommand ssh -W %h:%p eostre\n") + fs.write(conf, "Host ogma\n HostName ogma.hrafn.xyz\n User original\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + servers, err := r.ListServers("") + if err != nil { + t.Fatalf("ListServers: %v", err) + } + srv := servers[0] + srv.SourceFile = conf // pretend the user picked .conf in the modal + newSrv := srv + newSrv.User = "deploy" // only change User + + if err := r.UpdateServer(srv, newSrv); err != nil { + t.Fatalf("UpdateServer: %v", err) + } + + confContent := fs.read(conf) + overrideContent := fs.read(override) + + if !strings.Contains(confContent, "User deploy") { + t.Errorf(".conf missing new User: %q", confContent) + } + if strings.Contains(confContent, "ProxyCommand") { + t.Errorf(".conf should not have gained ProxyCommand: %q", confContent) + } + if strings.Contains(overrideContent, "User") { + t.Errorf("override should not have gained User: %q", overrideContent) + } + if strings.Contains(overrideContent, "HostName") { + t.Errorf("override should not have gained HostName: %q", overrideContent) + } +} + +func TestUpdateServer_PromptsOnFirstEditWhenSplit(t *testing.T) { + // Regression: an alias defined across multiple files (no remembered + // metadata.File) must surface ErrAmbiguousHost on first edit, not be + // auto-routed to the first-seen file. + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + override := "/home/u/.ssh/config.d/ogma.override" + conf := "/home/u/.ssh/config.d/ogma.conf" + fs.write(main, "Include "+override+"\nInclude "+conf+"\n") + fs.write(override, "Host ogma\n ProxyCommand ssh -W %h:%p eostre\n") + fs.write(conf, "Host ogma\n HostName ogma.hrafn.xyz\n User u\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + servers, err := r.ListServers("") + if err != nil { + t.Fatalf("ListServers: %v", err) + } + if len(servers) != 1 { + t.Fatalf("want 1 server, got %d", len(servers)) + } + srv := servers[0] + if srv.SourceFile != "" { + t.Errorf("SourceFile should be empty for split host, got %q", srv.SourceFile) + } + + newSrv := srv + newSrv.User = "deploy" + err = r.UpdateServer(srv, newSrv) + var ambig *domain.ErrAmbiguousHost + if !errors.As(err, &ambig) { + t.Fatalf("want ErrAmbiguousHost on first edit of split host, got %v", err) + } +} + +func TestUpdateServer_PersistsFileChoiceToMetadata(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + work := "/home/u/.ssh/work" + fs.write(main, "Include "+work+"\n") + fs.write(work, "Host pinned\n HostName 1.1.1.1\n") + + tmpMeta := filepath.Join(t.TempDir(), "metadata.json") + r := newRepoForFS(t, fs, tmpMeta) + + srv := domain.Server{Alias: "pinned", Host: "1.1.1.1"} + newSrv := srv + newSrv.User = "deploy" + if err := r.UpdateServer(srv, newSrv); err != nil { + t.Fatalf("update: %v", err) + } + + meta, err := r.metadataManager.loadAll() + if err != nil { + t.Fatalf("loadAll: %v", err) + } + if got := meta["pinned"].File; got != work { + t.Errorf("metadata.File = %q, want %q", got, work) + } +} diff --git a/internal/adapters/data/ssh_config_file/include_resolver.go b/internal/adapters/data/ssh_config_file/include_resolver.go new file mode 100644 index 00000000..bd67b6c8 --- /dev/null +++ b/internal/adapters/data/ssh_config_file/include_resolver.go @@ -0,0 +1,313 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_config_file + +import ( + "bufio" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + + "github.com/kevinburke/ssh_config" +) + +const maxIncludeDepth = 16 + +// hasGlobMeta reports whether s contains any character special to filepath.Glob. +func hasGlobMeta(s string) bool { + return strings.ContainsAny(s, "*?[") +} + +// configFile pairs a parsed ssh_config.Config with the absolute path of the +// file it came from. The path is the canonical key used everywhere CRUD +// operations need to know "which file does this host live in?". +type configFile struct { + path string + cfg *ssh_config.Config +} + +// loadedConfig holds the main SSH config plus every file pulled in via +// `Include` directives, in OpenSSH precedence order (main first, then includes +// depth-first in the order they appeared). +type loadedConfig struct { + files []configFile + mainPath string +} + +// findFile returns the configFile whose path matches absPath, or nil. +func (lc *loadedConfig) findFile(absPath string) *configFile { + for i := range lc.files { + if lc.files[i].path == absPath { + return &lc.files[i] + } + } + return nil +} + +// paths returns the absolute paths of every loaded file in order. +func (lc *loadedConfig) paths() []string { + out := make([]string, 0, len(lc.files)) + for _, f := range lc.files { + out = append(out, f.path) + } + return out +} + +// resolveIncludes parses mainPath and walks `Include` directives recursively, +// returning every file encountered. Missing globs are tolerated silently +// (matches OpenSSH). Cycles and excessive depth produce errors. +func (r *Repository) resolveIncludes(mainPath string) (*loadedConfig, error) { + absMain, err := filepath.Abs(mainPath) + if err != nil { + absMain = mainPath + } + + visited := make(map[string]bool) + lc := &loadedConfig{mainPath: absMain} + + if err := r.loadFileAndIncludes(absMain, lc, visited, 0); err != nil { + return nil, err + } + + if len(lc.files) == 0 { + // File didn't exist; preserve loadConfig's first-run behavior. + lc.files = append(lc.files, configFile{ + path: absMain, + cfg: &ssh_config.Config{Hosts: []*ssh_config.Host{}}, + }) + } + return lc, nil +} + +func (r *Repository) loadFileAndIncludes(path string, lc *loadedConfig, visited map[string]bool, depth int) error { + if depth > maxIncludeDepth { + return fmt.Errorf("ssh config include depth exceeded %d at %s", maxIncludeDepth, path) + } + + abs, err := filepath.Abs(path) + if err != nil { + abs = path + } + if visited[abs] { + return fmt.Errorf("ssh config include cycle detected at %s", abs) + } + visited[abs] = true + + file, err := r.fileSystem.Open(abs) + if err != nil { + if r.fileSystem.IsNotExist(err) { + // OpenSSH silently ignores missing includes; do the same. + if depth == 0 { + return nil + } + return nil + } + return fmt.Errorf("open %s: %w", abs, err) + } + + cfg, decodeErr := ssh_config.Decode(file) + if cerr := file.Close(); cerr != nil { + r.logger.Warnf("failed to close %s: %v", abs, cerr) + } + if decodeErr != nil { + return fmt.Errorf("decode %s: %w", abs, decodeErr) + } + + lc.files = append(lc.files, configFile{path: abs, cfg: cfg}) + + includes, err := r.parseIncludeDirectives(abs) + if err != nil { + return err + } + + for _, pattern := range includes { + expanded, err := r.expandIncludePattern(pattern) + if err != nil { + r.logger.Warnf("include pattern %q: %v", pattern, err) + continue + } + for _, child := range expanded { + if err := r.loadFileAndIncludes(child, lc, visited, depth+1); err != nil { + return err + } + } + } + return nil +} + +// parseIncludeDirectives scans a config file's raw text for top-level `Include` +// lines and returns their (still-unexpanded) glob patterns. We do our own +// scanning rather than relying on the parser's internal Include handling โ€” +// kevinburke/ssh_config keeps the resolved file map unexported. +// +// Limitations (documented): +// - Only top-level Includes are honored. Includes inside Host/Match blocks +// are ignored. OpenSSH allows them but they're rare; flagging as a v1 cap. +func (r *Repository) parseIncludeDirectives(absPath string) ([]string, error) { + file, err := r.fileSystem.Open(absPath) + if err != nil { + if r.fileSystem.IsNotExist(err) { + return nil, nil + } + return nil, fmt.Errorf("re-open %s for include scan: %w", absPath, err) + } + defer func() { + if cerr := file.Close(); cerr != nil { + r.logger.Warnf("failed to close %s during include scan: %v", absPath, cerr) + } + }() + + var patterns []string + scanner := bufio.NewScanner(file) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + + inHostOrMatch := false + for scanner.Scan() { + line := scanner.Text() + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + continue + } + + key, rest := splitDirective(trimmed) + if key == "" { + continue + } + lower := strings.ToLower(key) + + switch lower { + case "host", "match": + inHostOrMatch = true + continue + case "include": + if inHostOrMatch { + continue + } + patterns = append(patterns, splitIncludeArgs(rest)...) + } + } + if err := scanner.Err(); err != nil { + return nil, fmt.Errorf("scan %s: %w", absPath, err) + } + return patterns, nil +} + +// splitDirective splits "Key value..." (or "Key=value...") into (key, rest). +func splitDirective(line string) (string, string) { + for i := 0; i < len(line); i++ { + c := line[i] + if c == ' ' || c == '\t' || c == '=' { + key := line[:i] + rest := strings.TrimLeft(line[i:], " \t=") + return key, rest + } + } + return line, "" +} + +// splitIncludeArgs splits the argument list of an `Include` directive, +// respecting double-quoted patterns (which may contain spaces). +func splitIncludeArgs(rest string) []string { + var out []string + var cur strings.Builder + inQuote := false + for i := 0; i < len(rest); i++ { + c := rest[i] + switch { + case c == '"': + inQuote = !inQuote + case (c == ' ' || c == '\t') && !inQuote: + if cur.Len() > 0 { + out = append(out, cur.String()) + cur.Reset() + } + default: + cur.WriteByte(c) + } + } + if cur.Len() > 0 { + out = append(out, cur.String()) + } + return out +} + +// expandIncludePattern resolves a single `Include` pattern: handles `~/` +// expansion, glob-expands, and resolves relative paths against `~/.ssh/` +// (matching OpenSSH semantics). +func (r *Repository) expandIncludePattern(pattern string) ([]string, error) { + if pattern == "" { + return nil, nil + } + + expanded := pattern + if strings.HasPrefix(expanded, "~/") || expanded == "~" { + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("resolve ~ for include %q: %w", pattern, err) + } + if expanded == "~" { + expanded = home + } else { + expanded = filepath.Join(home, expanded[2:]) + } + } + + if !filepath.IsAbs(expanded) { + // Relative paths in user config resolve against ~/.ssh. + home, err := os.UserHomeDir() + if err != nil { + return nil, fmt.Errorf("resolve relative include %q: %w", pattern, err) + } + expanded = filepath.Join(home, ".ssh", expanded) + } + + var matches []string + if hasGlobMeta(expanded) { + m, err := filepath.Glob(expanded) + if err != nil { + return nil, fmt.Errorf("glob %q: %w", expanded, err) + } + matches = m + } else { + // Literal path โ€” go through the FileSystem abstraction so tests using + // an in-memory FS still resolve. + if _, err := r.fileSystem.Stat(expanded); err == nil { + matches = []string{expanded} + } else if !r.fileSystem.IsNotExist(err) { + return nil, fmt.Errorf("stat %q: %w", expanded, err) + } + } + sort.Strings(matches) + + resolved := make([]string, 0, len(matches)) + for _, m := range matches { + // Skip directories โ€” OpenSSH only Includes regular files. + info, statErr := r.fileSystem.Stat(m) + if statErr != nil { + continue + } + if info.IsDir() { + continue + } + abs, err := filepath.Abs(m) + if err != nil { + abs = m + } + resolved = append(resolved, abs) + } + return resolved, nil +} diff --git a/internal/adapters/data/ssh_config_file/include_resolver_test.go b/internal/adapters/data/ssh_config_file/include_resolver_test.go new file mode 100644 index 00000000..169886c8 --- /dev/null +++ b/internal/adapters/data/ssh_config_file/include_resolver_test.go @@ -0,0 +1,158 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_config_file + +import ( + "strings" + "testing" + + "go.uber.org/zap" +) + +func newTestRepo(t *testing.T, fs FileSystem) *Repository { + t.Helper() + logger := zap.NewNop().Sugar() + return &Repository{ + logger: logger, + configPath: "/home/u/.ssh/config", + fileSystem: fs, + metadataManager: newMetadataManager("/dev/null", logger), + } +} + +func TestResolveIncludes_NoIncludes(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + fs.write(main, "Host alpha\n HostName 1.1.1.1\n") + + r := newTestRepo(t, fs) + lc, err := r.resolveIncludes(main) + if err != nil { + t.Fatalf("resolveIncludes: %v", err) + } + if len(lc.files) != 1 { + t.Fatalf("want 1 file, got %d", len(lc.files)) + } + if lc.files[0].path != main { + t.Errorf("want path %s, got %s", main, lc.files[0].path) + } +} + +func TestResolveIncludes_AbsoluteInclude(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + inc := "/home/u/.ssh/work" + fs.write(main, "Include "+inc+"\nHost alpha\n HostName 1.1.1.1\n") + fs.write(inc, "Host beta\n HostName 2.2.2.2\n") + + r := newTestRepo(t, fs) + lc, err := r.resolveIncludes(main) + if err != nil { + t.Fatalf("resolveIncludes: %v", err) + } + if len(lc.files) != 2 { + t.Fatalf("want 2 files, got %d (%v)", len(lc.files), lc.paths()) + } + if lc.files[0].path != main || lc.files[1].path != inc { + t.Errorf("ordering wrong: %v", lc.paths()) + } +} + +func TestResolveIncludes_MissingIncludeIsSilent(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + fs.write(main, "Include /nonexistent/file\nHost a\n HostName x\n") + + r := newTestRepo(t, fs) + lc, err := r.resolveIncludes(main) + if err != nil { + t.Fatalf("missing include should not error: %v", err) + } + if len(lc.files) != 1 { + t.Errorf("want 1 file, got %d", len(lc.files)) + } +} + +func TestResolveIncludes_CycleDetected(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + a := "/home/u/.ssh/config" + b := "/home/u/.ssh/loop" + fs.write(a, "Include "+b+"\n") + fs.write(b, "Include "+a+"\n") + + r := newTestRepo(t, fs) + _, err := r.resolveIncludes(a) + if err == nil || !strings.Contains(err.Error(), "cycle") { + t.Fatalf("want cycle error, got %v", err) + } +} + +func TestResolveIncludes_IgnoresIncludeInsideHostBlock(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + inc := "/home/u/.ssh/should-not-load" + fs.write(main, "Host alpha\n HostName 1.1.1.1\n Include "+inc+"\n") + fs.write(inc, "Host beta\n HostName 2.2.2.2\n") + + r := newTestRepo(t, fs) + lc, err := r.resolveIncludes(main) + if err != nil { + t.Fatalf("resolveIncludes: %v", err) + } + if len(lc.files) != 1 { + t.Errorf("Include inside Host block should be ignored, got files=%v", lc.paths()) + } +} + +func TestResolveIncludes_FirstFileMissing(t *testing.T) { + fs := newMemFS(t) + defer fs.cleanup() + + main := "/home/u/.ssh/config" + r := newTestRepo(t, fs) + lc, err := r.resolveIncludes(main) + if err != nil { + t.Fatalf("resolveIncludes: %v", err) + } + if len(lc.files) != 1 { + t.Fatalf("want 1 placeholder file, got %d", len(lc.files)) + } + if got := len(lc.files[0].cfg.Hosts); got != 0 { + t.Errorf("placeholder cfg should be empty, got %d hosts", got) + } +} + +func TestSplitIncludeArgs(t *testing.T) { + got := splitIncludeArgs(`config.d/* "with spaces.conf" plain`) + want := []string{"config.d/*", "with spaces.conf", "plain"} + if len(got) != len(want) { + t.Fatalf("len mismatch: got %v want %v", got, want) + } + for i := range want { + if got[i] != want[i] { + t.Errorf("arg %d: got %q want %q", i, got[i], want[i]) + } + } +} diff --git a/internal/adapters/data/ssh_config_file/mapper.go b/internal/adapters/data/ssh_config_file/mapper.go index f8a31f4a..822795f6 100644 --- a/internal/adapters/data/ssh_config_file/mapper.go +++ b/internal/adapters/data/ssh_config_file/mapper.go @@ -15,6 +15,7 @@ package ssh_config_file import ( + "slices" "strconv" "strings" "time" @@ -23,46 +24,96 @@ import ( "github.com/kevinburke/ssh_config" ) -// toDomainServer converts ssh_config.Config to a slice of domain.Server. -func (r *Repository) toDomainServer(cfg *ssh_config.Config) []domain.Server { - servers := make([]domain.Server, 0, len(cfg.Hosts)) - for _, host := range cfg.Hosts { - - aliases := make([]string, 0, len(host.Patterns)) +// toDomainServer converts a loadedConfig (main file plus any included files) +// into a slice of domain.Server. +// +// OpenSSH semantics: when an alias appears in multiple Host blocks (across +// files or within one file), directives are merged with first-seen value +// winning per key; list-style directives (IdentityFile, SendEnv, etc.) append +// across all matching blocks. We replicate that by mapping every matching +// block's KVs into the same domain.Server, suppressing scalar keys we've +// already seen for that alias. +func (r *Repository) toDomainServer(lc *loadedConfig) []domain.Server { + byAlias := make(map[string]int) + seenKeys := make(map[string]map[string]bool) + servers := make([]domain.Server, 0) - for _, pattern := range host.Patterns { - alias := pattern.String() - // Skip if alias contains wildcards (not a concrete Host) - if strings.ContainsAny(alias, "!*?[]") { + for _, cf := range lc.files { + for _, host := range cf.cfg.Hosts { + aliases := make([]string, 0, len(host.Patterns)) + for _, pattern := range host.Patterns { + alias := pattern.String() + if strings.ContainsAny(alias, "!*?[]") { + continue + } + aliases = append(aliases, alias) + } + if len(aliases) == 0 { continue } - aliases = append(aliases, alias) - } - if len(aliases) == 0 { - continue - } - server := domain.Server{ - Alias: aliases[0], - Aliases: aliases, - Port: 22, - IdentityFiles: []string{}, - } - for _, node := range host.Nodes { - kvNode, ok := node.(*ssh_config.KV) - if !ok { - continue + primaryAlias := aliases[0] + idx, exists := byAlias[primaryAlias] + if !exists { + servers = append(servers, domain.Server{ + Alias: primaryAlias, + Aliases: aliases, + Port: 22, + IdentityFiles: []string{}, + SourceFile: cf.path, + SourceFiles: []string{cf.path}, + }) + idx = len(servers) - 1 + byAlias[primaryAlias] = idx + seenKeys[primaryAlias] = make(map[string]bool) + } else if !slices.Contains(servers[idx].SourceFiles, cf.path) { + servers[idx].SourceFiles = append(servers[idx].SourceFiles, cf.path) } - r.mapKVToServer(&server, kvNode) + seen := seenKeys[primaryAlias] + for _, node := range host.Nodes { + kvNode, ok := node.(*ssh_config.KV) + if !ok { + continue + } + key := strings.ToLower(kvNode.Key) + if !isAppendingKey(key) && seen[key] { + continue + } + r.mapKVToServer(&servers[idx], kvNode) + seen[key] = true + } } + } - servers = append(servers, server) + // Clear SourceFile when an alias is defined in more than one file: the + // "first-seen" file isn't a recorded user preference, so it must not + // auto-resolve the ambiguity prompt on edit/delete. mergeMetadata will + // populate SourceFile later if the user has previously chosen a file. + for i := range servers { + if len(servers[i].SourceFiles) > 1 { + servers[i].SourceFile = "" + } } return servers } +// isAppendingKey reports whether the SSH config key accumulates values across +// multiple Host blocks (rather than first-write-wins). +func isAppendingKey(key string) bool { + switch key { + case "identityfile", + "sendenv", + "setenv", + "localforward", + "remoteforward", + "dynamicforward": + return true + } + return false +} + // mapKVToServer maps an ssh_config.KV node to the corresponding fields in domain.Server. func (r *Repository) mapKVToServer(server *domain.Server, kvNode *ssh_config.KV) { key := strings.ToLower(kvNode.Key) @@ -299,6 +350,9 @@ func (r *Repository) mergeMetadata(servers []domain.Server, metadata map[string] if meta, exists := metadata[server.Alias]; exists { servers[i].Tags = meta.Tags servers[i].SSHCount = meta.SSHCount + if meta.File != "" { + servers[i].SourceFile = meta.File + } if meta.LastSeen != "" { if lastSeen, err := time.Parse(time.RFC3339, meta.LastSeen); err == nil { diff --git a/internal/adapters/data/ssh_config_file/memfs_test.go b/internal/adapters/data/ssh_config_file/memfs_test.go new file mode 100644 index 00000000..0b41c96a --- /dev/null +++ b/internal/adapters/data/ssh_config_file/memfs_test.go @@ -0,0 +1,207 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ssh_config_file + +import ( + "bytes" + "errors" + "io" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +// memFS is an in-memory FileSystem implementation for tests. It implements the +// subset of the FileSystem interface our code touches; calls that need to hit +// real disk (OpenFile returning *os.File) fall through to a temp-dir backing +// so atomic-rename tests still work. +type memFS struct { + mu sync.Mutex + files map[string][]byte + tempDir string + openReal map[string]string // logical path โ†’ real on-disk path for OpenFile redirects +} + +func newMemFS(t interface { + Helper() + Fatalf(string, ...any) +}, +) *memFS { + t.Helper() + dir, err := os.MkdirTemp("", "lazyssh-memfs-") + if err != nil { + t.Fatalf("mkdir temp: %v", err) + } + return &memFS{files: map[string][]byte{}, tempDir: dir, openReal: map[string]string{}} +} + +func (m *memFS) cleanup() { _ = os.RemoveAll(m.tempDir) } + +func (m *memFS) write(path string, content string) { + m.mu.Lock() + defer m.mu.Unlock() + m.files[path] = []byte(content) +} + +func (m *memFS) read(path string) string { + m.mu.Lock() + defer m.mu.Unlock() + return string(m.files[path]) +} + +// --- FileSystem interface --- + +func (m *memFS) Open(name string) (io.ReadCloser, error) { + m.mu.Lock() + defer m.mu.Unlock() + b, ok := m.files[name] + if !ok { + return nil, &os.PathError{Op: "open", Path: name, Err: os.ErrNotExist} + } + return io.NopCloser(bytes.NewReader(b)), nil +} + +func (m *memFS) Create(name string) (io.WriteCloser, error) { + return &memWriter{fs: m, path: name}, nil +} + +type memFileInfo struct { + name string + size int64 + dir bool + mode os.FileMode +} + +func (i memFileInfo) Name() string { return i.name } +func (i memFileInfo) Size() int64 { return i.size } +func (i memFileInfo) Mode() os.FileMode { return i.mode } +func (i memFileInfo) ModTime() time.Time { return time.Time{} } +func (i memFileInfo) IsDir() bool { return i.dir } +func (i memFileInfo) Sys() any { return nil } + +func (m *memFS) Stat(name string) (os.FileInfo, error) { + m.mu.Lock() + defer m.mu.Unlock() + if b, ok := m.files[name]; ok { + return memFileInfo{name: filepath.Base(name), size: int64(len(b)), mode: 0o600}, nil + } + // Treat any prefix that's a parent of a known file as a directory. + for p := range m.files { + if strings.HasPrefix(p, name+string(os.PathSeparator)) { + return memFileInfo{name: filepath.Base(name), dir: true, mode: 0o755 | os.ModeDir}, nil + } + } + return nil, &os.PathError{Op: "stat", Path: name, Err: os.ErrNotExist} +} + +func (m *memFS) IsNotExist(err error) bool { return errors.Is(err, os.ErrNotExist) } + +func (m *memFS) realPathFor(logical string) string { + if rp, ok := m.openReal[logical]; ok { + return rp + } + rel := strings.ReplaceAll(filepath.Clean(logical), string(os.PathSeparator), "_") + rp := filepath.Join(m.tempDir, rel) + m.openReal[logical] = rp + return rp +} + +func (m *memFS) Remove(file string) error { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.files[file]; ok { + delete(m.files, file) + return nil + } + if rp, ok := m.openReal[file]; ok { + delete(m.openReal, file) + return os.Remove(rp) + } + return &os.PathError{Op: "remove", Path: file, Err: os.ErrNotExist} +} + +func (m *memFS) Rename(src, dst string) error { + m.mu.Lock() + srcReal, ok := m.openReal[src] + m.mu.Unlock() + if !ok { + srcReal = src + } + b, err := os.ReadFile(srcReal) // #nosec G304 + if err != nil { + return err + } + m.mu.Lock() + m.files[dst] = b + delete(m.openReal, src) + m.mu.Unlock() + return os.Remove(srcReal) +} + +func (m *memFS) Chmod(path string, perms os.FileMode) error { return nil } + +func (m *memFS) OpenFile(path string, flag int, perms os.FileMode) (*os.File, error) { + m.mu.Lock() + rp := m.realPathFor(path) + m.mu.Unlock() + return os.OpenFile(rp, flag, perms) // #nosec G304 +} + +func (m *memFS) ReadDir(dir string) ([]os.DirEntry, error) { + m.mu.Lock() + defer m.mu.Unlock() + var entries []os.DirEntry + prefix := dir + if !strings.HasSuffix(prefix, string(os.PathSeparator)) { + prefix += string(os.PathSeparator) + } + for p := range m.files { + if strings.HasPrefix(p, prefix) { + rest := strings.TrimPrefix(p, prefix) + if !strings.Contains(rest, string(os.PathSeparator)) { + entries = append(entries, memDirEntry{name: rest, size: int64(len(m.files[p]))}) + } + } + } + return entries, nil +} + +type memDirEntry struct { + name string + size int64 +} + +func (e memDirEntry) Name() string { return e.name } +func (e memDirEntry) IsDir() bool { return false } +func (e memDirEntry) Type() os.FileMode { return 0 } +func (e memDirEntry) Info() (os.FileInfo, error) { + return memFileInfo{name: e.name, size: e.size, mode: 0o600}, nil +} + +type memWriter struct { + fs *memFS + path string + buf bytes.Buffer +} + +func (w *memWriter) Write(b []byte) (int, error) { return w.buf.Write(b) } +func (w *memWriter) Close() error { + w.fs.mu.Lock() + w.fs.files[w.path] = append([]byte(nil), w.buf.Bytes()...) + w.fs.mu.Unlock() + return nil +} diff --git a/internal/adapters/data/ssh_config_file/metadata_manager.go b/internal/adapters/data/ssh_config_file/metadata_manager.go index a4e7be7e..db29061c 100644 --- a/internal/adapters/data/ssh_config_file/metadata_manager.go +++ b/internal/adapters/data/ssh_config_file/metadata_manager.go @@ -30,6 +30,11 @@ type ServerMetadata struct { LastSeen string `json:"last_seen,omitempty"` PinnedAt string `json:"pinned_at,omitempty"` SSHCount int `json:"ssh_count,omitempty"` + // File is the absolute path of the SSH config file lazyssh should + // write to when editing or deleting this host. Populated lazily on + // the first successful write and used to suppress the ambiguity + // prompt on subsequent edits. + File string `json:"file,omitempty"` } type metadataManager struct { @@ -120,6 +125,23 @@ func (m *metadataManager) updateServer(server domain.Server, oldAlias string) er return m.saveAll(metadata) } +// setFile records the config file lazyssh should write to next time the +// alias is edited or deleted. Empty path clears the memory. +func (m *metadataManager) setFile(alias, path string) error { + metadata, err := m.loadAll() + if err != nil { + return fmt.Errorf("load metadata: %w", err) + } + + meta := metadata[alias] + if meta.File == path { + return nil + } + meta.File = path + metadata[alias] = meta + return m.saveAll(metadata) +} + func (m *metadataManager) deleteServer(alias string) error { metadata, err := m.loadAll() if err != nil { diff --git a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go index 37e8004b..1711da7d 100644 --- a/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go +++ b/internal/adapters/data/ssh_config_file/ssh_config_file_repo.go @@ -54,12 +54,12 @@ func NewRepositoryWithFS(logger *zap.SugaredLogger, configPath string, metaDataP // ListServers returns all servers matching the query pattern. // Empty query returns all servers. func (r *Repository) ListServers(query string) ([]domain.Server, error) { - cfg, err := r.loadConfig() + lc, err := r.loadConfig() if err != nil { - return nil, fmt.Errorf("failed to load config: %w", err) + return nil, err } - servers := r.toDomainServer(cfg) + servers := r.toDomainServer(lc) metadata, err := r.metadataManager.loadAll() if err != nil { r.logger.Warnf("Failed to load metadata: %v", err) @@ -73,41 +73,57 @@ func (r *Repository) ListServers(query string) ([]domain.Server, error) { return r.filterServers(servers, query), nil } -// AddServer adds a new server to the SSH config. +// AddServer adds a new server to the SSH config. If server.SourceFile is set +// and matches a loaded file, the new host is written there; otherwise it +// goes into the main config file. func (r *Repository) AddServer(server domain.Server) error { - cfg, err := r.loadConfig() + lc, err := r.loadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return err } - if r.serverExists(cfg, server.Alias) { + if r.serverExists(lc, server.Alias) { return fmt.Errorf("server with alias '%s' already exists", server.Alias) } + target := lc.findFile(server.SourceFile) + if target == nil { + // Default: main file. + target = &lc.files[0] + } + host := r.createHostFromServer(server) - cfg.Hosts = append(cfg.Hosts, host) + target.cfg.Hosts = append(target.cfg.Hosts, host) - if err := r.saveConfig(cfg); err != nil { + if err := r.saveFiles(lc, []string{target.path}); err != nil { r.logger.Warnf("Failed to save config while adding new server: %v", err) return fmt.Errorf("failed to save config: %w", err) } return r.metadataManager.updateServer(server, server.Alias) } -// UpdateServer updates an existing server in the SSH config. +// UpdateServer updates an existing server in the SSH config. The host is +// mutated in whichever file currently defines it (preferring server.SourceFile +// when the alias is defined in multiple files). func (r *Repository) UpdateServer(server domain.Server, newServer domain.Server) error { - cfg, err := r.loadConfig() + lc, err := r.loadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return err } - host := r.findHostByAlias(cfg, server.Alias) - if host == nil { + matches := r.findHostMatches(lc, server.Alias) + if len(matches) == 0 { return fmt.Errorf("server with alias '%s' not found", server.Alias) } + if len(matches) > 1 && !preferenceResolves(matches, server.SourceFile) { + return &domain.ErrAmbiguousHost{Alias: server.Alias, Candidates: matchPaths(matches)} + } + + picked := pickWritableMatch(matches, server.SourceFile) + host := picked.host if server.Alias != newServer.Alias { - if r.serverExists(cfg, newServer.Alias) { + if r.serverExists(lc, newServer.Alias) { return fmt.Errorf("server with alias '%s' already exists", newServer.Alias) } @@ -119,36 +135,41 @@ func (r *Repository) UpdateServer(server domain.Server, newServer domain.Server) newPatterns = append(newPatterns, pattern) } } - host.Patterns = newPatterns - } - r.updateHostNodes(host, newServer) + r.updateHostNodes(host, server, newServer) - if err := r.saveConfig(cfg); err != nil { + if err := r.saveFiles(lc, []string{picked.path}); err != nil { r.logger.Warnf("Failed to save config while updating server: %v", err) return fmt.Errorf("failed to save config: %w", err) } - // Update metadata; pass old alias to allow inline migration - return r.metadataManager.updateServer(newServer, server.Alias) + if err := r.metadataManager.updateServer(newServer, server.Alias); err != nil { + return err + } + return r.metadataManager.setFile(newServer.Alias, picked.path) } -// DeleteServer removes a server from the SSH config. +// DeleteServer removes a server from the SSH config (from whichever file +// currently defines it; preferring server.SourceFile on ambiguity). func (r *Repository) DeleteServer(server domain.Server) error { - cfg, err := r.loadConfig() + lc, err := r.loadConfig() if err != nil { - return fmt.Errorf("failed to load config: %w", err) + return err } - initialCount := len(cfg.Hosts) - cfg.Hosts = r.removeHostByAlias(cfg.Hosts, server.Alias) - - if len(cfg.Hosts) == initialCount { + matches := r.findHostMatches(lc, server.Alias) + if len(matches) == 0 { return fmt.Errorf("server with alias '%s' not found", server.Alias) } + if len(matches) > 1 && !preferenceResolves(matches, server.SourceFile) { + return &domain.ErrAmbiguousHost{Alias: server.Alias, Candidates: matchPaths(matches)} + } + + picked := pickWritableMatch(matches, server.SourceFile) + picked.cfg.Hosts = r.removeHostByAlias(picked.cfg.Hosts, server.Alias) - if err := r.saveConfig(cfg); err != nil { + if err := r.saveFiles(lc, []string{picked.path}); err != nil { r.logger.Warnf("Failed to save config while deleting server: %v", err) return fmt.Errorf("failed to save config: %w", err) } diff --git a/internal/adapters/ui/handlers.go b/internal/adapters/ui/handlers.go index 897e053d..602400f9 100644 --- a/internal/adapters/ui/handlers.go +++ b/internal/adapters/ui/handlers.go @@ -15,6 +15,7 @@ package ui import ( + "errors" "fmt" "strings" "time" @@ -264,6 +265,23 @@ func (t *tui) handleServerSave(server domain.Server, original *domain.Server) { err = t.serverService.AddServer(server) } if err != nil { + var ambig *domain.ErrAmbiguousHost + if errors.As(err, &ambig) && original != nil { + t.showFileChoiceModal(ambig.Alias, ambig.Candidates, "Save", func(chosen string) { + origCopy := *original + origCopy.SourceFile = chosen + newCopy := server + newCopy.SourceFile = chosen + if err := t.serverService.UpdateServer(origCopy, newCopy); err != nil { + t.showStatusTempColor("Save failed: "+err.Error(), "#FF6B6B") + return + } + t.showStatusTemp(fmt.Sprintf("Updated %s in %s", newCopy.Alias, chosen)) + t.refreshServerList() + t.handleFormCancel() + }) + return + } // Stay on form; show a small modal with the error modal := tview.NewModal(). SetText(fmt.Sprintf("Save failed: %v", err)). @@ -273,6 +291,13 @@ func (t *tui) handleServerSave(server domain.Server, original *domain.Server) { return } + if server.SourceFile != "" { + verb := "Added" + if original != nil { + verb = "Updated" + } + t.showStatusTemp(fmt.Sprintf("%s %s in %s", verb, server.Alias, server.SourceFile)) + } t.refreshServerList() t.handleFormCancel() } @@ -355,38 +380,79 @@ func (t *tui) showDeleteConfirmModal(server domain.Server) { msg := fmt.Sprintf("Delete server %s (%s@%s:%d)?\n\nThis action cannot be undone.", server.Alias, server.User, server.Host, server.Port) + doDelete := func() { + err := t.serverService.DeleteServer(server) + if err != nil { + var ambig *domain.ErrAmbiguousHost + if errors.As(err, &ambig) { + t.showFileChoiceModal(ambig.Alias, ambig.Candidates, "Delete", func(chosen string) { + srv := server + srv.SourceFile = chosen + if err := t.serverService.DeleteServer(srv); err != nil { + t.showStatusTempColor("Delete failed: "+err.Error(), "#FF6B6B") + return + } + t.showStatusTemp(fmt.Sprintf("Deleted %s from %s", srv.Alias, chosen)) + t.refreshServerList() + }) + return + } + t.showStatusTempColor("Delete failed: "+err.Error(), "#FF6B6B") + return + } + t.refreshServerList() + t.handleModalClose() + } + modal := tview.NewModal(). SetText(msg). AddButtons([]string{"[yellow]C[-]ancel", "[yellow]D[-]elete"}). SetDoneFunc(func(buttonIndex int, buttonLabel string) { if buttonIndex == 1 { - _ = t.serverService.DeleteServer(server) - t.refreshServerList() + doDelete() + return } t.handleModalClose() }) - // Add keyboard shortcuts for the modal modal.SetInputCapture(func(event *tcell.EventKey) *tcell.EventKey { switch event.Rune() { case 'c', 'C': - // Cancel t.handleModalClose() return nil case 'd', 'D': - // Delete - _ = t.serverService.DeleteServer(server) - t.refreshServerList() - t.handleModalClose() + doDelete() return nil } - // ESC key already handled by default modal behavior return event }) t.app.SetRoot(modal, true) } +// showFileChoiceModal asks the user which config file to apply an action to +// when the same alias is defined in multiple files. action is a verb shown +// on the confirmation buttons (e.g. "Save", "Delete"). onChoose is called +// with the chosen absolute file path; Cancel closes the modal. +func (t *tui) showFileChoiceModal(alias string, candidates []string, action string, onChoose func(path string)) { + msg := fmt.Sprintf("Host %q is defined in multiple files.\nWhich file should %s use?", alias, action) + buttons := append([]string{}, candidates...) + buttons = append(buttons, "Cancel") + + modal := tview.NewModal(). + SetText(msg). + AddButtons(buttons). + SetDoneFunc(func(idx int, label string) { + if idx < 0 || idx >= len(candidates) { + t.handleModalClose() + return + } + t.handleModalClose() + onChoose(candidates[idx]) + }) + t.app.SetRoot(modal, true) +} + func (t *tui) showEditTagsForm(server domain.Server) { form := tview.NewForm() form.SetBorder(true). @@ -408,8 +474,27 @@ func (t *tui) showEditTagsForm(server domain.Server) { newServer := server newServer.Tags = tags - _ = t.serverService.UpdateServer(server, newServer) - // Refresh UI and go back + err := t.serverService.UpdateServer(server, newServer) + if err != nil { + var ambig *domain.ErrAmbiguousHost + if errors.As(err, &ambig) { + t.showFileChoiceModal(ambig.Alias, ambig.Candidates, "Update", func(chosen string) { + orig := server + orig.SourceFile = chosen + nu := newServer + nu.SourceFile = chosen + if err := t.serverService.UpdateServer(orig, nu); err != nil { + t.showStatusTempColor("Tags update failed: "+err.Error(), "#FF6B6B") + return + } + t.refreshServerList() + t.showStatusTemp("Tags updated") + }) + return + } + t.showStatusTempColor("Tags update failed: "+err.Error(), "#FF6B6B") + return + } t.refreshServerList() t.returnToMain() t.showStatusTemp("Tags updated") diff --git a/internal/core/domain/errors.go b/internal/core/domain/errors.go new file mode 100644 index 00000000..34d12708 --- /dev/null +++ b/internal/core/domain/errors.go @@ -0,0 +1,35 @@ +// Copyright 2025. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package domain + +import ( + "fmt" + "strings" +) + +// ErrAmbiguousHost is returned when an alias is defined in more than one SSH +// config file (the main file plus one or more `Include`-d files) and the +// caller hasn't told the repository which file to write to. +// +// The TUI catches this error, prompts the user to pick a file, and re-invokes +// the operation with Server.SourceFile set to the chosen path. +type ErrAmbiguousHost struct { + Alias string + Candidates []string +} + +func (e *ErrAmbiguousHost) Error() string { + return fmt.Sprintf("alias %q is defined in multiple files: %s", e.Alias, strings.Join(e.Candidates, ", ")) +} diff --git a/internal/core/domain/server.go b/internal/core/domain/server.go index c23b301d..8d2cdfd8 100644 --- a/internal/core/domain/server.go +++ b/internal/core/domain/server.go @@ -114,4 +114,13 @@ type Server struct { // Debugging settings LogLevel string + + // SourceFile is the absolute path of the SSH config file this host was + // loaded from (or where it should be written when adding a new host). + // Provenance metadata, not part of SSH semantics. + SourceFile string + // SourceFiles lists every config file that defines this alias, in + // OpenSSH precedence order. A length > 1 means the alias is defined in + // multiple files; the UI uses this to prompt on edit/delete. + SourceFiles []string }