diff --git a/exec.go b/exec.go index e4247f6..1ca7739 100644 --- a/exec.go +++ b/exec.go @@ -1,6 +1,7 @@ package hookah import ( + "context" "errors" "fmt" "io" @@ -184,8 +185,30 @@ func getErrorHandlerEnv(f string, err error) []string { return env } -func (h *HookExec) execFile(f string, data io.ReadSeeker, timeout time.Duration, env ...string) error { - cmd := exec.Command(f) +// execFile executes the hook script at path f with data piped to stdin and the given environment variables. +// If timeout is greater than zero, the process and its children are killed via process group termination after +// the timeout expires. If timeout is zero, the process runs without a timeout. The function always waits for +// the process to exit, preventing zombie processes. +func (h *HookExec) execFile(f string, data io.ReadSeeker, timeout time.Duration, env ...string) (err error) { + ctx := context.Background() + + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } else { + cancel = func() {} + } + defer cancel() + + cmd := exec.CommandContext(ctx, f) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + cmd.Cancel = func() error { + // Kill the entire process group instead of just the parent + if cmd.Process == nil { + return os.ErrProcessDone + } + return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } if h.Stdout != nil { cmd.Stdout = h.Stdout @@ -205,32 +228,36 @@ func (h *HookExec) execFile(f string, data io.ReadSeeker, timeout time.Duration, if err != nil { return err } - defer stdin.Close() - err = cmd.Start() - if err != nil { + if _, err := data.Seek(0, 0); err != nil { + _ = stdin.Close() return err } - _, err = data.Seek(0, 0) - if err != nil { + if err := cmd.Start(); err != nil { + _ = stdin.Close() return err } - _, err = io.Copy(stdin, data) - if err != nil { + defer func() { + waitErr := cmd.Wait() + + if waitErr != nil && ctx.Err() == context.DeadlineExceeded { + waitErr = fmt.Errorf("hook timed out after %s: %w", timeout, waitErr) + } + + err = errors.Join(err, waitErr) + }() + + if _, err := io.Copy(stdin, data); err != nil { + _ = stdin.Close() return err } - stdin.Close() - - timer := time.AfterFunc(timeout, func() { - cmd.Process.Kill() - }) - err = cmd.Wait() - timer.Stop() + // Ignore close error - child may exit early without reading all stdin + _ = stdin.Close() - return err + return nil } // todo: base this on OS diff --git a/exec_test.go b/exec_test.go index 9ba2416..6915cc0 100644 --- a/exec_test.go +++ b/exec_test.go @@ -2,12 +2,16 @@ package hookah import ( "bytes" + "errors" + "io" "log" + "os" "strings" "testing" "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestOnlyExecutableBinsFound(t *testing.T) { @@ -144,3 +148,70 @@ func TestEnvPopulatedCorrectly(t *testing.T) { } } + +// TestExecFileTimeout verifies that execFile respects the timeout and returns +// a timeout error without hanging for long-running scripts. +func TestExecFileTimeout(t *testing.T) { + f, err := os.CreateTemp("", "hookah-test-*.sh") + require.NoError(t, err) + defer os.Remove(f.Name()) + + _, _ = io.WriteString(f, "#!/bin/sh\nsleep 30\n") + require.NoError(t, f.Close()) + require.NoError(t, os.Chmod(f.Name(), 0700)) + + h := HookExec{ + Stdout: io.Discard, + Stderr: io.Discard, + } + data := strings.NewReader(`{}`) + + start := time.Now() + err = h.execFile(f.Name(), data, 200*time.Millisecond) + elapsed := time.Since(start) + + require.Error(t, err) + assert.Less(t, elapsed, 5*time.Second, "execFile should not hang after timeout") + assert.Contains(t, err.Error(), "timed out") +} + +// TestExecFileCopyError verifies that a Read error during stdin copy still allows +// the child process to be reaped without the call hanging (no zombie processes). +func TestExecFileCopyError(t *testing.T) { + f, err := os.CreateTemp("", "hookah-test-*.sh") + require.NoError(t, err) + defer os.Remove(f.Name()) + + _, _ = io.WriteString(f, "#!/bin/sh\ncat\n") + require.NoError(t, f.Close()) + require.NoError(t, os.Chmod(f.Name(), 0700)) + + h := HookExec{ + Stdout: io.Discard, + Stderr: io.Discard, + } + + readErr := errors.New("simulated read error") + data := &readErrSeeker{readErr: readErr} + + done := make(chan error, 1) + go func() { + done <- h.execFile(f.Name(), data, 5*time.Second) + }() + + select { + case err := <-done: + assert.ErrorContains(t, err, readErr.Error()) + case <-time.After(3 * time.Second): + t.Fatal("execFile hung waiting for process to be reaped") + } +} + +// readErrSeeker is a ReadSeeker whose Seek always succeeds but whose Read always +// returns the configured error, simulating an io.Copy failure mid-transfer. +type readErrSeeker struct { + readErr error +} + +func (r *readErrSeeker) Seek(_ int64, _ int) (int64, error) { return 0, nil } +func (r *readErrSeeker) Read(_ []byte) (int, error) { return 0, r.readErr }