diff --git a/lib/serial/README.md b/lib/serial/README.md index 7e8d4a7..065a491 100644 --- a/lib/serial/README.md +++ b/lib/serial/README.md @@ -67,7 +67,7 @@ print(dumps({'b': 2, 'a': 1})) Reconstruct the value from a `dumps` string, interpreting the type tags. The result is a fresh, **unfrozen** value (the same as `json.decode`), so scripts can read or mutate it. -**Errors on**: invalid JSON (`serial.loads: …`), an unknown type tag (`unknown type tag "…"`), or a malformed envelope payload (`invalid bytes payload`, `invalid bigint payload`, `invalid time payload`, `invalid mapkv entry`, `invalid object payload`). A `set` with an unhashable element or a `mapkv` with an unhashable key errors (`unhashable`). Bare JSON numbers decode without a tag: an integer (any precision) to `int`, a number with `.`/`e`/`E` to `float`. +**Errors on**: invalid JSON (`serial.loads: …`), trailing content after the value (`unexpected trailing data after JSON value` — a second JSON document or garbage; trailing whitespace is fine), an unknown type tag (`unknown type tag "…"`), or a malformed envelope payload whose `v` has the wrong shape (`invalid bytes payload`, `invalid bigint payload`, `invalid time payload`, `invalid tuple payload`, `invalid set payload`, `invalid mapkv payload`/`invalid mapkv entry`, `invalid object payload`). A `set` with an unhashable element or a `mapkv` with an unhashable key errors (`unhashable`). Bare JSON numbers decode without a tag: an integer (any precision) to `int`, a number with `.`/`e`/`E` to `float`. ```python load('serial', 'loads') diff --git a/lib/serial/serial.go b/lib/serial/serial.go index 0059907..1011760 100644 --- a/lib/serial/serial.go +++ b/lib/serial/serial.go @@ -124,6 +124,13 @@ func generateLoads(try bool) func(*starlark.Thread, *starlark.Builtin, starlark. if err := dec.Decode(&raw); err != nil { return failResult(try, err, fn, true) } + // loads round-trips a single value; reject a second JSON value or + // trailing garbage rather than silently dropping it. dec.More() is + // true only for a further non-whitespace token, so trailing + // whitespace/newlines still pass. + if dec.More() { + return failResult(try, fmt.Errorf("unexpected trailing data after JSON value"), fn, true) + } val, err := decode(raw) if err != nil { return failResult(try, err, fn, true) @@ -399,7 +406,10 @@ func decodeObject(m map[string]interface{}) (starlark.Value, error) { } func decodeBytes(raw interface{}) (starlark.Value, error) { - s, _ := raw.(string) + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("invalid bytes payload: expected a base64 string") + } b, err := base64.StdEncoding.DecodeString(s) if err != nil { return nil, fmt.Errorf("invalid bytes payload: %w", err) @@ -408,16 +418,22 @@ func decodeBytes(raw interface{}) (starlark.Value, error) { } func decodeBigint(raw interface{}) (starlark.Value, error) { - s, _ := raw.(string) - bi, ok := new(big.Int).SetString(s, 10) + s, ok := raw.(string) if !ok { + return nil, fmt.Errorf("invalid bigint payload: expected a decimal string") + } + bi, valid := new(big.Int).SetString(s, 10) + if !valid { return nil, fmt.Errorf("invalid bigint payload %q", s) } return starlark.MakeBigInt(bi), nil } func decodeTuple(raw interface{}) (starlark.Value, error) { - arr, _ := raw.([]interface{}) + arr, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid tuple payload: expected an array") + } elems := make([]starlark.Value, len(arr)) for i, e := range arr { ev, err := decode(e) @@ -430,7 +446,10 @@ func decodeTuple(raw interface{}) (starlark.Value, error) { } func decodeSet(raw interface{}) (starlark.Value, error) { - arr, _ := raw.([]interface{}) + arr, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid set payload: expected an array") + } set := starlark.NewSet(len(arr)) for _, e := range arr { ev, err := decode(e) @@ -445,7 +464,10 @@ func decodeSet(raw interface{}) (starlark.Value, error) { } func decodeTime(raw interface{}) (starlark.Value, error) { - s, _ := raw.(string) + s, ok := raw.(string) + if !ok { + return nil, fmt.Errorf("invalid time payload: expected an RFC3339 string") + } tm, err := time.Parse(time.RFC3339Nano, s) if err != nil { return nil, fmt.Errorf("invalid time payload %q: %w", s, err) @@ -454,7 +476,10 @@ func decodeTime(raw interface{}) (starlark.Value, error) { } func decodeMapKV(raw interface{}) (starlark.Value, error) { - arr, _ := raw.([]interface{}) + arr, ok := raw.([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid mapkv payload: expected an array of [key, value] pairs") + } d := starlark.NewDict(len(arr)) for _, pr := range arr { kvp, ok := pr.([]interface{}) diff --git a/lib/serial/serial_test.go b/lib/serial/serial_test.go index 83b2eee..c866cd3 100644 --- a/lib/serial/serial_test.go +++ b/lib/serial/serial_test.go @@ -387,6 +387,36 @@ func TestLoadModule_Serial(t *testing.T) { `), wantErr: `unknown type tag`, }, + { + // a wrong-typed envelope payload (e.g. a number where an array or + // string is expected) used to be silently coerced to an empty + // value instead of erroring, violating the lossless-or-error + // contract. Each tag decoder must now reject the wrong shape. + name: `error: malformed envelope payloads reject the wrong type`, + script: itn.HereDoc(` + load('serial', 'loads') + assert.fails(lambda: loads('{"$t":"tuple","v":123}'), 'invalid tuple payload') + assert.fails(lambda: loads('{"$t":"set","v":123}'), 'invalid set payload') + assert.fails(lambda: loads('{"$t":"mapkv","v":123}'), 'invalid mapkv payload') + assert.fails(lambda: loads('{"$t":"bytes","v":123}'), 'invalid bytes payload') + assert.fails(lambda: loads('{"$t":"bigint","v":123}'), 'invalid bigint payload') + assert.fails(lambda: loads('{"$t":"time","v":123}'), 'invalid time payload') + `), + }, + { + // loads round-trips a single value; a second JSON value or + // trailing garbage was silently dropped. Trailing whitespace is + // still fine. + name: `error: trailing content after the JSON value is rejected`, + script: itn.HereDoc(` + load('serial', 'loads') + assert.eq(loads('1'), 1) + assert.eq(loads(' [1, 2] \n'), [1, 2]) + assert.fails(lambda: loads('1 2'), 'trailing data') + assert.fails(lambda: loads('{"a":1}{"b":2}'), 'trailing data') + assert.fails(lambda: loads('[1,2] junk'), 'trailing data') + `), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {