summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@datawire.io>2022-08-21 21:39:59 -0600
committerLuke Shumaker <lukeshu@datawire.io>2022-08-21 21:39:59 -0600
commit325838f35ce90080aa6c892a998d960c06c1c144 (patch)
tree60e87ee9a622a0261b93ef1e47dbca632c588276
parentd25456172946e5921747cd57fb04eb5b6da72fb6 (diff)
Add tests for the actual usability of the Decodable and Encodable interfaces
-rw-r--r--decode.go34
-rw-r--r--decode_test.go13
-rw-r--r--encode.go11
-rw-r--r--methods_test.go115
4 files changed, 162 insertions, 11 deletions
diff --git a/decode.go b/decode.go
index c160192..fbf2373 100644
--- a/decode.go
+++ b/decode.go
@@ -300,9 +300,13 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) {
case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType):
t := dec.peekRuneType()
obj := val.Addr().Interface().(Decodable)
- if err := obj.DecodeJSON(dec.limitingScanner()); err != nil {
+ l := dec.limitingScanner()
+ if err := obj.DecodeJSON(l); err != nil {
dec.panicType(t.jsonType(), typ, err)
}
+ if _, _, err := l.ReadRune(); err != io.EOF {
+ dec.panicType(t.jsonType(), typ, fmt.Errorf("did not consume entire %s", t.jsonType()))
+ }
case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType):
t := dec.peekRuneType()
var buf bytes.Buffer
@@ -724,7 +728,7 @@ func (dec *Decoder) decodeAny() any {
}
// DecodeObject is a helper function for implementing the Decoder interface.
-func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func() error) (err error) {
+func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) error) (err error) {
defer func() {
if r := recover(); r != nil {
if de, ok := r.(decodeError); ok {
@@ -738,13 +742,22 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func() error) (err erro
dec := NewDecoder(r)
dec.decodeObject(nil,
func() {
- if err := decodeKey(); err != nil {
+ l := dec.limitingScanner()
+ if err := decodeKey(l); err != nil {
dec.panicType("string", nil, err)
}
+ if _, _, err := l.ReadRune(); err != io.EOF {
+ dec.panicType("string", nil, fmt.Errorf("did not consume entire string"))
+ }
},
func() {
- if err := decodeVal(); err != nil {
- dec.panicType("", nil, err)
+ t := dec.peekRuneType()
+ l := dec.limitingScanner()
+ if err := decodeVal(l); err != nil {
+ dec.panicType(t.jsonType(), nil, err)
+ }
+ if _, _, err := l.ReadRune(); err != io.EOF {
+ dec.panicType(t.jsonType(), nil, fmt.Errorf("did not consume entire %s", t.jsonType()))
}
})
return
@@ -778,7 +791,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func())
}
// DecodeArray is a helper function for implementing the Decoder interface.
-func DecodeArray(r io.RuneScanner, decodeMember func() error) (err error) {
+func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) (err error) {
defer func() {
if r := recover(); r != nil {
if de, ok := r.(decodeError); ok {
@@ -791,8 +804,13 @@ func DecodeArray(r io.RuneScanner, decodeMember func() error) (err error) {
}()
dec := NewDecoder(r)
dec.decodeArray(nil, func() {
- if err := decodeMember(); err != nil {
- dec.panicType("array", nil, err)
+ t := dec.peekRuneType()
+ l := dec.limitingScanner()
+ if err := decodeMember(l); err != nil {
+ dec.panicType(t.jsonType(), nil, err)
+ }
+ if _, _, err := l.ReadRune(); err != io.EOF {
+ dec.panicType(t.jsonType(), nil, fmt.Errorf("did not consume entire %s", t.jsonType()))
}
})
return
diff --git a/decode_test.go b/decode_test.go
index 8220e39..a9e81e0 100644
--- a/decode_test.go
+++ b/decode_test.go
@@ -5,6 +5,7 @@
package lowmemjson
import (
+ "io"
"strings"
"testing"
@@ -19,3 +20,15 @@ func TestDecodeNumber(t *testing.T) {
assert.Equal(t, 1, num)
assert.Equal(t, 2, r.Len()) // check that it didn't read too far
}
+
+func TestDecodeObject(t *testing.T) {
+ err := DecodeObject(strings.NewReader(`{"foo":9}`),
+ func(r io.RuneScanner) error {
+ return nil
+ },
+ func(r io.RuneScanner) error {
+ var n int
+ return Decode(r, &n)
+ })
+ assert.ErrorContains(t, err, "did not consume entire")
+}
diff --git a/encode.go b/encode.go
index 44fd985..f93cba4 100644
--- a/encode.go
+++ b/encode.go
@@ -40,7 +40,8 @@ func encodeWriteString(w io.Writer, str string) {
}
type Encoder struct {
- w *ReEncoder
+ w *ReEncoder
+ closeAfterEncode bool
}
// NewEncoder returns a new encoder.
@@ -57,7 +58,8 @@ func NewEncoder(w io.Writer) *Encoder {
}
}
return &Encoder{
- w: re,
+ w: re,
+ closeAfterEncode: len(re.par.stack) == 0 || (len(re.par.stack) == 1 && re.par.stack[0] == RuneTypeError),
}
}
@@ -72,7 +74,10 @@ func (enc *Encoder) Encode(obj any) (err error) {
}
}()
encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{})
- return enc.w.Close()
+ if enc.closeAfterEncode {
+ return enc.w.Close()
+ }
+ return nil
}
// Encode encodes a value to w.
diff --git a/methods_test.go b/methods_test.go
new file mode 100644
index 0000000..8280a94
--- /dev/null
+++ b/methods_test.go
@@ -0,0 +1,115 @@
+// Copyright (C) 2022 Luke Shumaker <lukeshu@lukeshu.com>
+//
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+package lowmemjson_test
+
+import (
+ "bytes"
+ "fmt"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/assert"
+
+ "git.lukeshu.com/go/lowmemjson"
+)
+
+type SumRun struct {
+ ChecksumSize int `json:",omitempty"`
+ Addr int64 `json:",omitempty"`
+ Sums string
+}
+
+type SumRunWithGaps struct {
+ Addr int64
+ Size int64
+ Runs []SumRun
+}
+
+func (sg SumRunWithGaps) EncodeJSON(w io.Writer) error {
+ if _, err := fmt.Fprintf(w, `{"Addr":%d,"Size":%d,"Runs":[`, sg.Addr, sg.Size); err != nil {
+ return err
+ }
+ cur := sg.Addr
+ for i, run := range sg.Runs {
+ if i > 0 {
+ if _, err := w.Write([]byte{','}); err != nil {
+ return err
+ }
+ }
+ if run.Addr > cur {
+ if _, err := fmt.Fprintf(w, `{"Gap":%d},`, run.Addr-cur); err != nil {
+ return err
+ }
+ }
+ if err := lowmemjson.Encode(w, run); err != nil {
+ return err
+ }
+ }
+ end := sg.Addr + sg.Size
+ switch {
+ case end < cur:
+ return fmt.Errorf("invalid %T: addr went backwards: %v < %v", sg, end, cur)
+ case end > cur:
+ if _, err := fmt.Fprintf(w, `,{"Gap":%d}`, end-cur); err != nil {
+ return err
+ }
+ }
+ if _, err := w.Write([]byte("]}")); err != nil {
+ return err
+ }
+ return nil
+}
+
+func (sg *SumRunWithGaps) DecodeJSON(r io.RuneScanner) error {
+ *sg = SumRunWithGaps{}
+ var name string
+ return lowmemjson.DecodeObject(r,
+ func(r io.RuneScanner) error {
+ return lowmemjson.Decode(r, &name)
+ },
+ func(r io.RuneScanner) error {
+ switch name {
+ case "Addr":
+ return lowmemjson.Decode(r, &sg.Addr)
+ case "Size":
+ return lowmemjson.Decode(r, &sg.Size)
+ case "Runs":
+ return lowmemjson.DecodeArray(r, func(r io.RuneScanner) error {
+ var run SumRun
+ if err := lowmemjson.Decode(r, &run); err != nil {
+ return err
+ }
+ if run.ChecksumSize > 0 {
+ sg.Runs = append(sg.Runs, run)
+ }
+ return nil
+ })
+ default:
+ return fmt.Errorf("unknown key %q", name)
+ }
+ })
+}
+
+func TestMethods(t *testing.T) {
+ in := SumRunWithGaps{
+ Addr: 13631488,
+ Size: 416033783808,
+ Runs: []SumRun{
+ {
+ ChecksumSize: 4,
+ Addr: 1095761920,
+ Sums: "c160817cb5c72bbbe",
+ },
+ },
+ }
+ var buf bytes.Buffer
+ assert.NoError(t, lowmemjson.Encode(&buf, in))
+ assert.Equal(t,
+ `{"Addr":13631488,"Size":416033783808,"Runs":[{"Gap":1082130432},{"ChecksumSize":4,"Addr":1095761920,"Sums":"c160817cb5c72bbbe"},{"Gap":416033783808}]}`,
+ buf.String())
+ var out SumRunWithGaps
+ assert.NoError(t, lowmemjson.Decode(&buf, &out))
+ assert.Equal(t, in, out)
+}