Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions internal/adapters/nylas/demo_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func (d *DemoClient) ListPolicies(ctx context.Context) ([]domain.Policy, error)
Name: "Demo Policy",
ApplicationID: "app-demo",
OrganizationID: "org-demo",
Rules: []string{"rule-demo-1"},
},
}, nil
}
Expand All @@ -23,6 +24,7 @@ func (d *DemoClient) GetPolicy(ctx context.Context, policyID string) (*domain.Po
Name: "Demo Policy",
ApplicationID: "app-demo",
OrganizationID: "org-demo",
Rules: []string{"rule-demo-1"},
}, nil
}

Expand All @@ -33,6 +35,7 @@ func (d *DemoClient) CreatePolicy(ctx context.Context, payload map[string]any) (
Name: name,
ApplicationID: "app-demo",
OrganizationID: "org-demo",
Rules: []string{"rule-demo-1"},
}, nil
}

Expand All @@ -43,6 +46,7 @@ func (d *DemoClient) UpdatePolicy(ctx context.Context, policyID string, payload
Name: name,
ApplicationID: "app-demo",
OrganizationID: "org-demo",
Rules: []string{"rule-demo-1"},
}, nil
}

Expand Down
4 changes: 4 additions & 0 deletions internal/adapters/nylas/mock_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ func (m *MockClient) ListPolicies(ctx context.Context) ([]domain.Policy, error)
Name: "Default Policy",
ApplicationID: "app-123",
OrganizationID: "org-123",
Rules: []string{"rule-1"},
},
}, nil
}
Expand All @@ -23,6 +24,7 @@ func (m *MockClient) GetPolicy(ctx context.Context, policyID string) (*domain.Po
Name: "Default Policy",
ApplicationID: "app-123",
OrganizationID: "org-123",
Rules: []string{"rule-1"},
}, nil
}

Expand All @@ -33,6 +35,7 @@ func (m *MockClient) CreatePolicy(ctx context.Context, payload map[string]any) (
Name: name,
ApplicationID: "app-123",
OrganizationID: "org-123",
Rules: []string{"rule-1"},
}, nil
}

Expand All @@ -43,6 +46,7 @@ func (m *MockClient) UpdatePolicy(ctx context.Context, policyID string, payload
Name: name,
ApplicationID: "app-123",
OrganizationID: "org-123",
Rules: []string{"rule-1"},
}, nil
}

Expand Down
11 changes: 10 additions & 1 deletion internal/air/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ func sanitizeEmail(email string) string {
return safe + ".db"
}

func isAccountDBFile(name string) bool {
if !strings.HasSuffix(name, ".db") || strings.HasSuffix(name, "-wal") || strings.HasSuffix(name, "-shm") {
return false
}

// Shared databases are not per-account caches.
return name != "photos.db"
}

