Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func newAuditID() string {

// buildChainEntry creates a Merkle-linked ChainEntry from an AuditEvent.
// Reads the last entry from store to get the previous hash.
func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID string) (ChainEntry, error) {
func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID, sessionID string) (ChainEntry, error) {
prev, err := lastHash(store)
if err != nil {
return ChainEntry{}, err
Expand Down Expand Up @@ -52,6 +52,7 @@ func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID
ActionType: event.ActionType,
Actor: actor,
CallerID: event.CallerID,
SessionID: sessionID,
InputHash: event.InputHash,
OutputHash: event.OutputHash,
Result: result,
Expand All @@ -75,6 +76,7 @@ func buildChainEntry(store AuditStore, event AuditEvent, callerID, hookdTraceID
ActionType: internal.ActionType,
Actor: internal.Actor,
CallerID: internal.CallerID,
SessionID: internal.SessionID,
InputHash: internal.InputHash,
OutputHash: internal.OutputHash,
Result: internal.Result,
Expand Down Expand Up @@ -107,6 +109,7 @@ func buildPostCallEntry(store AuditStore, preEntry ChainEntry, resp *anthropic.M
PrevHash: prev,
EventType: eventType,
CallerID: cfg.CallerID,
SessionID: cfg.SessionID,
OutputHash: outputHash,
HookdTraceID: cfg.HookdTraceID,
Allowed: true,
Expand All @@ -126,6 +129,7 @@ func buildPostCallEntry(store AuditStore, preEntry ChainEntry, resp *anthropic.M
Hash: internal.Hash,
EventType: internal.EventType,
CallerID: internal.CallerID,
SessionID: internal.SessionID,
OutputHash: internal.OutputHash,
HookdTraceID: internal.HookdTraceID,
Allowed: internal.Allowed,
Expand All @@ -152,6 +156,7 @@ func buildStreamCompleteEntry(store AuditStore, startEntry ChainEntry, streamErr
PrevHash: prev,
EventType: eventType,
CallerID: cfg.CallerID,
SessionID: cfg.SessionID,
HookdTraceID: cfg.HookdTraceID,
Allowed: true,
Timestamp: ts.Format(time.RFC3339Nano),
Expand All @@ -170,6 +175,7 @@ func buildStreamCompleteEntry(store AuditStore, startEntry ChainEntry, streamErr
Hash: internal.Hash,
EventType: internal.EventType,
CallerID: internal.CallerID,
SessionID: internal.SessionID,
HookdTraceID: internal.HookdTraceID,
Allowed: internal.Allowed,
Timestamp: ts,
Expand Down Expand Up @@ -222,6 +228,7 @@ func computeEntryHash(store AuditStore, entry ChainEntry) (ChainEntry, error) {
ActionType: entry.ActionType,
Actor: entry.Actor,
CallerID: entry.CallerID,
SessionID: entry.SessionID,
InputHash: entry.InputHash,
OutputHash: entry.OutputHash,
Result: entry.Result,
Expand Down Expand Up @@ -278,6 +285,7 @@ func findLastValidHash(entries []ChainEntry) string {
ActionType: e.ActionType,
Actor: e.Actor,
CallerID: e.CallerID,
SessionID: e.SessionID,
InputHash: e.InputHash,
OutputHash: e.OutputHash,
Result: e.Result,
Expand Down Expand Up @@ -330,6 +338,7 @@ func buildSegmentBoundaryEntry(lastValidHash, reason string) (ChainEntry, error)
}, nil
}


// verifyChain is the internal entry point for Verify().
func verifyChain(entries []ChainEntry) error {
internalEntries := make([]inaudit.Entry, len(entries))
Expand All @@ -342,6 +351,7 @@ func verifyChain(entries []ChainEntry) error {
ActionType: e.ActionType,
Actor: e.Actor,
CallerID: e.CallerID,
SessionID: e.SessionID,
InputHash: e.InputHash,
OutputHash: e.OutputHash,
Result: e.Result,
Expand Down
1 change: 1 addition & 0 deletions gate.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ func (g *gateWrapper) evaluate(ctx context.Context, params anthropic.MessageNewP
ID: auditID,
EventType: "llm_call",
CallerID: cfg.CallerID,
SessionID: cfg.SessionID,
HookdTraceID: cfg.HookdTraceID,
Allowed: result.Allowed,
DenialReason: result.DenialReason,
Expand Down
5 changes: 5 additions & 0 deletions internal/audit/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Entry struct {
ActionType string `json:"action_type,omitempty"`
Actor string `json:"actor,omitempty"`
CallerID string `json:"caller_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
InputHash string `json:"input_hash,omitempty"`
OutputHash string `json:"output_hash,omitempty"`
Result string `json:"result,omitempty"`
Expand All @@ -34,12 +35,15 @@ type Entry struct {

// HashContent is the subset of fields included in the Merkle hash.
// Excludes Hash itself (computed from this) and Metadata (advisory).
// SessionID uses omitempty so pre-v0.2 entries (without session_id) remain
// verifiable — an absent field hashes identically to an absent JSON key.
type HashContent struct {
PrevHash string `json:"prev_hash"`
EventType string `json:"event_type"`
ActionType string `json:"action_type,omitempty"`
Actor string `json:"actor,omitempty"`
CallerID string `json:"caller_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
InputHash string `json:"input_hash,omitempty"`
OutputHash string `json:"output_hash,omitempty"`
Result string `json:"result,omitempty"`
Expand All @@ -56,6 +60,7 @@ func ComputeHash(e Entry) (string, error) {
EventType: e.EventType,
ActionType: e.ActionType,
CallerID: e.CallerID,
SessionID: e.SessionID,
InputHash: e.InputHash,
OutputHash: e.OutputHash,
HookdTraceID: e.HookdTraceID,
Expand Down
1 change: 1 addition & 0 deletions store.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ type ChainEntry struct {
ActionType string `json:"action_type,omitempty"`
Actor string `json:"actor,omitempty"`
CallerID string `json:"caller_id,omitempty"`
SessionID string `json:"session_id,omitempty"`
InputHash string `json:"input_hash,omitempty"`
OutputHash string `json:"output_hash,omitempty"`
Result string `json:"result,omitempty"`
Expand Down
94 changes: 93 additions & 1 deletion writ.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ type Config struct {
// fails Merkle verification. A ChainSegmentBoundary entry is written
// recording the recovery event. Default false: corrupt chain → ErrCorruptChain.
AllowCorruptChainRecovery bool

// SessionID is a stable identifier for the current process run
// (e.g. a UUID generated at startup). Written into each chain entry.
// On writ.New(), if the chain's last entry has a different non-empty
// SessionID, a warning is added to Client.Warnings() to flag the
// cross-session boundary for human review.
SessionID string
}

// ErrCorruptChain is returned by New() when the existing chain fails Merkle
Expand All @@ -62,6 +69,7 @@ type Client struct {
cfg Config
gater *gateWrapper
chain AuditStore
warnings []string // non-fatal init warnings (e.g. session ID mismatch)
}

// New constructs a writ.Client with lazy OPA policy reload.
Expand Down Expand Up @@ -108,6 +116,20 @@ func NewWithContext(ctx context.Context, cfg Config) (*Client, error) {
chain: store,
}
c.Messages = &MessagesService{wc: c}

// Warn on session ID mismatch so operators know the chain crosses a restart.
if cfg.SessionID != "" {
if entries, readErr := store.ReadAll(); readErr == nil && len(entries) > 0 {
last := entries[len(entries)-1]
if last.SessionID != "" && last.SessionID != cfg.SessionID {
c.warnings = append(c.warnings, fmt.Sprintf(
"session ID mismatch: last chain entry has session_id=%q, current session is %q — chain spans a process restart",
last.SessionID, cfg.SessionID,
))
}
}
}

return c, nil
}

Expand Down Expand Up @@ -188,14 +210,84 @@ type AuditEvent struct {
Metadata map[string]string
}

// Warnings returns non-fatal issues detected at construction time.
// Currently populated when Config.SessionID differs from the last chain
// entry's session_id, indicating the chain spans a process restart.
func (c *Client) Warnings() []string {
return c.warnings
}

// VerifyResult is the output of Client.VerifyFull.
type VerifyResult struct {
Valid bool
EntryCount int
FirstBreak *ChainEntry // nil if Valid is true
RootHash string // hash of the last entry; empty if chain is empty
SessionGaps []SessionGap
}

// SessionGap describes a point in the chain where the session_id changed,
// indicating a process restart boundary.
type SessionGap struct {
AfterEntryIndex int // index of the last entry with PrevSessionID
PrevSessionID string
NextSessionID string
}

// VerifyFull verifies Merkle hash integrity and reports SessionID gaps.
// Unlike the package-level Verify, this returns structured results including
// chain continuity information across process restarts.
func (c *Client) VerifyFull() (*VerifyResult, error) {
entries, err := c.chain.ReadAll()
if err != nil {
return nil, fmt.Errorf("writ.VerifyFull: read chain: %w", err)
}
result := &VerifyResult{EntryCount: len(entries)}
if err := verifyChain(entries); err != nil {
result.Valid = false
result.FirstBreak = firstBrokenEntry(entries)
} else {
result.Valid = true
if len(entries) > 0 {
result.RootHash = entries[len(entries)-1].Hash
}
}
for i := 1; i < len(entries); i++ {
prev, curr := entries[i-1], entries[i]
if prev.SessionID != "" && curr.SessionID != "" && curr.SessionID != prev.SessionID {
result.SessionGaps = append(result.SessionGaps, SessionGap{
AfterEntryIndex: i - 1,
PrevSessionID: prev.SessionID,
NextSessionID: curr.SessionID,
})
}
}
return result, nil
}

// firstBrokenEntry returns the first ChainEntry whose hash link is broken,
// or nil if the chain is intact.
func firstBrokenEntry(entries []ChainEntry) *ChainEntry {
for i := range entries {
copy := entries[i]
if i == len(entries)-1 {
return &copy
}
if entries[i+1].PrevHash != entries[i].Hash {
return &copy
}
}
return nil
}

// Audit writes an explicit event to the writ chain. Use for tool use events
// (file read, shell exec, web fetch) that require Article 12 granularity.
// The chain entry includes a Merkle link to the previous entry.
func (c *Client) Audit(event AuditEvent) error {
if event.Timestamp.IsZero() {
event.Timestamp = time.Now().UTC()
}
entry, err := buildChainEntry(c.chain, event, c.cfg.CallerID, c.cfg.HookdTraceID)
entry, err := buildChainEntry(c.chain, event, c.cfg.CallerID, c.cfg.HookdTraceID, c.cfg.SessionID)
if err != nil {
return fmt.Errorf("writ.Audit: build entry: %w", err)
}
Expand Down
Loading