diff options
-rw-r--r-- | decode.go | 34 | ||||
-rw-r--r-- | decode_test.go | 13 | ||||
-rw-r--r-- | encode.go | 11 | ||||
-rw-r--r-- | methods_test.go | 115 |
4 files changed, 162 insertions, 11 deletions
@@ -300,9 +300,13 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): t := dec.peekRuneType() obj := val.Addr().Interface().(Decodable) - if err := obj.DecodeJSON(dec.limitingScanner()); err != nil { + l := dec.limitingScanner() + if err := obj.DecodeJSON(l); err != nil { dec.panicType(t.jsonType(), typ, err) } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.panicType(t.jsonType(), typ, fmt.Errorf("did not consume entire %s", t.jsonType())) + } case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): t := dec.peekRuneType() var buf bytes.Buffer @@ -724,7 +728,7 @@ func (dec *Decoder) decodeAny() any { } // DecodeObject is a helper function for implementing the Decoder interface. -func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func() error) (err error) { +func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) error) (err error) { defer func() { if r := recover(); r != nil { if de, ok := r.(decodeError); ok { @@ -738,13 +742,22 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func() error) (err erro dec := NewDecoder(r) dec.decodeObject(nil, func() { - if err := decodeKey(); err != nil { + l := dec.limitingScanner() + if err := decodeKey(l); err != nil { dec.panicType("string", nil, err) } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.panicType("string", nil, fmt.Errorf("did not consume entire string")) + } }, func() { - if err := decodeVal(); err != nil { - dec.panicType("", nil, err) + t := dec.peekRuneType() + l := dec.limitingScanner() + if err := decodeVal(l); err != nil { + dec.panicType(t.jsonType(), nil, err) + } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.panicType(t.jsonType(), nil, fmt.Errorf("did not consume entire %s", t.jsonType())) } }) return @@ -778,7 +791,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func()) } // DecodeArray is a helper function for implementing the Decoder interface. -func DecodeArray(r io.RuneScanner, decodeMember func() error) (err error) { +func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) (err error) { defer func() { if r := recover(); r != nil { if de, ok := r.(decodeError); ok { @@ -791,8 +804,13 @@ func DecodeArray(r io.RuneScanner, decodeMember func() error) (err error) { }() dec := NewDecoder(r) dec.decodeArray(nil, func() { - if err := decodeMember(); err != nil { - dec.panicType("array", nil, err) + t := dec.peekRuneType() + l := dec.limitingScanner() + if err := decodeMember(l); err != nil { + dec.panicType(t.jsonType(), nil, err) + } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.panicType(t.jsonType(), nil, fmt.Errorf("did not consume entire %s", t.jsonType())) } }) return diff --git a/decode_test.go b/decode_test.go index 8220e39..a9e81e0 100644 --- a/decode_test.go +++ b/decode_test.go @@ -5,6 +5,7 @@ package lowmemjson import ( + "io" "strings" "testing" @@ -19,3 +20,15 @@ func TestDecodeNumber(t *testing.T) { assert.Equal(t, 1, num) assert.Equal(t, 2, r.Len()) // check that it didn't read too far } + +func TestDecodeObject(t *testing.T) { + err := DecodeObject(strings.NewReader(`{"foo":9}`), + func(r io.RuneScanner) error { + return nil + }, + func(r io.RuneScanner) error { + var n int + return Decode(r, &n) + }) + assert.ErrorContains(t, err, "did not consume entire") +} @@ -40,7 +40,8 @@ func encodeWriteString(w io.Writer, str string) { } type Encoder struct { - w *ReEncoder + w *ReEncoder + closeAfterEncode bool } // NewEncoder returns a new encoder. @@ -57,7 +58,8 @@ func NewEncoder(w io.Writer) *Encoder { } } return &Encoder{ - w: re, + w: re, + closeAfterEncode: len(re.par.stack) == 0 || (len(re.par.stack) == 1 && re.par.stack[0] == RuneTypeError), } } @@ -72,7 +74,10 @@ func (enc *Encoder) Encode(obj any) (err error) { } }() encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}) - return enc.w.Close() + if enc.closeAfterEncode { + return enc.w.Close() + } + return nil } // Encode encodes a value to w. diff --git a/methods_test.go b/methods_test.go new file mode 100644 index 0000000..8280a94 --- /dev/null +++ b/methods_test.go @@ -0,0 +1,115 @@ +// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com> +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson_test + +import ( + "bytes" + "fmt" + "io" + "testing" + + "github.com/stretchr/testify/assert" + + "git.lukeshu.com/go/lowmemjson" +) + +type SumRun struct { + ChecksumSize int `json:",omitempty"` + Addr int64 `json:",omitempty"` + Sums string +} + +type SumRunWithGaps struct { + Addr int64 + Size int64 + Runs []SumRun +} + +func (sg SumRunWithGaps) EncodeJSON(w io.Writer) error { + if _, err := fmt.Fprintf(w, `{"Addr":%d,"Size":%d,"Runs":[`, sg.Addr, sg.Size); err != nil { + return err + } + cur := sg.Addr + for i, run := range sg.Runs { + if i > 0 { + if _, err := w.Write([]byte{','}); err != nil { + return err + } + } + if run.Addr > cur { + if _, err := fmt.Fprintf(w, `{"Gap":%d},`, run.Addr-cur); err != nil { + return err + } + } + if err := lowmemjson.Encode(w, run); err != nil { + return err + } + } + end := sg.Addr + sg.Size + switch { + case end < cur: + return fmt.Errorf("invalid %T: addr went backwards: %v < %v", sg, end, cur) + case end > cur: + if _, err := fmt.Fprintf(w, `,{"Gap":%d}`, end-cur); err != nil { + return err + } + } + if _, err := w.Write([]byte("]}")); err != nil { + return err + } + return nil +} + +func (sg *SumRunWithGaps) DecodeJSON(r io.RuneScanner) error { + *sg = SumRunWithGaps{} + var name string + return lowmemjson.DecodeObject(r, + func(r io.RuneScanner) error { + return lowmemjson.Decode(r, &name) + }, + func(r io.RuneScanner) error { + switch name { + case "Addr": + return lowmemjson.Decode(r, &sg.Addr) + case "Size": + return lowmemjson.Decode(r, &sg.Size) + case "Runs": + return lowmemjson.DecodeArray(r, func(r io.RuneScanner) error { + var run SumRun + if err := lowmemjson.Decode(r, &run); err != nil { + return err + } + if run.ChecksumSize > 0 { + sg.Runs = append(sg.Runs, run) + } + return nil + }) + default: + return fmt.Errorf("unknown key %q", name) + } + }) +} + +func TestMethods(t *testing.T) { + in := SumRunWithGaps{ + Addr: 13631488, + Size: 416033783808, + Runs: []SumRun{ + { + ChecksumSize: 4, + Addr: 1095761920, + Sums: "c160817cb5c72bbbe", + }, + }, + } + var buf bytes.Buffer + assert.NoError(t, lowmemjson.Encode(&buf, in)) + assert.Equal(t, + `{"Addr":13631488,"Size":416033783808,"Runs":[{"Gap":1082130432},{"ChecksumSize":4,"Addr":1095761920,"Sums":"c160817cb5c72bbbe"},{"Gap":416033783808}]}`, + buf.String()) + var out SumRunWithGaps + assert.NoError(t, lowmemjson.Decode(&buf, &out)) + assert.Equal(t, in, out) +} |