From e87c9b4d8b629f5df19e9dd182162889d279b4f2 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sat, 28 Jan 2023 23:26:26 -0700 Subject: encode: Fix errors for marshalers/encodables with bad output --- encode.go | 18 +++++++-- errors.go | 5 ++- methods_test.go | 111 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 5 deletions(-) diff --git a/encode.go b/encode.go index fa337ad..00848ed 100644 --- a/encode.go +++ b/encode.go @@ -146,7 +146,11 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool }}) } if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) { - panic(encodeError{err}) + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + SourceFunc: "EncodeJSON", + Err: err, + }}) } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): @@ -173,10 +177,18 @@ func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool // Use a sub-ReEncoder to check that it's a full element. validator := &ReEncoder{Out: w, BackslashEscape: escaper} if _, err := validator.Write(dat); err != nil { - panic(encodeError{err}) + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + SourceFunc: "MarshalJSON", + Err: err, + }}) } if err := validator.Close(); err != nil { - panic(encodeError{err}) + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + SourceFunc: "MarshalJSON", + Err: err, + }}) } case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): diff --git a/errors.go b/errors.go index d36fc83..fe48723 100644 --- a/errors.go +++ b/errors.go @@ -138,8 +138,9 @@ type EncodeTypeError = json.UnsupportedTypeError // } type EncodeValueError = json.UnsupportedValueError -// An EncodeMethodError wraps an error that is returned from an -// object's method when encoding that object to JSON. +// An EncodeMethodError either wraps an error that is returned from an +// object's method when encoding that object to JSON, or wraps a +// *ReEncodeSyntaxError for the method's output. type EncodeMethodError struct { Type reflect.Type // The Go type that the method is on SourceFunc string // The method: "EncodeJSON", "MarshalJSON", or "MarshalText" diff --git a/methods_test.go b/methods_test.go index 5e2209a..46e2601 100644 --- a/methods_test.go +++ b/methods_test.go @@ -6,8 +6,10 @@ package lowmemjson_test import ( "bytes" + "errors" "fmt" "io" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -121,3 +123,112 @@ func TestMethods(t *testing.T) { assert.NoError(t, lowmemjson.NewDecoder(&buf).Decode(&out)) assert.Equal(t, in, out) } + +type strEncoder string + +func (s strEncoder) EncodeJSON(w io.Writer) error { + _, err := io.WriteString(w, string(s)) + return err +} + +type strMarshaler string + +func (s strMarshaler) MarshalJSON() ([]byte, error) { + return []byte(s), nil +} + +type strTextMarshaler struct { + str string + err string +} + +func (m strTextMarshaler) MarshalText() (txt []byte, err error) { + if len(m.str) > 0 { + txt = []byte(m.str) + } + if len(m.err) > 0 { + err = errors.New(m.err) + } + return +} + +func TestMethodsEncode(t *testing.T) { + t.Parallel() + type testcase struct { + In string + ExpectedErr string + } + testcases := map[string]testcase{ + "basic": {In: `{}`}, + "empty": {In: ``, ExpectedErr: `syntax error at input byte 0: EOF`}, + "short": {In: `{`, ExpectedErr: `syntax error at input byte 1: unexpected EOF`}, + "long": {In: `{}{}`, ExpectedErr: `syntax error at input byte 2: invalid character '{' after top-level value`}, + } + t.Run("encodable", func(t *testing.T) { + t.Parallel() + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + t.Parallel() + var buf strings.Builder + err := lowmemjson.NewEncoder(&buf).Encode([]any{strEncoder(tc.In)}) + if tc.ExpectedErr == "" { + assert.NoError(t, err) + assert.Equal(t, "["+tc.In+"]", buf.String()) + } else { + assert.EqualError(t, err, + `json: error calling EncodeJSON for type lowmemjson_test.strEncoder: `+ + tc.ExpectedErr) + } + }) + } + }) + t.Run("marshaler", func(t *testing.T) { + t.Parallel() + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + t.Parallel() + var buf strings.Builder + err := lowmemjson.NewEncoder(&buf).Encode([]any{strMarshaler(tc.In)}) + if tc.ExpectedErr == "" { + assert.NoError(t, err) + assert.Equal(t, "["+tc.In+"]", buf.String()) + } else { + assert.EqualError(t, err, + `json: error calling MarshalJSON for type lowmemjson_test.strMarshaler: `+ + tc.ExpectedErr) + } + }) + } + }) + t.Run("text", func(t *testing.T) { + t.Parallel() + type testcase struct { + Str string + Err string + } + testcases := map[string]testcase{ + "basic": {Str: `a`}, + "err": {Err: `xxx`}, + "both": {Str: `a`, Err: `xxx`}, + } + for tcName, tc := range testcases { + tc := tc + t.Run(tcName, func(t *testing.T) { + t.Parallel() + var buf strings.Builder + err := lowmemjson.NewEncoder(&buf).Encode([]any{strTextMarshaler{str: tc.Str, err: tc.Err}}) + if tc.Err == "" { + assert.NoError(t, err) + assert.Equal(t, `["`+tc.Str+`"]`, buf.String()) + } else { + assert.EqualError(t, err, + `json: error calling MarshalText for type lowmemjson_test.strTextMarshaler: `+ + tc.Err) + assert.Equal(t, "[", buf.String()) + } + }) + } + }) +} -- cgit v1.2.3-54-g00ecf