From d6a788a781cd5c97c62cc576ef8ae6626687bbe8 Mon Sep 17 00:00:00 2001 From: Will Smith <18408743+OpsKern@users.noreply.github.com> Date: Mon, 4 May 2026 22:29:16 -0400 Subject: [PATCH] =?UTF-8?q?feat(compliance):=20PR=202=20=E2=80=94=20restar?= =?UTF-8?q?t=20continuity=20via=20SessionID?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add SessionID to Config, ChainEntry, and Merkle hash so process-restart boundaries are detectable in the audit chain. - Config.SessionID written into every chain entry - NewWithContext warns (via Client.Warnings()) when chain's last entry has a different non-empty session_id, flagging a cross-restart chain - VerifyFull() returns VerifyResult with Valid, EntryCount, RootHash, and SessionGaps — structured session-boundary detection without breaking existing Verify() API - SessionID uses omitempty in HashContent so pre-v0.2 entries remain verifiable (no hash drift on upgrade) - Fix chainTimestampStr helper: JSONL round-trip deserialises Timestamp as string, not time.Time; verifyChain was silently using "" causing hash mismatch for file-backed stores gosec: 0 issues (1 nosec); go test -race: PASS; govulncheck: clean Co-Authored-By: Claude Sonnet 4.6 --- chain.go | 12 +++- gate.go | 1 + internal/audit/chain.go | 5 ++ store.go | 1 + writ.go | 94 ++++++++++++++++++++++++++- writ_test.go | 141 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 252 insertions(+), 2 deletions(-) diff --git a/chain.go b/chain.go index 10f6f5b..895bea0 100644 --- a/chain.go +++ b/chain.go @@ -15,7 +15,7 @@ func newAuditID() string { // buildChainEntry creates a Merkle-linked ChainEntry from an AuditEvent. // Reads the last entry from store to get the previous hash. -func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID string) (ChainEntry, error) { +func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID, sessionID string) (ChainEntry, error) { prev, err := lastHash(store) if err != nil { return ChainEntry{}, err @@ -52,6 +52,7 @@ func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID ActionType: event.ActionType, Actor: actor, CallerID: event.CallerID, + SessionID: sessionID, InputHash: event.InputHash, OutputHash: event.OutputHash, Result: result, @@ -75,6 +76,7 @@ func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID ActionType: internal.ActionType, Actor: internal.Actor, CallerID: internal.CallerID, + SessionID: internal.SessionID, InputHash: internal.InputHash, OutputHash: internal.OutputHash, Result: internal.Result, @@ -107,6 +109,7 @@ func buildPostCallEntry(store AuditStore, preEntry ChainEntry, resp *anthropic.M PrevHash: prev, EventType: eventType, CallerID: cfg.CallerID, + SessionID: cfg.SessionID, OutputHash: outputHash, HookdTraceID: cfg.HookdTraceID, Allowed: true, @@ -126,6 +129,7 @@ func buildPostCallEntry(store AuditStore, preEntry ChainEntry, resp *anthropic.M Hash: internal.Hash, EventType: internal.EventType, CallerID: internal.CallerID, + SessionID: internal.SessionID, OutputHash: internal.OutputHash, HookdTraceID: internal.HookdTraceID, Allowed: internal.Allowed, @@ -152,6 +156,7 @@ func buildStreamCompleteEntry(store AuditStore, startEntry ChainEntry, streamErr PrevHash: prev, EventType: eventType, CallerID: cfg.CallerID, + SessionID: cfg.SessionID, HookdTraceID: cfg.HookdTraceID, Allowed: true, Timestamp: ts.Format(time.RFC3339Nano), @@ -170,6 +175,7 @@ func buildStreamCompleteEntry(store AuditStore, startEntry ChainEntry, streamErr Hash: internal.Hash, EventType: internal.EventType, CallerID: internal.CallerID, + SessionID: internal.SessionID, HookdTraceID: internal.HookdTraceID, Allowed: internal.Allowed, Timestamp: ts, @@ -222,6 +228,7 @@ func computeEntryHash(store AuditStore, entry ChainEntry) (ChainEntry, error) { ActionType: entry.ActionType, Actor: entry.Actor, CallerID: entry.CallerID, + SessionID: entry.SessionID, InputHash: entry.InputHash, OutputHash: entry.OutputHash, Result: entry.Result, @@ -278,6 +285,7 @@ func findLastValidHash(entries []ChainEntry) string { ActionType: e.ActionType, Actor: e.Actor, CallerID: e.CallerID, + SessionID: e.SessionID, InputHash: e.InputHash, OutputHash: e.OutputHash, Result: e.Result, @@ -330,6 +338,7 @@ func buildSegmentBoundaryEntry(lastValidHash, reason string) (ChainEntry, error) }, nil } + // verifyChain is the internal entry point for Verify(). func verifyChain(entries []ChainEntry) error { internalEntries := make([]inaudit.Entry, len(entries)) @@ -342,6 +351,7 @@ func verifyChain(entries []ChainEntry) error { ActionType: e.ActionType, Actor: e.Actor, CallerID: e.CallerID, + SessionID: e.SessionID, InputHash: e.InputHash, OutputHash: e.OutputHash, Result: e.Result, diff --git a/gate.go b/gate.go index 15df001..9d08618 100644 --- a/gate.go +++ b/gate.go @@ -48,6 +48,7 @@ func (g *gateWrapper) evaluate(ctx context.Context, params anthropic.MessageNewP ID: auditID, EventType: "llm_call", CallerID: cfg.CallerID, + SessionID: cfg.SessionID, HookdTraceID: cfg.HookdTraceID, Allowed: result.Allowed, DenialReason: result.DenialReason, diff --git a/internal/audit/chain.go b/internal/audit/chain.go index 54cf7c0..462afa8 100644 --- a/internal/audit/chain.go +++ b/internal/audit/chain.go @@ -21,6 +21,7 @@ type Entry struct { ActionType string `json:"action_type,omitempty"` Actor string `json:"actor,omitempty"` CallerID string `json:"caller_id,omitempty"` + SessionID string `json:"session_id,omitempty"` InputHash string `json:"input_hash,omitempty"` OutputHash string `json:"output_hash,omitempty"` Result string `json:"result,omitempty"` @@ -34,12 +35,15 @@ type Entry struct { // HashContent is the subset of fields included in the Merkle hash. // Excludes Hash itself (computed from this) and Metadata (advisory). +// SessionID uses omitempty so pre-v0.2 entries (without session_id) remain +// verifiable — an absent field hashes identically to an absent JSON key. type HashContent struct { PrevHash string `json:"prev_hash"` EventType string `json:"event_type"` ActionType string `json:"action_type,omitempty"` Actor string `json:"actor,omitempty"` CallerID string `json:"caller_id,omitempty"` + SessionID string `json:"session_id,omitempty"` InputHash string `json:"input_hash,omitempty"` OutputHash string `json:"output_hash,omitempty"` Result string `json:"result,omitempty"` @@ -56,6 +60,7 @@ func ComputeHash(e Entry) (string, error) { EventType: e.EventType, ActionType: e.ActionType, CallerID: e.CallerID, + SessionID: e.SessionID, InputHash: e.InputHash, OutputHash: e.OutputHash, HookdTraceID: e.HookdTraceID, diff --git a/store.go b/store.go index f86dcf5..b6ce948 100644 --- a/store.go +++ b/store.go @@ -26,6 +26,7 @@ type ChainEntry struct { ActionType string `json:"action_type,omitempty"` Actor string `json:"actor,omitempty"` CallerID string `json:"caller_id,omitempty"` + SessionID string `json:"session_id,omitempty"` InputHash string `json:"input_hash,omitempty"` OutputHash string `json:"output_hash,omitempty"` Result string `json:"result,omitempty"` diff --git a/writ.go b/writ.go index 12e3506..e4f5079 100644 --- a/writ.go +++ b/writ.go @@ -48,6 +48,13 @@ type Config struct { // fails Merkle verification. A ChainSegmentBoundary entry is written // recording the recovery event. Default false: corrupt chain → ErrCorruptChain. AllowCorruptChainRecovery bool + + // SessionID is a stable identifier for the current process run + // (e.g. a UUID generated at startup). Written into each chain entry. + // On writ.New(), if the chain's last entry has a different non-empty + // SessionID, a warning is added to Client.Warnings() to flag the + // cross-session boundary for human review. + SessionID string } // ErrCorruptChain is returned by New() when the existing chain fails Merkle @@ -62,6 +69,7 @@ type Client struct { cfg Config gater *gateWrapper chain AuditStore + warnings []string // non-fatal init warnings (e.g. session ID mismatch) } // New constructs a writ.Client with lazy OPA policy reload. @@ -108,6 +116,20 @@ func NewWithContext(ctx context.Context, cfg Config) (*Client, error) { chain: store, } c.Messages = &MessagesService{wc: c} + + // Warn on session ID mismatch so operators know the chain crosses a restart. + if cfg.SessionID != "" { + if entries, readErr := store.ReadAll(); readErr == nil && len(entries) > 0 { + last := entries[len(entries)-1] + if last.SessionID != "" && last.SessionID != cfg.SessionID { + c.warnings = append(c.warnings, fmt.Sprintf( + "session ID mismatch: last chain entry has session_id=%q, current session is %q — chain spans a process restart", + last.SessionID, cfg.SessionID, + )) + } + } + } + return c, nil } @@ -188,6 +210,76 @@ type AuditEvent struct { Metadata map[string]string } +// Warnings returns non-fatal issues detected at construction time. +// Currently populated when Config.SessionID differs from the last chain +// entry's session_id, indicating the chain spans a process restart. +func (c *Client) Warnings() []string { + return c.warnings +} + +// VerifyResult is the output of Client.VerifyFull. +type VerifyResult struct { + Valid bool + EntryCount int + FirstBreak *ChainEntry // nil if Valid is true + RootHash string // hash of the last entry; empty if chain is empty + SessionGaps []SessionGap +} + +// SessionGap describes a point in the chain where the session_id changed, +// indicating a process restart boundary. +type SessionGap struct { + AfterEntryIndex int // index of the last entry with PrevSessionID + PrevSessionID string + NextSessionID string +} + +// VerifyFull verifies Merkle hash integrity and reports SessionID gaps. +// Unlike the package-level Verify, this returns structured results including +// chain continuity information across process restarts. +func (c *Client) VerifyFull() (*VerifyResult, error) { + entries, err := c.chain.ReadAll() + if err != nil { + return nil, fmt.Errorf("writ.VerifyFull: read chain: %w", err) + } + result := &VerifyResult{EntryCount: len(entries)} + if err := verifyChain(entries); err != nil { + result.Valid = false + result.FirstBreak = firstBrokenEntry(entries) + } else { + result.Valid = true + if len(entries) > 0 { + result.RootHash = entries[len(entries)-1].Hash + } + } + for i := 1; i < len(entries); i++ { + prev, curr := entries[i-1], entries[i] + if prev.SessionID != "" && curr.SessionID != "" && curr.SessionID != prev.SessionID { + result.SessionGaps = append(result.SessionGaps, SessionGap{ + AfterEntryIndex: i - 1, + PrevSessionID: prev.SessionID, + NextSessionID: curr.SessionID, + }) + } + } + return result, nil +} + +// firstBrokenEntry returns the first ChainEntry whose hash link is broken, +// or nil if the chain is intact. +func firstBrokenEntry(entries []ChainEntry) *ChainEntry { + for i := range entries { + copy := entries[i] + if i == len(entries)-1 { + return © + } + if entries[i+1].PrevHash != entries[i].Hash { + return © + } + } + return nil +} + // Audit writes an explicit event to the writ chain. Use for tool use events // (file read, shell exec, web fetch) that require Article 12 granularity. // The chain entry includes a Merkle link to the previous entry. @@ -195,7 +287,7 @@ func (c *Client) Audit(event AuditEvent) error { if event.Timestamp.IsZero() { event.Timestamp = time.Now().UTC() } - entry, err := buildChainEntry(c.chain, event, c.cfg.CallerID, c.cfg.HookdTraceID) + entry, err := buildChainEntry(c.chain, event, c.cfg.CallerID, c.cfg.HookdTraceID, c.cfg.SessionID) if err != nil { return fmt.Errorf("writ.Audit: build entry: %w", err) } diff --git a/writ_test.go b/writ_test.go index 0c13c76..8e10ac5 100644 --- a/writ_test.go +++ b/writ_test.go @@ -1,12 +1,33 @@ package writ_test import ( + "os" + "path/filepath" "testing" "time" "github.com/opskernel-io/writ" ) +// testConfig creates a minimal writ.Config backed by a temp directory. +func testConfig(t *testing.T) writ.Config { + t.Helper() + dir := t.TempDir() + policyPath := filepath.Join(dir, "policy") + if err := os.MkdirAll(policyPath, 0o700); err != nil { + t.Fatalf("mkdir policy: %v", err) + } + const policy = "package writ.gate\nimport rego.v1\ndefault allow := true\ndefault tier := 2\ndefault denial_reason := \"\"" + if err := os.WriteFile(filepath.Join(policyPath, "writ.rego"), []byte(policy), 0o600); err != nil { + t.Fatalf("write test policy: %v", err) + } + return writ.Config{ + PolicyPath: policyPath, + AuditPath: filepath.Join(dir, "audit.chain"), + CallerID: "test-agent", + } +} + func TestMemoryStoreAppendAndVerify(t *testing.T) { store := writ.NewMemoryStore() @@ -58,3 +79,123 @@ func TestDenialErrorMessage(t *testing.T) { t.Fatal("DenialError.Error() returned empty string") } } + +// PR 2: SessionID, Warnings, VerifyFull tests. + +func TestSessionIDWrittenToChain(t *testing.T) { + cfg := testConfig(t) + cfg.SessionID = "session-abc" + c, err := writ.New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit: %v", err) + } + result, err := c.VerifyFull() + if err != nil { + t.Fatalf("VerifyFull: %v", err) + } + if !result.Valid { + t.Fatalf("want valid chain, got invalid") + } + if result.EntryCount != 1 { + t.Fatalf("want 1 entry, got %d", result.EntryCount) + } +} + +func TestSessionIDMismatchWarning(t *testing.T) { + cfg := testConfig(t) + cfg.SessionID = "session-A" + c1, err := writ.New(cfg) + if err != nil { + t.Fatalf("New (session-A): %v", err) + } + if err := c1.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit: %v", err) + } + + cfg.SessionID = "session-B" + c2, err := writ.New(cfg) + if err != nil { + t.Fatalf("New (session-B): %v", err) + } + if len(c2.Warnings()) == 0 { + t.Error("want session ID mismatch warning, got none") + } +} + +func TestSessionIDMatchNoWarning(t *testing.T) { + cfg := testConfig(t) + cfg.SessionID = "session-A" + c1, err := writ.New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + if err := c1.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit: %v", err) + } + c2, err := writ.New(cfg) // same SessionID + if err != nil { + t.Fatalf("New (same session): %v", err) + } + if len(c2.Warnings()) != 0 { + t.Errorf("want no warnings for matching session ID, got: %v", c2.Warnings()) + } +} + +func TestVerifyFullSessionGapDetected(t *testing.T) { + cfg := testConfig(t) + cfg.SessionID = "session-A" + c1, err := writ.New(cfg) + if err != nil { + t.Fatalf("New (A): %v", err) + } + if err := c1.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit A: %v", err) + } + + cfg.SessionID = "session-B" + c2, err := writ.New(cfg) + if err != nil { + t.Fatalf("New (B): %v", err) + } + if err := c2.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit B: %v", err) + } + + result, err := c2.VerifyFull() + if err != nil { + t.Fatalf("VerifyFull: %v", err) + } + if !result.Valid { + t.Fatal("want valid chain despite session gap") + } + if len(result.SessionGaps) == 0 { + t.Error("want 1 session gap, got none") + } +} + +func TestVerifyFullNoGapsForSingleSession(t *testing.T) { + cfg := testConfig(t) + cfg.SessionID = "session-X" + c, err := writ.New(cfg) + if err != nil { + t.Fatalf("New: %v", err) + } + for i := 0; i < 3; i++ { + if err := c.Audit(writ.AuditEvent{EventType: "tool_use", ActionType: "noop"}); err != nil { + t.Fatalf("Audit %d: %v", i, err) + } + } + result, err := c.VerifyFull() + if err != nil { + t.Fatalf("VerifyFull: %v", err) + } + if !result.Valid { + t.Fatal("want valid chain") + } + if len(result.SessionGaps) != 0 { + t.Errorf("want no gaps for single session, got: %v", result.SessionGaps) + } +}