summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-01-28 23:26:26 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-01-30 13:49:17 -0700
commite87c9b4d8b629f5df19e9dd182162889d279b4f2 (patch)
tree78c368370dbe7f35246aa353d934c906b36be0c1
parentff6dc0bc519886905e758a84e572f5e34d6c03d1 (diff)
encode: Fix errors for marshalers/encodables with bad output
-rw-r--r--encode.go18
-rw-r--r--errors.go5
-rw-r--r--methods_test.go111
3 files changed, 129 insertions, 5 deletions
diff --git a/encode.go b/encode.go
index fa337ad..00848ed 100644
--- a/encode.go
+++ b/encode.go
@@ -146,7 +146,11 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool
}})
}
if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) {
- panic(encodeError{err})
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ SourceFunc: "EncodeJSON",
+ Err: err,
+ }})
}
case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType):
@@ -173,10 +177,18 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool
// Use a sub-ReEncoder to check that it's a full element.
validator := &ReEncoder{Out: w, BackslashEscape: escaper}
if _, err := validator.Write(dat); err != nil {
- panic(encodeError{err})
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ SourceFunc: "MarshalJSON",
+ Err: err,
+ }})
}
if err := validator.Close(); err != nil {
- panic(encodeError{err})
+ panic(encodeError{&EncodeMethodError{
+ Type: val.Type(),
+ SourceFunc: "MarshalJSON",
+ Err: err,
+ }})
}
case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType):
diff --git a/errors.go b/errors.go
index d36fc83..fe48723 100644
--- a/errors.go
+++ b/errors.go
@@ -138,8 +138,9 @@ type EncodeTypeError = json.UnsupportedTypeError
// }
type EncodeValueError = json.UnsupportedValueError
-// An EncodeMethodError wraps an error that is returned from an
-// object's method when encoding that object to JSON.
+// An EncodeMethodError either wraps an error that is returned from an
+// object's method when encoding that object to JSON, or wraps a
+// *ReEncodeSyntaxError for the method's output.
type EncodeMethodError struct {
Type reflect.Type // The Go type that the method is on
SourceFunc string // The method: "EncodeJSON", "MarshalJSON", or "MarshalText"
diff --git a/methods_test.go b/methods_test.go
index 5e2209a..46e2601 100644
--- a/methods_test.go
+++ b/methods_test.go
@@ -6,8 +6,10 @@ package lowmemjson_test
import (
"bytes"
+ "errors"
"fmt"
"io"
+ "strings"
"testing"
"github.com/stretchr/testify/assert"
@@ -121,3 +123,112 @@ func TestMethods(t *testing.T) {
assert.NoError(t, lowmemjson.NewDecoder(&buf).Decode(&out))
assert.Equal(t, in, out)
}
+
+type strEncoder string
+
+func (s strEncoder) EncodeJSON(w io.Writer) error {
+ _, err := io.WriteString(w, string(s))
+ return err
+}
+
+type strMarshaler string
+
+func (s strMarshaler) MarshalJSON() ([]byte, error) {
+ return []byte(s), nil
+}
+
+type strTextMarshaler struct {
+ str string
+ err string
+}
+
+func (m strTextMarshaler) MarshalText() (txt []byte, err error) {
+ if len(m.str) > 0 {
+ txt = []byte(m.str)
+ }
+ if len(m.err) > 0 {
+ err = errors.New(m.err)
+ }
+ return
+}
+
+func TestMethodsEncode(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ In string
+ ExpectedErr string
+ }
+ testcases := map[string]testcase{
+ "basic": {In: `{}`},
+ "empty": {In: ``, ExpectedErr: `syntax error at input byte 0: EOF`},
+ "short": {In: `{`, ExpectedErr: `syntax error at input byte 1: unexpected EOF`},
+ "long": {In: `{}{}`, ExpectedErr: `syntax error at input byte 2: invalid character '{' after top-level value`},
+ }
+ t.Run("encodable", func(t *testing.T) {
+ t.Parallel()
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ var buf strings.Builder
+ err := lowmemjson.NewEncoder(&buf).Encode([]any{strEncoder(tc.In)})
+ if tc.ExpectedErr == "" {
+ assert.NoError(t, err)
+ assert.Equal(t, "["+tc.In+"]", buf.String())
+ } else {
+ assert.EqualError(t, err,
+ `json: error calling EncodeJSON for type lowmemjson_test.strEncoder: `+
+ tc.ExpectedErr)
+ }
+ })
+ }
+ })
+ t.Run("marshaler", func(t *testing.T) {
+ t.Parallel()
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ var buf strings.Builder
+ err := lowmemjson.NewEncoder(&buf).Encode([]any{strMarshaler(tc.In)})
+ if tc.ExpectedErr == "" {
+ assert.NoError(t, err)
+ assert.Equal(t, "["+tc.In+"]", buf.String())
+ } else {
+ assert.EqualError(t, err,
+ `json: error calling MarshalJSON for type lowmemjson_test.strMarshaler: `+
+ tc.ExpectedErr)
+ }
+ })
+ }
+ })
+ t.Run("text", func(t *testing.T) {
+ t.Parallel()
+ type testcase struct {
+ Str string
+ Err string
+ }
+ testcases := map[string]testcase{
+ "basic": {Str: `a`},
+ "err": {Err: `xxx`},
+ "both": {Str: `a`, Err: `xxx`},
+ }
+ for tcName, tc := range testcases {
+ tc := tc
+ t.Run(tcName, func(t *testing.T) {
+ t.Parallel()
+ var buf strings.Builder
+ err := lowmemjson.NewEncoder(&buf).Encode([]any{strTextMarshaler{str: tc.Str, err: tc.Err}})
+ if tc.Err == "" {
+ assert.NoError(t, err)
+ assert.Equal(t, `["`+tc.Str+`"]`, buf.String())
+ } else {
+ assert.EqualError(t, err,
+ `json: error calling MarshalText for type lowmemjson_test.strTextMarshaler: `+
+ tc.Err)
+ assert.Equal(t, "[", buf.String())
+ }
+ })
+ }
+ })
+}