From 49ee8be679add0bd3cf08a2669331b3be7a835f8 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Fri, 17 Feb 2023 19:21:37 -0700 Subject: compat/json: Correctly handle syntax-error-in-decode --- ReleaseNotes.md | 14 ++++++++ compat/json/compat.go | 82 ++++++++++++++++++++++++++++++++++++++++----- compat/json/compat_test.go | 78 ++++++++++++++++++++++++++++++++++++++++++ decode.go | 2 ++ decode_scan.go | 6 ++++ internal/jsonparse/parse.go | 15 +++++++++ 6 files changed, 188 insertions(+), 9 deletions(-) diff --git a/ReleaseNotes.md b/ReleaseNotes.md index a8496e0..5e8dab7 100644 --- a/ReleaseNotes.md +++ b/ReleaseNotes.md @@ -13,6 +13,11 @@ default, add a `CompactFloats` ReEncoderConfig option to control this. + + Decoder: Decoding `json.Unmarshaler` or `lowmemjson.Decodable` + as a top-level value no longer needs to read past the closing + `"`/`]`/`}`; this can be significant when reading streaming + input, as that next read may block. + - Compatibility bugfixes: + compat/json.Valid: No longer considers truncated JSON @@ -32,6 +37,15 @@ + compat/json.Indent: Preserve trailing whitespace, same as `encoding/json`. + + compat/json.Decoder: No longer transforms "unexpected EOF" + errors to "unexpected end of JSON input". This makes it + different than `compat/json.Unmarshal`, but the same as + `encoding/json`. + + + compat/json.Decoder, compat/json.Unmarshal: No longer mutate + the target value at all if there is a syntax error in the + input. + - Unicode: + Feature: Encoder, ReEncoder: Add an `InvalidUTF8` diff --git a/compat/json/compat.go b/compat/json/compat.go index 3a9bd6c..695c1a8 100644 --- a/compat/json/compat.go +++ b/compat/json/compat.go @@ -237,14 +237,14 @@ func Valid(data []byte) bool { // Decode wrappers /////////////////////////////////////////////////// -func convertDecodeError(err error) error { +func convertDecodeError(err error, isUnmarshal bool) error { if derr, ok := err.(*lowmemjson.DecodeError); ok { switch terr := derr.Err.(type) { case *lowmemjson.DecodeSyntaxError: switch { case errors.Is(terr.Err, io.EOF): err = io.EOF - case errors.Is(terr.Err, io.ErrUnexpectedEOF): + case errors.Is(terr.Err, io.ErrUnexpectedEOF) && isUnmarshal: err = &SyntaxError{ msg: "unexpected end of JSON input", Offset: terr.Offset, @@ -284,13 +284,66 @@ func convertDecodeError(err error) error { return err } +type decodeValidator struct{} + +func (*decodeValidator) DecodeJSON(r io.RuneScanner) error { + for { + if _, _, err := r.ReadRune(); err != nil { + + if err == io.EOF { + return nil + } + return err + } + } +} + +var _ lowmemjson.Decodable = (*decodeValidator)(nil) + func Unmarshal(data []byte, ptr any) error { - return convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr)) + if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(&decodeValidator{}), true); err != nil { + return err + } + if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr), true); err != nil { + return err + } + return nil +} + +type teeRuneScanner struct { + src io.RuneScanner + dst *bytes.Buffer + lastSize int +} + +func (tee *teeRuneScanner) ReadRune() (r rune, size int, err error) { + r, size, err = tee.src.ReadRune() + if err == nil { + if _, err := tee.dst.WriteRune(r); err != nil { + return 0, 0, err + } + } + + tee.lastSize = size + return +} + +func (tee *teeRuneScanner) UnreadRune() error { + if tee.lastSize == 0 { + return lowmemjson.ErrInvalidUnreadRune + } + _ = tee.src.UnreadRune() + tee.dst.Truncate(tee.dst.Len() - tee.lastSize) + tee.lastSize = 0 + return nil } type Decoder struct { + validatorBuf *bufio.Reader + validator *lowmemjson.Decoder + + decoderBuf bytes.Buffer *lowmemjson.Decoder - buf *bufio.Reader } func NewDecoder(r io.Reader) *Decoder { @@ -298,18 +351,29 @@ func NewDecoder(r io.Reader) *Decoder { if !ok { br = bufio.NewReader(r) } - return &Decoder{ - Decoder: lowmemjson.NewDecoder(br), - buf: br, + ret := &Decoder{ + validatorBuf: br, } + ret.validator = lowmemjson.NewDecoder(&teeRuneScanner{ + src: ret.validatorBuf, + dst: &ret.decoderBuf, + }) + ret.Decoder = lowmemjson.NewDecoder(&ret.decoderBuf) + return ret } func (dec *Decoder) Decode(ptr any) error { - return convertDecodeError(dec.Decoder.Decode(ptr)) + if err := convertDecodeError(dec.validator.Decode(&decodeValidator{}), false); err != nil { + return err + } + if err := convertDecodeError(dec.Decoder.Decode(ptr), false); err != nil { + return err + } + return nil } func (dec *Decoder) Buffered() io.Reader { - dat, _ := dec.buf.Peek(dec.buf.Buffered()) + dat, _ := dec.validatorBuf.Peek(dec.validatorBuf.Buffered()) return bytes.NewReader(dat) } diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go index 29a8b37..df9d387 100644 --- a/compat/json/compat_test.go +++ b/compat/json/compat_test.go @@ -6,6 +6,8 @@ package json import ( "bytes" + "reflect" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -161,3 +163,79 @@ func TestCompatMarshal(t *testing.T) { }) } } + +func TestCompatUnmarshal(t *testing.T) { + t.Parallel() + type testcase struct { + In string + InPtr any + ExpOut any + ExpErr string + } + testcases := map[string]testcase{ + "empty-obj": {In: `{}`, ExpOut: map[string]any{}}, + "partial-obj": {In: `{"foo":"bar",`, ExpOut: nil, ExpErr: `unexpected end of JSON input`}, + "existing-obj": {In: `{"baz":"quz"}`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar", "baz": "quz"}}, + "existing-obj-partial": {In: `{"baz":"quz"`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar"}, ExpErr: "unexpected end of JSON input"}, + "empty-ary": {In: `[]`, ExpOut: []any{}}, + "two-objs": {In: `{} {}`, ExpOut: nil, ExpErr: `invalid character '{' after top-level value`}, + "two-numbers1": {In: `00`, ExpOut: nil, ExpErr: `invalid character '0' after top-level value`}, + "two-numbers2": {In: `1 2`, ExpOut: nil, ExpErr: `invalid character '2' after top-level value`}, + } + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + t.Parallel() + ptr := tc.InPtr + if ptr == nil { + var out any + ptr = &out + } + err := Unmarshal([]byte(tc.In), ptr) + assert.Equal(t, tc.ExpOut, reflect.ValueOf(ptr).Elem().Interface()) + if tc.ExpErr == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.ExpErr) + } + }) + } +} + +func TestCompatDecode(t *testing.T) { + t.Parallel() + type testcase struct { + In string + InPtr any + ExpOut any + ExpErr string + } + testcases := map[string]testcase{ + "empty-obj": {In: `{}`, ExpOut: map[string]any{}}, + "partial-obj": {In: `{"foo":"bar",`, ExpOut: nil, ExpErr: `unexpected EOF`}, + "existing-obj": {In: `{"baz":"quz"}`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar", "baz": "quz"}}, + "existing-obj-partial": {In: `{"baz":"quz"`, InPtr: &map[string]string{"foo": "bar"}, ExpOut: map[string]string{"foo": "bar"}, ExpErr: "unexpected EOF"}, + "empty-ary": {In: `[]`, ExpOut: []any{}}, + "two-objs": {In: `{} {}`, ExpOut: map[string]any{}}, + "two-numbers1": {In: `00`, ExpOut: float64(0)}, + "two-numbers2": {In: `1 2`, ExpOut: float64(1)}, + } + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + t.Parallel() + ptr := tc.InPtr + if ptr == nil { + var out any + ptr = &out + } + err := NewDecoder(strings.NewReader(tc.In)).Decode(ptr) + assert.Equal(t, tc.ExpOut, reflect.ValueOf(ptr).Elem().Interface()) + if tc.ExpErr == "" { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, tc.ExpErr) + } + }) + } +} diff --git a/decode.go b/decode.go index 491971a..a136668 100644 --- a/decode.go +++ b/decode.go @@ -53,6 +53,8 @@ import ( // or another is encountered; if it does not, then the parent Decode // call will return a *DecodeTypeError. // +// DecodeJSON should return nil (not io.EOF) on success. +// // Implementor's note: "withLimitingScanner" is the thing to search // for in decode.go if you want to read up on that io.RuneScanner. type Decodable interface { diff --git a/decode_scan.go b/decode_scan.go index fcf47ff..63694c4 100644 --- a/decode_scan.go +++ b/decode_scan.go @@ -41,6 +41,12 @@ func (sc *runeTypeScanner) ReadRuneType() (rune, int, jsonparse.RuneType, error) case sc.repeat: sc.offset += int64(sc.rSize) _, _, _ = sc.inner.ReadRune() + case sc.parser.IsAtBarrier(): + sc.rTypeOK = true + sc.rType = jsonparse.RuneTypeEOF + sc.rRune = 0 + sc.rSize = 0 + sc.rErr = nil default: sc.rTypeOK = true again: diff --git a/internal/jsonparse/parse.go b/internal/jsonparse/parse.go index 1c35533..6432d75 100644 --- a/internal/jsonparse/parse.go +++ b/internal/jsonparse/parse.go @@ -525,6 +525,21 @@ func (par *Parser) HandleEOF() (RuneType, error) { } } +// IsAtBarrier returns whether a read-barrier has been reached and the +// next HandleRune call would definitely return RuneTypeEOF. +func (par *Parser) IsAtBarrier() bool { + return par.initialized && + // HandleRune wouldn't return early with an error. + !par.closed && + par.err == nil && + // The current (sub-)parser has reached its end, and + len(par.stack) == 0 && + // there is a barrier, and + len(par.barriers) > 0 && + // that barrier would definitely return RuneTypeEOF. + !par.barriers[len(par.barriers)-1].allowWS +} + // HandleRune feeds a Unicode rune to the Parser. // // An error is returned if and only if the RuneType is RuneTypeError. -- cgit v1.2.3