Skip to content
61 changes: 44 additions & 17 deletions exec.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package hookah

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -184,8 +185,30 @@
return env
}

func (h *HookExec) execFile(f string, data io.ReadSeeker, timeout time.Duration, env ...string) error {
cmd := exec.Command(f)
// execFile executes the hook script at path f with data piped to stdin and the given environment variables.
// If timeout is greater than zero, the process and its children are killed via process group termination after
// the timeout expires. If timeout is zero, the process runs without a timeout. The function always waits for
// the process to exit, preventing zombie processes.
func (h *HookExec) execFile(f string, data io.ReadSeeker, timeout time.Duration, env ...string) (err error) {
ctx := context.Background()

var cancel context.CancelFunc
if timeout > 0 {
ctx, cancel = context.WithTimeout(ctx, timeout)
} else {
cancel = func() {}
}
defer cancel()

cmd := exec.CommandContext(ctx, f)

Check failure

Code scanning / CodeQL

Command built from user-controlled sources Critical

This command depends on a
user-provided value
.
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
cmd.Cancel = func() error {
// Kill the entire process group instead of just the parent
if cmd.Process == nil {
return os.ErrProcessDone
}
return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}

if h.Stdout != nil {
cmd.Stdout = h.Stdout
Expand All @@ -205,32 +228,36 @@
if err != nil {
return err
}
defer stdin.Close()

err = cmd.Start()
if err != nil {
if _, err := data.Seek(0, 0); err != nil {
_ = stdin.Close()
return err
}

_, err = data.Seek(0, 0)
if err != nil {
if err := cmd.Start(); err != nil {
_ = stdin.Close()
return err
}

_, err = io.Copy(stdin, data)
if err != nil {
defer func() {
waitErr := cmd.Wait()

if waitErr != nil && ctx.Err() == context.DeadlineExceeded {
waitErr = fmt.Errorf("hook timed out after %s: %w", timeout, waitErr)
}

err = errors.Join(err, waitErr)
}()

if _, err := io.Copy(stdin, data); err != nil {
_ = stdin.Close()
return err
}
Comment thread
donatj marked this conversation as resolved.
stdin.Close()

timer := time.AfterFunc(timeout, func() {
cmd.Process.Kill()
})

err = cmd.Wait()
timer.Stop()
// Ignore close error - child may exit early without reading all stdin
_ = stdin.Close()

return err
return nil
}

// todo: base this on OS
Expand Down
71 changes: 71 additions & 0 deletions exec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@ package hookah

import (
"bytes"
"errors"
"io"
"log"
"os"
"strings"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestOnlyExecutableBinsFound(t *testing.T) {
Expand Down Expand Up @@ -144,3 +148,70 @@ func TestEnvPopulatedCorrectly(t *testing.T) {
}

}

// TestExecFileTimeout verifies that execFile respects the timeout and returns
// a timeout error without hanging for long-running scripts.
func TestExecFileTimeout(t *testing.T) {
f, err := os.CreateTemp("", "hookah-test-*.sh")
require.NoError(t, err)
defer os.Remove(f.Name())

_, _ = io.WriteString(f, "#!/bin/sh\nsleep 30\n")
require.NoError(t, f.Close())
require.NoError(t, os.Chmod(f.Name(), 0700))

h := HookExec{
Stdout: io.Discard,
Stderr: io.Discard,
}
data := strings.NewReader(`{}`)

start := time.Now()
err = h.execFile(f.Name(), data, 200*time.Millisecond)
elapsed := time.Since(start)

require.Error(t, err)
assert.Less(t, elapsed, 5*time.Second, "execFile should not hang after timeout")
assert.Contains(t, err.Error(), "timed out")
}

// TestExecFileCopyError verifies that a Read error during stdin copy still allows
// the child process to be reaped without the call hanging (no zombie processes).
func TestExecFileCopyError(t *testing.T) {
f, err := os.CreateTemp("", "hookah-test-*.sh")
require.NoError(t, err)
defer os.Remove(f.Name())

_, _ = io.WriteString(f, "#!/bin/sh\ncat\n")
require.NoError(t, f.Close())
require.NoError(t, os.Chmod(f.Name(), 0700))

h := HookExec{
Stdout: io.Discard,
Stderr: io.Discard,
}

readErr := errors.New("simulated read error")
data := &readErrSeeker{readErr: readErr}

done := make(chan error, 1)
go func() {
done <- h.execFile(f.Name(), data, 5*time.Second)
}()

select {
case err := <-done:
assert.ErrorContains(t, err, readErr.Error())
case <-time.After(3 * time.Second):
t.Fatal("execFile hung waiting for process to be reaped")
}
}

// readErrSeeker is a ReadSeeker whose Seek always succeeds but whose Read always
// returns the configured error, simulating an io.Copy failure mid-transfer.
type readErrSeeker struct {
readErr error
}

func (r *readErrSeeker) Seek(_ int64, _ int) (int64, error) { return 0, nil }
func (r *readErrSeeker) Read(_ []byte) (int, error) { return 0, r.readErr }