diff options
Diffstat (limited to 'compat/json/compat.go')
-rw-r--r-- | compat/json/compat.go | 200 |
1 files changed, 121 insertions, 79 deletions
diff --git a/compat/json/compat.go b/compat/json/compat.go index 695c1a8..4dc15ab 100644 --- a/compat/json/compat.go +++ b/compat/json/compat.go @@ -40,22 +40,99 @@ type ( // MarshalerError = json.MarshalerError // Duplicated to access a private field. ) -// Encode wrappers /////////////////////////////////////////////////// +// Error conversion ////////////////////////////////////////////////// -func convertEncodeError(err error) error { - if me, ok := err.(*lowmemjson.EncodeMethodError); ok { - err = &MarshalerError{ - Type: me.Type, - Err: me.Err, - sourceFunc: me.SourceFunc, +func convertError(err error, isUnmarshal bool) error { + switch err := err.(type) { + case nil: + return nil + case *lowmemjson.DecodeArgumentError: + return err + case *lowmemjson.DecodeError: + switch suberr := err.Err.(type) { + case *lowmemjson.DecodeReadError: + return err + case *lowmemjson.DecodeSyntaxError: + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + if isUnmarshal { + return &SyntaxError{ + msg: "unexpected end of JSON input", + Offset: suberr.Offset, + } + } + return suberr.Err + } + return &SyntaxError{ + msg: suberr.Err.Error(), + Offset: suberr.Offset + 1, + } + case *lowmemjson.DecodeTypeError: + switch subsuberr := suberr.Err.(type) { + case *UnmarshalTypeError: + // Populate the .Struct and .Field members. + subsuberr.Struct = err.FieldParent + subsuberr.Field = err.FieldName + return subsuberr + default: + switch { + case errors.Is(err, lowmemjson.ErrDecodeNonEmptyInterface), + errors.Is(err, strconv.ErrSyntax), + errors.Is(err, strconv.ErrRange): + return &UnmarshalTypeError{ + Value: suberr.JSONType, + Type: suberr.GoType, + Offset: suberr.Offset, + Struct: err.FieldParent, + Field: err.FieldName, + } + default: + return subsuberr + } + case nil, *lowmemjson.DecodeArgumentError: + return &UnmarshalTypeError{ + Value: suberr.JSONType, + Type: suberr.GoType, + Offset: suberr.Offset, + Struct: err.FieldParent, + Field: err.FieldName, + } + } + default: + panic(fmt.Errorf("should not happen: unexpected lowmemjson.DecodeError sub-type: %T: %w", suberr, err)) } + case *lowmemjson.EncodeWriteError: + return err + case *lowmemjson.EncodeTypeError: + return err + case *lowmemjson.EncodeValueError: + return err + case *lowmemjson.EncodeMethodError: + return &MarshalerError{ + Type: err.Type, + Err: err.Err, + sourceFunc: err.SourceFunc, + } + case *lowmemjson.ReEncodeWriteError: + return err + case *lowmemjson.ReEncodeSyntaxError: + ret := &SyntaxError{ + msg: err.Err.Error(), + Offset: err.Offset + 1, + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + ret.msg = "unexpected end of JSON input" + } + return ret + default: + panic(fmt.Errorf("should not happen: unexpected lowmemjson error type: %T: %w", err, err)) } - return err } +// Encode wrappers /////////////////////////////////////////////////// + func marshal(v any, cfg lowmemjson.ReEncoderConfig) ([]byte, error) { var buf bytes.Buffer - if err := convertEncodeError(lowmemjson.NewEncoder(lowmemjson.NewReEncoder(&buf, cfg)).Encode(v)); err != nil { + if err := convertError(lowmemjson.NewEncoder(lowmemjson.NewReEncoder(&buf, cfg)).Encode(v), false); err != nil { return nil, err } return buf.Bytes(), nil @@ -105,7 +182,7 @@ func (enc *Encoder) refreshConfig() { } func (enc *Encoder) Encode(v any) error { - if err := convertEncodeError(enc.encoder.Encode(v)); err != nil { + if err := convertError(enc.encoder.Encode(v), false); err != nil { enc.buf.Reset() return err } @@ -133,19 +210,6 @@ func (enc *Encoder) SetIndent(prefix, indent string) { // ReEncode wrappers ///////////////////////////////////////////////// -func convertReEncodeError(err error) error { - if se, ok := err.(*lowmemjson.ReEncodeSyntaxError); ok { - err = &SyntaxError{ - msg: se.Err.Error(), - Offset: se.Offset + 1, - } - if errors.Is(se.Err, io.ErrUnexpectedEOF) { - err.(*SyntaxError).msg = "unexpected end of JSON input" - } - } - return err -} - func HTMLEscape(dst *bytes.Buffer, src []byte) { for n := 0; n < len(src); { c, size := utf8.DecodeRune(src[n:]) @@ -172,7 +236,7 @@ func reencode(dst io.Writer, src []byte, cfg lowmemjson.ReEncoderConfig) error { if err == nil { err = formatter.Close() } - return convertReEncodeError(err) + return convertError(err, false) } func Compact(dst *bytes.Buffer, src []byte) error { @@ -237,53 +301,6 @@ func Valid(data []byte) bool { // Decode wrappers /////////////////////////////////////////////////// -func convertDecodeError(err error, isUnmarshal bool) error { - if derr, ok := err.(*lowmemjson.DecodeError); ok { - switch terr := derr.Err.(type) { - case *lowmemjson.DecodeSyntaxError: - switch { - case errors.Is(terr.Err, io.EOF): - err = io.EOF - case errors.Is(terr.Err, io.ErrUnexpectedEOF) && isUnmarshal: - err = &SyntaxError{ - msg: "unexpected end of JSON input", - Offset: terr.Offset, - } - default: - err = &SyntaxError{ - msg: terr.Err.Error(), - Offset: terr.Offset + 1, - } - } - case *lowmemjson.DecodeTypeError: - if typeErr, ok := terr.Err.(*json.UnmarshalTypeError); ok { - err = &UnmarshalTypeError{ - Value: typeErr.Value, - Type: typeErr.Type, - Offset: typeErr.Offset, - Struct: derr.FieldParent, - Field: derr.FieldName, - } - } else if _, isArgErr := terr.Err.(*lowmemjson.DecodeArgumentError); terr.Err != nil && - !isArgErr && - !errors.Is(terr.Err, lowmemjson.ErrDecodeNonEmptyInterface) && - !errors.Is(terr.Err, strconv.ErrSyntax) && - !errors.Is(terr.Err, strconv.ErrRange) { - err = terr.Err - } else { - err = &UnmarshalTypeError{ - Value: terr.JSONType, - Type: terr.GoType, - Offset: terr.Offset, - Struct: derr.FieldParent, - Field: derr.FieldName, - } - } - } - } - return err -} - type decodeValidator struct{} func (*decodeValidator) DecodeJSON(r io.RuneScanner) error { @@ -301,17 +318,20 @@ func (*decodeValidator) DecodeJSON(r io.RuneScanner) error { var _ lowmemjson.Decodable = (*decodeValidator)(nil) func Unmarshal(data []byte, ptr any) error { - if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(&decodeValidator{}), true); err != nil { + if err := convertError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(&decodeValidator{}), true); err != nil { return err } - if err := convertDecodeError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr), true); err != nil { + if err := convertError(lowmemjson.NewDecoder(bytes.NewReader(data)).DecodeThenEOF(ptr), true); err != nil { return err } return nil } type teeRuneScanner struct { - src io.RuneScanner + src interface { + io.RuneScanner + io.ByteScanner + } dst *bytes.Buffer lastSize int } @@ -319,11 +339,14 @@ type teeRuneScanner struct { func (tee *teeRuneScanner) ReadRune() (r rune, size int, err error) { r, size, err = tee.src.ReadRune() if err == nil { - if _, err := tee.dst.WriteRune(r); err != nil { - return 0, 0, err + if r == utf8.RuneError && size == 1 { + _ = tee.src.UnreadRune() + b, _ := tee.src.ReadByte() + _ = tee.dst.WriteByte(b) + } else { + _, _ = tee.dst.WriteRune(r) } } - tee.lastSize = size return } @@ -338,6 +361,25 @@ func (tee *teeRuneScanner) UnreadRune() error { return nil } +func (tee *teeRuneScanner) ReadByte() (b byte, err error) { + b, err = tee.src.ReadByte() + if err == nil { + _ = tee.dst.WriteByte(b) + tee.lastSize = 1 + } + return +} + +func (tee *teeRuneScanner) UnreadByte() error { + if tee.lastSize != 1 { + return lowmemjson.ErrInvalidUnreadRune + } + _ = tee.src.UnreadByte() + tee.dst.Truncate(tee.dst.Len() - tee.lastSize) + tee.lastSize = 0 + return nil +} + type Decoder struct { validatorBuf *bufio.Reader validator *lowmemjson.Decoder @@ -363,10 +405,10 @@ func NewDecoder(r io.Reader) *Decoder { } func (dec *Decoder) Decode(ptr any) error { - if err := convertDecodeError(dec.validator.Decode(&decodeValidator{}), false); err != nil { + if err := convertError(dec.validator.Decode(&decodeValidator{}), false); err != nil { return err } - if err := convertDecodeError(dec.Decoder.Decode(ptr), false); err != nil { + if err := convertError(dec.Decoder.Decode(ptr), false); err != nil { return err } return nil |