// DBPath returns the database path for an email.
func (m *Manager) DBPath(email string) string {
return filepath.Join(m.basePath, sanitizeEmail(email))
Expand Down Expand Up @@ -240,7 +249,7 @@ func (m *Manager) ListCachedAccounts() ([]string, error) {
var emails []string
for _, entry := range entries {
name := entry.Name()
if strings.HasSuffix(name, ".db") && !strings.HasSuffix(name, "-wal") && !strings.HasSuffix(name, "-shm") {
if isAccountDBFile(name) {
email := strings.TrimSuffix(name, ".db")
emails = append(emails, email)
}
Expand Down
21 changes: 21 additions & 0 deletions internal/air/cache/emails.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,15 @@ func (s *EmailStore) UpdateFlags(id string, unread, starred *bool) error {
return nil
}

return s.UpdateMessage(id, unread, starred, nil)
}

// UpdateMessage updates cached message flags and folder placement.
func (s *EmailStore) UpdateMessage(id string, unread, starred *bool, folders []string) error {
if unread == nil && starred == nil && folders == nil {
return nil
}

// Use strings.Builder to avoid string concatenation in loop
var query strings.Builder
query.WriteString("UPDATE emails SET")
Expand All @@ -290,6 +299,18 @@ func (s *EmailStore) UpdateFlags(id string, unread, starred *bool) error {
}
query.WriteString(" starred = ?")
args = append(args, boolToInt(*starred))
needComma = true
}
if folders != nil {
if needComma {
query.WriteString(",")
}
folderID := ""
if len(folders) > 0 {
folderID = folders[0]
}
query.WriteString(" folder_id = ?")
args = append(args, folderID)
}

query.WriteString(" WHERE id = ?")
Expand Down
93 changes: 78 additions & 15 deletions internal/air/cache/encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"encoding/hex"
"fmt"
"os"
"time"

// Import for side effects - registers sqlite3 driver and adiantum VFS
_ "github.com/ncruces/go-sqlite3/driver"
Expand All @@ -29,14 +30,21 @@ const (
// allowedTables is a whitelist of table names that can be used in SQL queries.
// This prevents SQL injection by ensuring only known table names are used.
var allowedTables = map[string]bool{
"emails": true,
"events": true,
"contacts": true,
"folders": true,
"calendars": true,
"sync_state": true,
"emails": true,
"events": true,
"contacts": true,
"folders": true,
"calendars": true,
"sync_state": true,
"attachments": true,
"offline_queue": true,
}

var (
getOrCreateKeyFunc = getOrCreateKey
deleteKeyFunc = deleteKey
)

// tableNames returns the list of allowed table names for migration operations.
func tableNames() []string {
names := make([]string, 0, len(allowedTables))
Expand Down Expand Up @@ -108,8 +116,9 @@ func openEncryptedDB(dbPath string, key []byte) (*sql.DB, error) {
return nil, fmt.Errorf("open encrypted database: %w", err)
}

// Verify the key works by running a simple query
if _, err := db.Exec("SELECT 1"); err != nil {
// Verify the key works by reading the schema, which fails with the wrong key.
var schemaObjects int
if err := db.QueryRow("SELECT COUNT(*) FROM sqlite_master").Scan(&schemaObjects); err != nil {
_ = db.Close()
return nil, fmt.Errorf("verify encryption key: %w", err)
}
Expand Down Expand Up @@ -161,7 +170,7 @@ func (m *EncryptedManager) GetDB(email string) (*sql.DB, error) {
}

// Get or create encryption key
key, err := getOrCreateKey(email)
key, err := getOrCreateKeyFunc(email)
if err != nil {
return nil, fmt.Errorf("get encryption key for %s: %w", email, err)
}
Expand Down Expand Up @@ -207,7 +216,7 @@ func (m *EncryptedManager) ClearCache(email string) error {
// Remove encryption key if encryption is enabled
if m.encryption.Enabled {
delete(m.keys, email)
if err := deleteKey(email); err != nil {
if err := deleteKeyFunc(email); err != nil {
// Log but don't fail - key might not exist
fmt.Fprintf(os.Stderr, "warning: failed to delete encryption key: %v\n", err)
}
Expand All @@ -233,7 +242,7 @@ func (m *EncryptedManager) MigrateToEncrypted(email string) error {
defer func() { _ = unencryptedDB.Close() }()

// Get or create encryption key
key, err := getOrCreateKey(email)
key, err := getOrCreateKeyFunc(email)
if err != nil {
return fmt.Errorf("get encryption key: %w", err)
}
Expand Down Expand Up @@ -298,7 +307,7 @@ func (m *EncryptedManager) MigrateToUnencrypted(email string) error {
key, ok := m.keys[email]
if !ok {
var err error
key, err = getOrCreateKey(email)
key, err = getOrCreateKeyFunc(email)
if err != nil {
return fmt.Errorf("get encryption key: %w", err)
}
Expand Down Expand Up @@ -352,12 +361,33 @@ func (m *EncryptedManager) MigrateToUnencrypted(email string) error {
_ = os.Remove(backupPath)
_ = os.Remove(backupPath + "-wal")
_ = os.Remove(backupPath + "-shm")
_ = deleteKey(email)
_ = deleteKeyFunc(email)
delete(m.keys, email)

return nil
}

// ClearAllCaches removes all encrypted cache databases and associated keys.
func (m *EncryptedManager) ClearAllCaches() error {
accounts, err := m.ListCachedAccounts()
if err != nil {
return err
}

if err := m.Manager.ClearAllCaches(); err != nil {
return err
}

for _, email := range accounts {
delete(m.keys, email)
if err := deleteKeyFunc(email); err != nil {
fmt.Fprintf(os.Stderr, "warning: failed to delete encryption key: %v\n", err)
}
}

return nil
}

// copyTable copies all rows from one table to another.
func copyTable(src, dst *sql.DB, table string) error {
// Validate table name against whitelist to prevent SQL injection
Expand Down Expand Up @@ -437,12 +467,45 @@ func IsEncrypted(dbPath string) (bool, error) {
}
defer func() { _ = db.Close() }()

// Try a simple query - will fail if encrypted
_, err = db.Exec("SELECT 1")
// Read the schema - this fails when the database is encrypted and opened without a key.
var schemaObjects int
err = db.QueryRow("SELECT COUNT(*) FROM sqlite_master").Scan(&schemaObjects)
if err != nil {
// Database exists but can't be read - likely encrypted
return true, nil
}

return false, nil
}

// GetStats returns statistics for an encrypted cache database.
func (m *EncryptedManager) GetStats(email string) (*CacheStats, error) {
db, err := m.GetDB(email)
if err != nil {
return nil, err
}

stats := &CacheStats{Email: email}

info, err := os.Stat(m.DBPath(email))
if err == nil {
stats.SizeBytes = info.Size()
}

row := db.QueryRow("SELECT COUNT(*) FROM emails")
_ = row.Scan(&stats.EmailCount)

row = db.QueryRow("SELECT COUNT(*) FROM events")
_ = row.Scan(&stats.EventCount)

row = db.QueryRow("SELECT COUNT(*) FROM contacts")
_ = row.Scan(&stats.ContactCount)

var lastSync int64
row = db.QueryRow("SELECT MAX(last_sync) FROM sync_state")
if err := row.Scan(&lastSync); err == nil && lastSync > 0 {
stats.LastSync = time.Unix(lastSync, 0)
}

return stats, nil
}
Loading
Loading