diff --git a/internal/json/json.go b/internal/json/json.go index 37c515a4..7c9b5b1c 100644 --- a/internal/json/json.go +++ b/internal/json/json.go @@ -6,10 +6,23 @@ import ( json "github.com/bytedance/sonic" ) +// useInt64API decodes JSON integers into int64 instead of float64, preserving +// full precision for values whose magnitude exceeds 2^53. Non-integer numbers +// still decode as float64. +var useInt64API = json.Config{UseInt64: true}.Froze() + func Unmarshal(b []byte, v any) error { return json.Unmarshal(b, v) } +// UnmarshalUseInt64 behaves like Unmarshal but decodes JSON integers into +// int64 (rather than float64) when the destination is an interface{}. This +// matters for full-range bigint values from wal2json: float64 only has 53 +// bits of mantissa and would silently round. +func UnmarshalUseInt64(b []byte, v any) error { + return useInt64API.Unmarshal(b, v) +} + func Marshal(v any) ([]byte, error) { return json.Marshal(v) } diff --git a/internal/json/json_test.go b/internal/json/json_test.go new file mode 100644 index 00000000..2dd195b1 --- /dev/null +++ b/internal/json/json_test.go @@ -0,0 +1,105 @@ +// SPDX-License-Identifier: Apache-2.0 + +package json + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestUnmarshalUseInt64(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + expected any + }{ + { + name: "small integer decodes as int64", + input: `{"value": 42}`, + expected: int64(42), + }, + { + name: "bigint above 2^53 preserved", + input: `{"value": 9007199254740993}`, + expected: int64(9007199254740993), + }, + { + name: "max int64 preserved", + input: `{"value": 9223372036854775807}`, + expected: int64(9223372036854775807), + }, + { + name: "negative bigint preserved", + input: `{"value": -9223372036854775808}`, + expected: int64(-9223372036854775808), + }, + { + name: "non-integer number decodes as float64", + input: `{"value": 1.5}`, + expected: 1.5, + }, + { + // `1.0` is mathematically an integer but the JSON literal + // contains a decimal point, so sonic treats it as float64. + // Pin this down — Postgres serialises non-integer numerics + // this way and pgstream relies on the type discriminator. + name: "integer-valued literal with decimal point stays float64", + input: `{"value": 1.0}`, + expected: 1.0, + }, + { + // Scientific notation always decodes as float64 regardless of + // whether the value is mathematically an integer. + name: "scientific notation decodes as float64", + input: `{"value": 1e10}`, + expected: 1e10, + }, + { + name: "negative scientific notation decodes as float64", + input: `{"value": -1.5e2}`, + expected: -1.5e2, + }, + { + name: "zero decodes as int64", + input: `{"value": 0}`, + expected: int64(0), + }, + { + name: "small negative integer decodes as int64", + input: `{"value": -1}`, + expected: int64(-1), + }, + { + // 2^63 doesn't fit in int64; sonic falls back to float64 + // rather than wrapping around or returning an error. + name: "value above MaxInt64 falls back to float64", + input: `{"value": 9223372036854775808}`, + expected: float64(9223372036854775808), + }, + { + name: "null decodes as nil", + input: `{"value": null}`, + expected: nil, + }, + { + // Sanity check that the rule applies inside arrays / nested + // objects too — this is what makes the snapshot/jsonb fix + // for #686 work end-to-end. + name: "nested large integer preserved through array", + input: `{"value": [1, 9223372036854775807, 3]}`, + expected: []any{int64(1), int64(9223372036854775807), int64(3)}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + var got map[string]any + require.NoError(t, UnmarshalUseInt64([]byte(tt.input), &got)) + require.Equal(t, tt.expected, got["value"]) + }) + } +} diff --git a/internal/postgres/pg_utils.go b/internal/postgres/pg_utils.go index 5d7f8350..86189eb0 100644 --- a/internal/postgres/pg_utils.go +++ b/internal/postgres/pg_utils.go @@ -16,6 +16,7 @@ import ( "github.com/jackc/pgx/v5/pgtype" "github.com/lib/pq" pgxvec "github.com/pgvector/pgvector-go/pgx" + pgjson "github.com/xataio/pgstream/internal/json" ) type QualifiedName struct { @@ -164,6 +165,9 @@ type extensionType struct { // extensionTypes lists the postgres extension types pgstream teaches pgx // about on every connection. var extensionTypes = []extensionType{ + {name: "json", register: registerWithCodec("json", &pgtype.JSONCodec{Marshal: pgjson.Marshal, Unmarshal: pgjson.UnmarshalUseInt64})}, + {name: "jsonb", register: registerWithCodec("jsonb", &pgtype.JSONBCodec{Marshal: pgjson.Marshal, Unmarshal: pgjson.UnmarshalUseInt64})}, + {name: "hstore", register: registerWithCodec("hstore", pgtype.HstoreCodec{})}, {name: "vector", register: func(ctx context.Context, conn *pgx.Conn, _ uint32) error { // pgxvec registers vector, halfvec and sparsevec in one call — diff --git a/pkg/stream/integration/pg_pg_integration_test.go b/pkg/stream/integration/pg_pg_integration_test.go index 0158d9f1..ac4c70c5 100644 --- a/pkg/stream/integration/pg_pg_integration_test.go +++ b/pkg/stream/integration/pg_pg_integration_test.go @@ -6,6 +6,7 @@ import ( "context" "fmt" "os" + "reflect" "testing" "time" @@ -947,3 +948,128 @@ func getRoles(t *testing.T, ctx context.Context, conn pglib.Querier) []string { return roles } + +// Test_PostgresToPostgres_LargeIntegerPrecisionWAL is the WAL-replication +// counterpart of Test_SnapshotToPostgres_LargeIntegerPrecision. It pins down +// that INSERT / UPDATE events carrying large integers — both as plain bigint +// column values and as integer fields inside a jsonb payload — replicate +// bit-for-bit through the wal2json → listener → target path. +// +// Covers the WAL side of #824 (bigint) and #686 (jsonb large int). +func Test_PostgresToPostgres_LargeIntegerPrecisionWAL(t *testing.T) { + if os.Getenv("PGSTREAM_INTEGRATION_TESTS") == "" { + t.Skip("skipping integration test...") + } + + cfg := &stream.Config{ + Listener: testPostgresListenerCfg(), + Processor: testPostgresProcessorCfg(), + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + runStream(t, ctx, cfg) + + testTable := "pg2pg_largeint_wal" + + // 9007199254740993 = 2^53 + 1, the smallest int64 that float64 cannot + // represent exactly. 9223372036854775807 is MaxInt64. + execQuery(t, ctx, fmt.Sprintf( + `CREATE TABLE %s( + id bigint PRIMARY KEY, + amount bigint NOT NULL, + payload jsonb NOT NULL + )`, testTable)) + + type row struct { + id int64 + amount int64 + payload string + } + + targetConn, err := pglib.NewConn(ctx, targetPGURL) + require.NoError(t, err) + + fetch := func() ([]row, error) { + rows, err := targetConn.Query(ctx, + fmt.Sprintf("SELECT id, amount, payload::text FROM %s ORDER BY id", testTable)) + if err != nil { + return nil, err + } + defer rows.Close() + out := []row{} + for rows.Next() { + var r row + if err := rows.Scan(&r.id, &r.amount, &r.payload); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() + } + + waitFor := func(want []row) { + t.Helper() + timer := time.NewTimer(20 * time.Second) + defer timer.Stop() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + for { + select { + case <-timer.C: + cancel() + t.Fatalf("timeout waiting for WAL replication; last fetch did not match %v", want) + case <-ticker.C: + got, err := fetch() + if err != nil || len(got) != len(want) { + continue + } + if reflect.DeepEqual(got, want) { + require.Equal(t, want, got) + return + } + } + } + } + + t.Run("insert bigint above 2^53 + jsonb with large int", func(t *testing.T) { + execQuery(t, ctx, fmt.Sprintf( + `INSERT INTO %s(id, amount, payload) VALUES + (9007199254740993, 9007199254740993, '{"n":9007199254740993}'), + (9223372036854775807, 9223372036854775807, '{"n":9223372036854775807,"nested":{"k":1234567890123456789}}')`, + testTable)) + + waitFor([]row{ + {id: 9007199254740993, amount: 9007199254740993, payload: `{"n": 9007199254740993}`}, + {id: 9223372036854775807, amount: 9223372036854775807, payload: `{"n": 9223372036854775807, "nested": {"k": 1234567890123456789}}`}, + }) + }) + + t.Run("update jsonb payload with new large int", func(t *testing.T) { + execQuery(t, ctx, fmt.Sprintf( + `UPDATE %s + SET payload = jsonb_set(payload, '{k2}', '987654321987654321'::jsonb) + WHERE id = 9223372036854775807`, + testTable)) + + // jsonb stores keys ordered by length then alphabetically, so + // `n` (1) before `k2` (2) before `nested` (6). + waitFor([]row{ + {id: 9007199254740993, amount: 9007199254740993, payload: `{"n": 9007199254740993}`}, + {id: 9223372036854775807, amount: 9223372036854775807, payload: `{"n": 9223372036854775807, "k2": 987654321987654321, "nested": {"k": 1234567890123456789}}`}, + }) + }) + + t.Run("update bigint above 2^53 by +1", func(t *testing.T) { + // The repro from #824: incrementing a bigint above 2^53 must + // land on the destination as exactly source+1, not source+0. + execQuery(t, ctx, fmt.Sprintf( + `UPDATE %s SET amount = amount + 1 WHERE id = 9007199254740993`, + testTable)) + + waitFor([]row{ + {id: 9007199254740993, amount: 9007199254740994, payload: `{"n": 9007199254740993}`}, + {id: 9223372036854775807, amount: 9223372036854775807, payload: `{"n": 9223372036854775807, "k2": 987654321987654321, "nested": {"k": 1234567890123456789}}`}, + }) + }) +} diff --git a/pkg/stream/integration/snapshot_pg_integration_test.go b/pkg/stream/integration/snapshot_pg_integration_test.go index a7e8af85..98b6b4e0 100644 --- a/pkg/stream/integration/snapshot_pg_integration_test.go +++ b/pkg/stream/integration/snapshot_pg_integration_test.go @@ -407,3 +407,123 @@ func Test_SnapshotToPostgres_IdentityAndGeneratedColumns(t *testing.T) { run("batch") }) } + +// Test_SnapshotToPostgres_LargeIntegerPrecision verifies that snapshot/restore +// preserves int64 precision for both: +// - plain bigint columns whose value exceeds 2^53 (#824), and +// - large integer values embedded inside json/jsonb columns (#686). +// +// Without the custom internal/json.UnmarshalUseInt64 decoder, JSON integers +// are silently rounded to float64 — the destination row diverges from the +// source by one or more units with no error or warning. +func Test_SnapshotToPostgres_LargeIntegerPrecision(t *testing.T) { + if os.Getenv("PGSTREAM_INTEGRATION_TESTS") == "" { + t.Skip("skipping integration test...") + } + + var snapshotPGURL string + pgcleanup, err := testcontainers.SetupPostgresContainer(context.Background(), &snapshotPGURL, testcontainers.Postgres14, "config/postgresql.conf") + require.NoError(t, err) + defer pgcleanup() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + run := func(suffix string, opts ...option) { + testTable := fmt.Sprintf("largeint_%s", suffix) + + // 9007199254740993 = 2^53 + 1, the smallest int64 that float64 + // cannot represent exactly. 9223372036854775807 is MaxInt64. + execQueryWithURL(t, ctx, snapshotPGURL, fmt.Sprintf( + `CREATE TABLE %s( + id bigint PRIMARY KEY, + amount bigint NOT NULL, + payload jsonb NOT NULL + )`, testTable)) + execQueryWithURL(t, ctx, snapshotPGURL, fmt.Sprintf( + `INSERT INTO %s(id, amount, payload) VALUES + (9007199254740993, 9007199254740993, '{"n":9007199254740993}'), + (9223372036854775807, 9223372036854775807, '{"n":9223372036854775807,"nested":{"k":1234567890123456789}}')`, + testTable)) + + cfg := &stream.Config{ + Listener: testPostgresListenerCfgWithSnapshot(snapshotPGURL, targetPGURL, []string{"*.*"}), + Processor: testPostgresProcessorCfg(opts...), + } + initStream(t, ctx, snapshotPGURL) + runSnapshot(t, ctx, cfg) + + targetConn, err := pglib.NewConn(ctx, targetPGURL) + require.NoError(t, err) + sourceConn, err := pglib.NewConn(ctx, snapshotPGURL) + require.NoError(t, err) + + timer := time.NewTimer(20 * time.Second) + defer timer.Stop() + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + type row struct { + id int64 + amount int64 + payload string + } + query := fmt.Sprintf( + "SELECT id, amount, payload::text FROM %s ORDER BY id", testTable) + fetch := func(conn pglib.Querier) ([]row, error) { + rows, err := conn.Query(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + out := []row{} + for rows.Next() { + var r row + if err := rows.Scan(&r.id, &r.amount, &r.payload); err != nil { + return nil, err + } + out = append(out, r) + } + return out, rows.Err() + } + + want, err := fetch(sourceConn) + require.NoError(t, err) + require.Len(t, want, 2) + // Belt-and-braces: confirm the source actually stores the + // large ints we asked for. If this fails the test setup is + // wrong, not the production code. + require.Equal(t, int64(9007199254740993), want[0].id) + require.Equal(t, int64(9223372036854775807), want[1].id) + require.Contains(t, want[1].payload, `"k": 1234567890123456789`) + + validation := func() bool { + got, err := fetch(targetConn) + if err != nil || len(got) != len(want) { + return false + } + require.Equal(t, want, got) + return true + } + + for { + select { + case <-timer.C: + cancel() + t.Error("timeout waiting for large-int snapshot sync") + return + case <-ticker.C: + if validation() { + return + } + } + } + } + + t.Run("bulk ingest", func(t *testing.T) { + run("bulk", withBulkIngestionEnabled()) + }) + t.Run("batch writer", func(t *testing.T) { + run("batch") + }) +} diff --git a/pkg/wal/listener/postgres/wal_pg_listener.go b/pkg/wal/listener/postgres/wal_pg_listener.go index 5df73aeb..c1c97031 100644 --- a/pkg/wal/listener/postgres/wal_pg_listener.go +++ b/pkg/wal/listener/postgres/wal_pg_listener.go @@ -50,7 +50,7 @@ func New(handler replicationHandler, processEvent listenerProcessWalEvent, opts logger: loglib.NewNoopLogger(), replicationHandler: handler, processEvent: processEvent, - walDataDeserialiser: json.Unmarshal, + walDataDeserialiser: json.UnmarshalUseInt64, lsnParser: handler.GetLSNParser(), } diff --git a/pkg/wal/processor/postgres/postgres_wal_dml_adapter.go b/pkg/wal/processor/postgres/postgres_wal_dml_adapter.go index 31f67174..1b3c5e3a 100644 --- a/pkg/wal/processor/postgres/postgres_wal_dml_adapter.go +++ b/pkg/wal/processor/postgres/postgres_wal_dml_adapter.go @@ -128,7 +128,7 @@ func (a *dmlAdapter) buildInsertQueries(d *wal.Data, schemaInfo schemaInfo) []*q // handle sequence columns that need to be updated after insert for _, col := range d.Columns { if seqName, ok := schemaInfo.sequenceColumns[pglib.QuoteIdentifier(col.Name)]; ok { - colValueFloat, ok := col.Value.(float64) + seqVal, ok := toInt64(col.Value) if !ok { a.logger.Warn(nil, "unexpected value type for sequence column, expected integer", loglib.Fields{ "column_name": col.Name, "column_type": col.Type, "column_value": col.Value, @@ -139,7 +139,7 @@ func (a *dmlAdapter) buildInsertQueries(d *wal.Data, schemaInfo schemaInfo) []*q table: d.Table, schema: d.Schema, sql: "SELECT setval($1::regclass, $2::bigint, true)", - args: []any{seqName, int64(colValueFloat)}, + args: []any{seqName, seqVal}, }) } } @@ -424,6 +424,24 @@ func needsTextCopy(columnTypes []string) bool { return false } +// toInt64 converts a wal.Column.Value into an int64 if it represents an +// integer. WAL data deserialised with UseInt64 produces int64, but snapshots +// and tests may produce other integer types or float64. +func toInt64(v any) (int64, bool) { + switch n := v.(type) { + case int64: + return n, true + case int: + return int64(n), true + case int32: + return int64(n), true + case float64: + return int64(n), true + default: + return 0, false + } +} + func isArray(colType string) bool { // PostgreSQL array types can be represented in two ways: // 1. With [] suffix: text[], int[], etc. diff --git a/pkg/wal/processor/postgres/postgres_wal_dml_adapter_bulk.go b/pkg/wal/processor/postgres/postgres_wal_dml_adapter_bulk.go index f36d55a1..0b13cdbf 100644 --- a/pkg/wal/processor/postgres/postgres_wal_dml_adapter_bulk.go +++ b/pkg/wal/processor/postgres/postgres_wal_dml_adapter_bulk.go @@ -252,14 +252,13 @@ func (a *dmlAdapter) buildBulkInsertQueries(events []*wal.Data, si schemaInfo) [ if !a.forCopy { for _, col := range e.Columns { if seqName, ok := si.sequenceColumns[pglib.QuoteIdentifier(col.Name)]; ok { - colValueFloat, ok := col.Value.(float64) + val, ok := toInt64(col.Value) if !ok { a.logger.Warn(nil, "unexpected value type for sequence column, expected integer", loglib.Fields{ "column_name": col.Name, "column_type": col.Type, "column_value": col.Value, }) continue } - val := int64(colValueFloat) if current, exists := seqMaxValues[seqName]; !exists || val > current { seqMaxValues[seqName] = val } diff --git a/pkg/wal/processor/postgres/postgres_wal_dml_adapter_test.go b/pkg/wal/processor/postgres/postgres_wal_dml_adapter_test.go index d53ed584..1869e9b9 100644 --- a/pkg/wal/processor/postgres/postgres_wal_dml_adapter_test.go +++ b/pkg/wal/processor/postgres/postgres_wal_dml_adapter_test.go @@ -224,6 +224,41 @@ func TestDMLAdapter_walDataToQueries(t *testing.T) { }, }, }, + { + name: "insert with int64 sequence value preserves precision above 2^53", + walData: &wal.Data{ + Action: "I", + Schema: testSchema, + Table: testTable, + Columns: []wal.Column{ + {ID: columnID(1), Name: "id", Value: int64(9007199254740993)}, + {ID: columnID(2), Name: "name", Value: "alice"}, + }, + Metadata: wal.Metadata{ + InternalColIDs: []string{columnID(1)}, + }, + }, + sequenceColumns: map[string]string{ + `"id"`: `"id_seq"`, + }, + forCopy: false, + + wantQueries: []*query{ + { + schema: testSchema, + table: testTable, + columnNames: quotedColumnNames, + sql: fmt.Sprintf("INSERT INTO %s(\"id\", \"name\") OVERRIDING SYSTEM VALUE VALUES($1, $2)", quotedTestTable), + args: []any{int64(9007199254740993), "alice"}, + }, + { + schema: testSchema, + table: testTable, + sql: "SELECT setval($1::regclass, $2::bigint, true)", + args: []any{`"id_seq"`, int64(9007199254740993)}, + }, + }, + }, { name: "insert with sequences - for copy enabled", walData: &wal.Data{