From 54bbd1e59317a6e9658eb8098657078cc8e81979 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sun, 14 Aug 2022 20:52:06 -0600 Subject: wip: Reduce test differences [ci-skip] - Handle UTF-16 surrogate pairs - Handle cycles in values - Handle cycles in types - Better errors - Handle case-folding of struct field names - Allow []byteTypeWithMethods - Fix struct field-order - Fix handling of interfaces storing pointers - Enforce a maximum decode depth - Validate struct tags --- borrowed_misc.go | 20 ++ compat/json/borrowed_decode_test.go | 37 ++-- compat/json/borrowed_encode_test.go | 15 +- compat/json/borrowed_misc.go | 14 ++ compat/json/borrowed_scanner_test.go | 4 - compat/json/borrowed_tagkey_test.go | 1 - compat/json/compat.go | 45 ++++- compat/json/compat_test.go | 6 + decode.go | 342 +++++++++++++++++++++++------------ decode_scan.go | 34 +--- encode.go | 69 +++++-- errors.go | 130 +++++++++++++ parse.go | 16 +- reencode.go | 7 +- struct.go | 61 ++++--- 15 files changed, 575 insertions(+), 226 deletions(-) create mode 100644 compat/json/borrowed_misc.go create mode 100644 errors.go diff --git a/borrowed_misc.go b/borrowed_misc.go index 343c924..d5ace19 100644 --- a/borrowed_misc.go +++ b/borrowed_misc.go @@ -6,6 +6,8 @@ package lowmemjson import ( "reflect" + "strings" + "unicode" ) // from encode.go @@ -26,3 +28,21 @@ func isEmptyValue(v reflect.Value) bool { } return false } + +// from encode.go +func isValidTag(s string) bool { + if s == "" { + return false + } + for _, c := range s { + switch { + case strings.ContainsRune("!#$%&()*+-./:;<=>?@[]^_{|}~ ", c): + // Backslash and quote chars are reserved, but + // otherwise any punctuation chars are allowed + // in a tag name. + case !unicode.IsLetter(c) && !unicode.IsDigit(c): + return false + } + } + return true +} diff --git a/compat/json/borrowed_decode_test.go b/compat/json/borrowed_decode_test.go index 306f85e..4b84718 100644 --- a/compat/json/borrowed_decode_test.go +++ b/compat/json/borrowed_decode_test.go @@ -455,8 +455,8 @@ var unmarshalTests = []unmarshalTest{ {in: `{"X": "foo", "Y"}`, err: &SyntaxError{"invalid character '}' after object key", 17}}, {in: `[1, 2, 3+]`, err: &SyntaxError{"invalid character '+' after array element", 9}}, {in: `{"X":12x}`, err: &SyntaxError{"invalid character 'x' after object key:value pair", 8}, useNumber: true}, - {in: `[2, 3`, err: &SyntaxError{Err: "unexpected end of JSON input", Offset: 5}}, // MODIFIED - {in: `{"F3": -}`, ptr: new(V), out: V{F3: Number("-")}, err: &SyntaxError{Err: "invalid character '}' in numeric literal", Offset: 9}}, // MODIFIED + {in: `[2, 3`, err: &SyntaxError{msg: "unexpected end of JSON input", Offset: 5}}, + {in: `{"F3": -}`, ptr: new(V), out: V{F3: Number("-")}, err: &SyntaxError{msg: "invalid character '}' in numeric literal", Offset: 9}}, // raw value errors {in: "\x01 42", err: &SyntaxError{"invalid character '\\x01' looking for beginning of value", 1}}, @@ -957,7 +957,7 @@ var unmarshalTests = []unmarshalTest{ in: `invalid`, ptr: new(Number), err: &SyntaxError{ - Err: "invalid character 'i' looking for beginning of value", // MODIFIED + msg: "invalid character 'i' looking for beginning of value", Offset: 1, }, }, @@ -1040,7 +1040,6 @@ func TestMarshalNumberZeroVal(t *testing.T) { } func TestMarshalEmbeds(t *testing.T) { - t.Skip() // TODO top := &Top{ Level0: 1, Embed0: Embed0{ @@ -1089,17 +1088,16 @@ func equalError(a, b error) bool { if b == nil { return a == nil } - return true // a.Error() == b.Error() // MODIFIED + return a.Error() == b.Error() } func TestUnmarshal(t *testing.T) { - t.Skip() // TODO for i, tt := range unmarshalTests { scan := lowmemjson.ReEncoder{Out: io.Discard} // MODIFIED in := []byte(tt.in) - if _, err := scan.Write(in); err != nil { + if err := checkValid(in, &scan); err != nil { if !equalError(err, tt.err) { - t.Errorf("#%d: checkValid: %#v\n\n%s", i, err, tt.in) + t.Errorf("#%d: checkValid: %#v", i, err) continue } } @@ -1142,11 +1140,11 @@ func TestUnmarshal(t *testing.T) { continue } if !reflect.DeepEqual(v.Elem().Interface(), tt.out) { - t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v\n\n%s", i, v.Elem().Interface(), tt.out, tt.in) + t.Errorf("#%d: mismatch\nhave: %#+v\nwant: %#+v", i, v.Elem().Interface(), tt.out) data, _ := Marshal(v.Elem().Interface()) - println(string(data)) + t.Log(string(data)) // MODIFIED data, _ = Marshal(tt.out) - println(string(data)) + t.Log(string(data)) // MODIFIED continue } @@ -1311,7 +1309,7 @@ func TestErrorMessageFromMisusedString(t *testing.T) { var s WrongString err := NewDecoder(r).Decode(&s) got := fmt.Sprintf("%v", err) - if err == nil { // if got != tt.err { // MODIFIED + if got != tt.err { t.Errorf("%d. got err = %q, want %q", n, got, tt.err) } } @@ -1742,7 +1740,6 @@ var interfaceSetTests = []struct { } func TestInterfaceSet(t *testing.T) { - t.Skip() // TODO for _, tt := range interfaceSetTests { b := struct{ X any }{tt.pre} blob := `{"X":` + tt.json + `}` @@ -2015,7 +2012,7 @@ var decodeTypeErrorTests = []struct { func TestUnmarshalTypeError(t *testing.T) { for _, item := range decodeTypeErrorTests { err := Unmarshal([]byte(item.src), item.dest) - if err == nil { // if _, ok := err.(*UnmarshalTypeError); !ok { // MODIFIED + if _, ok := err.(*UnmarshalTypeError); !ok { t.Errorf("expected type error for Unmarshal(%q, type %T): got %T", item.src, item.dest, err) } @@ -2037,7 +2034,7 @@ func TestUnmarshalSyntax(t *testing.T) { var x any for _, src := range unmarshalSyntaxTests { err := Unmarshal([]byte(src), &x) - if err == nil { // _, ok := err.(*SyntaxError); !ok { // MODIFIED + if _, ok := err.(*SyntaxError); !ok { t.Errorf("expected syntax error for Unmarshal(%q): got %T", src, err) } } @@ -2202,9 +2199,9 @@ func TestInvalidUnmarshalText(t *testing.T) { t.Errorf("Unmarshal expecting error, got nil") continue } - // if got := err.Error(); got != tt.want { // MODIFIED - // t.Errorf("Unmarshal = %q; want %q", got, tt.want) // MODIFIED - // } // MODIFIED + if got := err.Error(); got != tt.want { + t.Errorf("Unmarshal = %q; want %q", got, tt.want) + } } } @@ -2243,7 +2240,6 @@ func TestInvalidStringOption(t *testing.T) { // (Issue 28145) If the embedded struct is given an explicit name and has // exported methods, don't cause a panic trying to get its value. func TestUnmarshalEmbeddedUnexported(t *testing.T) { - t.Skip() // TODO type ( embed1 struct{ Q int } embed2 struct{ Q int } @@ -2365,7 +2361,6 @@ func TestUnmarshalEmbeddedUnexported(t *testing.T) { } func TestUnmarshalErrorAfterMultipleJSON(t *testing.T) { - t.Skip() // TODO tests := []struct { in string err error @@ -2417,7 +2412,6 @@ func TestUnmarshalPanic(t *testing.T) { // The decoder used to hang if decoding into an interface pointing to its own address. // See golang.org/issues/31740. func TestUnmarshalRecursivePointer(t *testing.T) { - t.Skip() // TODO var v any v = &v data := []byte(`{"a": "b"}`) @@ -2493,7 +2487,6 @@ func TestUnmarshalRescanLiteralMangledUnquote(t *testing.T) { } func TestUnmarshalMaxDepth(t *testing.T) { - t.Skip() // TODO testcases := []struct { name string data string diff --git a/compat/json/borrowed_encode_test.go b/compat/json/borrowed_encode_test.go index bb7c9dc..11c2db4 100644 --- a/compat/json/borrowed_encode_test.go +++ b/compat/json/borrowed_encode_test.go @@ -226,11 +226,11 @@ var unsupportedValues = []any{ math.NaN(), math.Inf(-1), math.Inf(1), - //pointerCycle, // MODIFIED - //pointerCycleIndirect, // MODIFIED - //mapCycle, // MODIFIED - //sliceCycle, // MODIFIED - //recursiveSliceCycle, // MODIFIED + pointerCycle, + pointerCycleIndirect, + mapCycle, + sliceCycle, + recursiveSliceCycle, } func TestUnsupportedValues(t *testing.T) { @@ -344,7 +344,6 @@ func (CText) MarshalText() ([]byte, error) { } func TestMarshalerEscaping(t *testing.T) { - t.Skip() // MODIFIED var c C want := `"\u003c\u0026\u003e"` b, err := Marshal(c) @@ -877,7 +876,6 @@ func (f textfloat) MarshalText() ([]byte, error) { return tenc(`TF:%0.2f`, f) } // Issue 13783 func TestEncodeBytekind(t *testing.T) { - t.Skip() // TODO testdata := []struct { data any want string @@ -1139,7 +1137,6 @@ func TestMarshalRawMessageValue(t *testing.T) { if err != nil { t.Errorf("test %d, unexpected failure: %v", i, err) } else { - t.Skip() // MODIFIED t.Errorf("test %d, unexpected success", i) } } @@ -1178,7 +1175,6 @@ func TestMarshalUncommonFieldNames(t *testing.T) { } } -/* // MODIFIED func TestMarshalerError(t *testing.T) { s := "test variable" st := reflect.TypeOf(s) @@ -1205,4 +1201,3 @@ func TestMarshalerError(t *testing.T) { } } } -*/ // MODIFIED diff --git a/compat/json/borrowed_misc.go b/compat/json/borrowed_misc.go new file mode 100644 index 0000000..30a3b0e --- /dev/null +++ b/compat/json/borrowed_misc.go @@ -0,0 +1,14 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package json + +// A SyntaxError is a description of a JSON syntax error. +// Unmarshal will return a SyntaxError if the JSON can't be parsed. +type SyntaxError struct { + msg string // description of error + Offset int64 // error occurred after reading Offset bytes +} + +func (e *SyntaxError) Error() string { return e.msg } diff --git a/compat/json/borrowed_scanner_test.go b/compat/json/borrowed_scanner_test.go index 4955405..3474b3e 100644 --- a/compat/json/borrowed_scanner_test.go +++ b/compat/json/borrowed_scanner_test.go @@ -65,7 +65,6 @@ var ex1i = `[ ]` func TestCompact(t *testing.T) { - t.Skip() // TODO var buf bytes.Buffer for _, tt := range examples { buf.Reset() @@ -86,7 +85,6 @@ func TestCompact(t *testing.T) { } func TestCompactSeparators(t *testing.T) { - t.Skip() // TODO // U+2028 and U+2029 should be escaped inside strings. // They should not appear outside strings. tests := []struct { @@ -106,7 +104,6 @@ func TestCompactSeparators(t *testing.T) { } func TestIndent(t *testing.T) { - t.Skip() // TODO var buf bytes.Buffer for _, tt := range examples { buf.Reset() @@ -192,7 +189,6 @@ var indentErrorTests = []indentErrorTest{ } func TestIndentErrors(t *testing.T) { - t.Skip() // TODO for i, tt := range indentErrorTests { slice := make([]uint8, 0) buf := bytes.NewBuffer(slice) diff --git a/compat/json/borrowed_tagkey_test.go b/compat/json/borrowed_tagkey_test.go index 6a2d612..6330efd 100644 --- a/compat/json/borrowed_tagkey_test.go +++ b/compat/json/borrowed_tagkey_test.go @@ -96,7 +96,6 @@ var structTagObjectKeyTests = []struct { } func TestStructTagObjectKey(t *testing.T) { - t.Skip() // TODO for _, tt := range structTagObjectKeyTests { b, err := Marshal(tt.raw) if err != nil { diff --git a/compat/json/compat.go b/compat/json/compat.go index 78a9d5f..b26914b 100644 --- a/compat/json/compat.go +++ b/compat/json/compat.go @@ -8,7 +8,9 @@ import ( "bufio" "bytes" "encoding/json" + "errors" "io" + "strconv" "git.lukeshu.com/go/lowmemjson" ) @@ -19,7 +21,7 @@ type ( RawMessage = json.RawMessage // low-level decode errors - SyntaxError = lowmemjson.SyntaxError + //SyntaxError = lowmemjson.DecodeSyntaxError // expose a field UnmarshalFieldError = json.UnmarshalFieldError UnmarshalTypeError = json.UnmarshalTypeError // lowmemjson.DecodeTypeError @@ -28,7 +30,7 @@ type ( // marshal errors InvalidUTF8Error = json.InvalidUTF8Error - MarshalerError = json.MarshalerError + MarshalerError = lowmemjson.EncodeMethodError // expose a field UnsupportedTypeError = json.UnsupportedTypeError UnsupportedValueError = json.UnsupportedValueError ) @@ -92,7 +94,7 @@ func Valid(data []byte) bool { } func Unmarshal(data []byte, ptr any) error { - return lowmemjson.Decode(bytes.NewReader(data), ptr) + return NewDecoder(bytes.NewReader(data)).Decode(ptr) } ///////////////////////////////////////////////////////////////////// @@ -113,6 +115,43 @@ func NewDecoder(r io.Reader) *Decoder { } } +func (dec *Decoder) Decode(ptr any) error { + err := dec.Decoder.Decode(ptr) + if derr, ok := err.(*lowmemjson.DecodeError); ok { + switch terr := derr.Err.(type) { + case *lowmemjson.DecodeSyntaxError: + err = &SyntaxError{ + msg: terr.Err.Error(), + Offset: terr.Offset, + } + case *lowmemjson.DecodeTypeError: + if typeErr, ok := terr.Err.(*json.UnmarshalTypeError); ok { + err = &UnmarshalTypeError{ + Value: typeErr.Value, + Type: typeErr.Type, + Offset: typeErr.Offset, + Struct: derr.FieldParent, + Field: derr.FieldName, + } + } else if _, isArgErr := terr.Err.(*lowmemjson.DecodeArgumentError); terr.Err != nil && + !isArgErr && + !errors.Is(terr.Err, strconv.ErrSyntax) && + !errors.Is(terr.Err, strconv.ErrRange) { + err = terr.Err + } else { + err = &UnmarshalTypeError{ + Value: terr.JSONType, + Type: terr.GoType, + Offset: terr.Offset, + Struct: derr.FieldParent, + Field: derr.FieldName, + } + } + } + } + return err +} + func (dec *Decoder) Buffered() io.Reader { dat, _ := dec.buf.Peek(dec.buf.Buffered()) return bytes.NewReader(dat) diff --git a/compat/json/compat_test.go b/compat/json/compat_test.go index 399ff02..5a34d22 100644 --- a/compat/json/compat_test.go +++ b/compat/json/compat_test.go @@ -7,11 +7,17 @@ package json import ( "bytes" + "git.lukeshu.com/go/lowmemjson" "git.lukeshu.com/go/lowmemjson/internal" ) var parseTag = internal.ParseTag +func checkValid(in []byte, scan *lowmemjson.ReEncoder) error { + _, err := scan.Write(in) + return err +} + const ( startDetectingCyclesAfter = 1000 ) diff --git a/decode.go b/decode.go index 8426526..e4fdd77 100644 --- a/decode.go +++ b/decode.go @@ -14,6 +14,8 @@ import ( "reflect" "strconv" "strings" + "unicode/utf16" + "unicode/utf8" ) type Decodable interface { @@ -26,6 +28,11 @@ type runeBuffer interface { Reset() } +type decodeStackItem struct { + par reflect.Type + idx any +} + type Decoder struct { io runeTypeScanner @@ -35,7 +42,7 @@ type Decoder struct { // state err error - stack []any + stack []decodeStackItem } func NewDecoder(r io.Reader) *Decoder { @@ -63,8 +70,18 @@ func (dec *Decoder) More() bool { return e == nil && t != RuneTypeEOF } -func (dec *Decoder) stackPush(idx any) { - dec.stack = append(dec.stack, idx) +const maxNestingDepth = 10000 + +func (dec *Decoder) stackPush(par reflect.Type, idx any) { + dec.stack = append(dec.stack, decodeStackItem{par, idx}) + if len(dec.stack) > maxNestingDepth { + panic(decodeError{ + Field: dec.stackStr(), + FieldParent: dec.stackParent(), + FieldName: dec.stackName(), + Err: ErrDecodeExceededMaxDepth, + }) + } } func (dec *Decoder) stackPop() { dec.stack = dec.stack[:len(dec.stack)-1] @@ -73,11 +90,30 @@ func (dec *Decoder) stackStr() string { var buf strings.Builder buf.WriteString("v") for _, item := range dec.stack { - fmt.Fprintf(&buf, "[%#v]", item) + fmt.Fprintf(&buf, "[%#v]", item.idx) } return buf.String() } +func (dec *Decoder) stackParent() string { + if len(dec.stack) > 0 && dec.stack[len(dec.stack)-1].par.Kind() == reflect.Struct { + return dec.stack[len(dec.stack)-1].par.Name() + } + return "" +} + +func (dec *Decoder) stackName() string { + var fields []string + for i := len(dec.stack) - 1; i >= 0 && dec.stack[i].par.Kind() == reflect.Struct; i-- { + fields = append(fields, dec.stack[i].idx.(string)) + } + for i := 0; i < len(fields)/2; i++ { + j := (len(fields) - 1) - i + fields[i], fields[j] = fields[j], fields[i] + } + return strings.Join(fields, ".") +} + func Decode(r io.Reader, ptr any) error { return NewDecoder(r).Decode(ptr) } @@ -85,7 +121,7 @@ func Decode(r io.Reader, ptr any) error { func (dec *Decoder) Decode(ptr any) (err error) { ptrVal := reflect.ValueOf(ptr) if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { - return &json.InvalidUnmarshalError{ + return &DecodeArgumentError{ // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil Type: reflect.TypeOf(ptr), } @@ -99,7 +135,8 @@ func (dec *Decoder) Decode(ptr any) (err error) { defer func() { if r := recover(); r != nil { if de, ok := r.(decodeError); ok { - dec.err = de.Err + pub := DecodeError(de) + dec.err = &pub err = dec.err } else { panic(r) @@ -112,19 +149,31 @@ func (dec *Decoder) Decode(ptr any) (err error) { // io helpers ////////////////////////////////////////////////////////////////////////////////////// -type decodeError struct { - Err error -} - -func (dec *Decoder) panicType(typ reflect.Type, err error) { - panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", - dec.InputOffset(), dec.stackStr(), typ, err)}) +type decodeError DecodeError + +func (dec *Decoder) panicType(jTyp string, gTyp reflect.Type, err error) { + panic(decodeError{ + Field: dec.stackStr(), + FieldParent: dec.stackParent(), + FieldName: dec.stackName(), + Err: &DecodeTypeError{ + GoType: gTyp, + JSONType: jTyp, + Err: err, + Offset: dec.InputOffset(), + }, + }) } func (dec *Decoder) readRune() (rune, RuneType) { c, _, t, e := dec.io.ReadRuneType() if e != nil { - panic(decodeError{e}) + panic(decodeError{ + Field: dec.stackStr(), + FieldParent: dec.stackParent(), + FieldName: dec.stackName(), + Err: e, + }) } return c, t } @@ -150,10 +199,10 @@ func (dec *Decoder) expectRune(ec rune, et RuneType) { } } -func (dec *Decoder) expectRuneType(ec rune, et RuneType) { +func (dec *Decoder) expectRuneType(ec rune, et RuneType, gt reflect.Type) { ac, at := dec.readRune() if ac != ec || at != et { - dec.panicType(nil, fmt.Errorf("TODO error message")) + dec.panicType(at.jsonType(), gt, nil) } } @@ -164,7 +213,12 @@ type decRuneTypeScanner struct { func (sc *decRuneTypeScanner) ReadRuneType() (rune, int, RuneType, error) { c, s, t, e := sc.dec.io.ReadRuneType() if e != nil { - panic(decodeError{e}) + panic(decodeError{ + Field: sc.dec.stackStr(), + FieldParent: sc.dec.stackParent(), + FieldName: sc.dec.stackName(), + Err: e, + }) } return c, s, t, nil } @@ -223,22 +277,25 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { typ := val.Type() switch { case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: + t := dec.peekRuneType() var buf bytes.Buffer dec.scan(&buf) if err := val.Addr().Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { - dec.panicType(typ, err) + dec.panicType(t.jsonType(), typ, err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): + t := dec.peekRuneType() obj := val.Addr().Interface().(Decodable) if err := obj.DecodeJSON(dec.limitingScanner()); err != nil { - dec.panicType(typ, err) + dec.panicType(t.jsonType(), typ, err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): + t := dec.peekRuneType() var buf bytes.Buffer dec.scan(&buf) obj := val.Addr().Interface().(json.Unmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { - dec.panicType(typ, err) + dec.panicType(t.jsonType(), typ, err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): if nullOK && dec.peekRuneType() == RuneTypeNullN { @@ -246,30 +303,29 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { return } var buf bytes.Buffer - dec.decodeString(&buf) + dec.decodeString(typ, &buf) obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { - dec.panicType(typ, err) + dec.panicType("string", typ, err) } default: - kind := typ.Kind() - switch kind { + switch kind := typ.Kind(); kind { case reflect.Bool: if nullOK && dec.peekRuneType() == RuneTypeNullN { dec.decodeNull() return } - val.SetBool(dec.decodeBool()) + val.SetBool(dec.decodeBool(typ)) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if nullOK && dec.peekRuneType() == RuneTypeNullN { dec.decodeNull() return } var buf strings.Builder - dec.scanNumber(&buf) + dec.scanNumber(typ, &buf) n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { - dec.panicType(typ, err) + dec.panicType("number "+buf.String(), typ, err) } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: @@ -278,10 +334,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { return } var buf strings.Builder - dec.scanNumber(&buf) + dec.scanNumber(typ, &buf) n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { - dec.panicType(typ, err) + dec.panicType("number "+buf.String(), typ, err) } val.SetUint(n) case reflect.Float32, reflect.Float64: @@ -290,10 +346,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { return } var buf strings.Builder - dec.scanNumber(&buf) + dec.scanNumber(typ, &buf) n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { - dec.panicType(typ, err) + dec.panicType("number "+buf.String(), typ, err) } val.SetFloat(n) case reflect.String: @@ -303,32 +359,44 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { } var buf strings.Builder if typ == numberType { - dec.scanNumber(&buf) + dec.scanNumber(typ, &buf) val.SetString(buf.String()) } else { - dec.decodeString(&buf) + dec.decodeString(typ, &buf) val.SetString(buf.String()) } case reflect.Interface: if typ.NumMethod() > 0 { - dec.panicType(typ, fmt.Errorf("cannot decode in to non-empty interface")) + dec.panicType("", typ, fmt.Errorf("cannot decode in to non-empty interface")) + } + // If the interface stores a pointer, try to use the type information of the pointer. + if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { + // Follow a chain of pointers until we find the first settable + // pointer (if any). + ptr := val.Elem() + for ptr.Kind() == reflect.Pointer { + if ptr.CanSet() { + break + } + if ptr.IsNil() { + break + } + ptr = ptr.Elem() + } + // We only neet to be able to set the pointer itself if we're + // decoding "null", so add a "||" clause. + if ptr.Kind() == reflect.Pointer && (ptr.CanSet() || dec.peekRuneType() != RuneTypeNullN) { + dec.decode(ptr, false) + break + } } + // Couldn't get type information from a pointer; fall back to untyped mode. switch dec.peekRuneType() { case RuneTypeNullN: - if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer { - // XXX: I can't justify this case, other than "it's what encoding/json does, but - // I don't understand their rationale". - dec.decode(val.Elem(), false) - } else { - dec.decodeNull() - val.Set(reflect.Zero(typ)) - } + dec.decodeNull() + val.Set(reflect.Zero(typ)) default: - if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { - dec.decode(val.Elem(), false) - } else { - val.Set(reflect.ValueOf(dec.decodeAny())) - } + val.Set(reflect.ValueOf(dec.decodeAny())) } case reflect.Struct: if nullOK && dec.peekRuneType() == RuneTypeNullN { @@ -337,14 +405,23 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { } index := indexStruct(typ) var nameBuf strings.Builder - dec.decodeObject(&nameBuf, func() { + dec.decodeObject(typ, &nameBuf, func() { name := nameBuf.String() - dec.stackPush(name) + dec.stackPush(typ, name) defer dec.stackPop() idx, ok := index.byName[name] + if !ok { + for oname, oidx := range index.byName { + if strings.EqualFold(name, oname) { + idx = oidx + ok = true + break + } + } + } if !ok { if dec.disallowUnknownFields { - dec.panicType(typ, fmt.Errorf("unknown field %q", name)) + dec.panicType("", typ, fmt.Errorf("json: unknown field %q", name)) } dec.scan(io.Discard) return @@ -353,18 +430,20 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { fVal := val for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { - if fVal.IsNil() { - if !fVal.CanSet() { // https://golang.org/issue/21357 - dec.panicType(fVal.Type().Elem(), fmt.Errorf("cannot set embedded pointer to unexported type")) + if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357 + dec.panicType("", fVal.Type().Elem(), fmt.Errorf("cannot set embedded pointer to unexported type")) + } + if dec.peekRuneType() != RuneTypeNullN { + if fVal.IsNil() { + fVal.Set(reflect.New(fVal.Type().Elem())) } - fVal.Set(reflect.New(fVal.Type().Elem())) + fVal = fVal.Elem() } - fVal = fVal.Elem() } fVal = fVal.Field(idx) } if field.Quote { - switch dec.peekRuneType() { + switch t := dec.peekRuneType(); t { case RuneTypeNullN: dec.decodeNull() switch fVal.Kind() { @@ -378,18 +457,25 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { case RuneTypeStringBeg: // TODO: Figure out how to do this without buffering, have correct offsets. var buf bytes.Buffer - dec.decodeString(&buf) - subD := NewDecoder(&buf) - subD.decode(fVal, false) + dec.decodeString(nil, &buf) + if err := Decode(bytes.NewReader(buf.Bytes()), fVal.Addr().Interface()); err != nil { + if str := buf.String(); str != "null" { + dec.panicType("", fVal.Type(), + fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", + str, fVal.Type())) + } + } default: - dec.panicType(typ, fmt.Errorf(",string field TODO ERROR MESSAGE")) + dec.panicType(t.jsonType(), fVal.Type(), + fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", + fVal.Type())) } } else { dec.decode(fVal, true) } }) case reflect.Map: - switch dec.peekRuneType() { + switch t := dec.peekRuneType(); t { case RuneTypeNullN: dec.decodeNull() val.Set(reflect.Zero(typ)) @@ -398,14 +484,14 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { val.Set(reflect.MakeMap(typ)) } var nameBuf bytes.Buffer - dec.decodeObject(&nameBuf, func() { + dec.decodeObject(typ, &nameBuf, func() { nameValTyp := typ.Key() nameValPtr := reflect.New(nameValTyp) switch { case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): obj := nameValPtr.Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { - dec.panicType(nameValTyp, err) + dec.panicType("string", nameValTyp, err) } default: switch nameValTyp.Kind() { @@ -414,20 +500,20 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - dec.panicType(nameValTyp, err) + dec.panicType("number "+nameBuf.String(), nameValTyp, err) } nameValPtr.Elem().SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - dec.panicType(nameValTyp, err) + dec.panicType("number "+nameBuf.String(), nameValTyp, err) } nameValPtr.Elem().SetUint(n) default: - dec.panicType(typ, fmt.Errorf("invalid map key type: %v", nameValTyp)) + dec.panicType("object", typ, &DecodeArgumentError{nameValTyp}) } } - dec.stackPush(nameValPtr.Elem()) + dec.stackPush(typ, nameValPtr.Elem()) defer dec.stackPop() fValPtr := reflect.New(typ.Elem()) @@ -436,33 +522,39 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) }) default: - dec.panicType(typ, fmt.Errorf("map: TODO")) + dec.panicType(t.jsonType(), typ, nil) } case reflect.Slice: switch { - case typ.Elem().Kind() == reflect.Uint8: - switch dec.peekRuneType() { + case typ.Elem().Kind() == reflect.Uint8 && !(dec.peekRuneType() == RuneTypeArrayBeg && (false || + reflect.PointerTo(typ.Elem()).Implements(decodableType) || + reflect.PointerTo(typ.Elem()).Implements(jsonUnmarshalerType) || + reflect.PointerTo(typ.Elem()).Implements(textUnmarshalerType))): + switch t := dec.peekRuneType(); t { case RuneTypeNullN: dec.decodeNull() val.Set(reflect.Zero(typ)) case RuneTypeStringBeg: - var buf bytes.Buffer - dec.decodeString(newBase64Decoder(&buf)) if typ.Elem() == byteType { + var buf bytes.Buffer + dec.decodeString(typ, newBase64Decoder(&buf)) val.Set(reflect.ValueOf(buf.Bytes())) } else { + // TODO: Surely there's a better way. At the very least, we should + // avoid buffering. + var buf bytes.Buffer + dec.decodeString(typ, newBase64Decoder(&buf)) bs := buf.Bytes() - // TODO: Surely there's a better way. val.Set(reflect.MakeSlice(typ, len(bs), len(bs))) for i := 0; i < len(bs); i++ { val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem())) } } default: - dec.panicType(typ, fmt.Errorf("byte slice: TODO")) + dec.panicType(t.jsonType(), typ, nil) } default: - switch dec.peekRuneType() { + switch t := dec.peekRuneType(); t { case RuneTypeNullN: dec.decodeNull() val.Set(reflect.Zero(typ)) @@ -474,8 +566,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { val.Set(val.Slice(0, 0)) } i := 0 - dec.decodeArray(func() { - dec.stackPush(i) + dec.decodeArray(typ, func() { + dec.stackPush(typ, i) defer dec.stackPop() mValPtr := reflect.New(typ.Elem()) dec.decode(mValPtr.Elem(), false) @@ -483,7 +575,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { i++ }) default: - dec.panicType(typ, fmt.Errorf("slice: TODO")) + dec.panicType(t.jsonType(), typ, nil) } } case reflect.Array: @@ -493,8 +585,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { } i := 0 n := val.Len() - dec.decodeArray(func() { - dec.stackPush(i) + dec.decodeArray(typ, func() { + dec.stackPush(typ, i) defer dec.stackPop() if i < n { mValPtr := reflect.New(typ.Elem()) @@ -512,15 +604,6 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { switch dec.peekRuneType() { case RuneTypeNullN: dec.decodeNull() - /* - for typ.Elem().Kind() == reflect.Pointer { - if val.IsNil() || !val.Elem().CanSet() { - val.Set(reflect.New(typ.Elem())) - } - val = val.Elem() - typ = val.Type() - } - */ val.Set(reflect.Zero(typ)) default: if val.IsNil() { @@ -529,7 +612,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) { dec.decode(val.Elem(), false) } default: - dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) + dec.panicType("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) } } } @@ -545,9 +628,9 @@ func (dec *Decoder) scan(out io.Writer) { } } -func (dec *Decoder) scanNumber(out io.Writer) { - if !dec.peekRuneType().IsNumber() { - dec.panicType(numberType, fmt.Errorf("number: not a number")) +func (dec *Decoder) scanNumber(gTyp reflect.Type, out io.Writer) { + if t := dec.peekRuneType(); !t.IsNumber() { + dec.panicType(t.jsonType(), gTyp, nil) } dec.scan(out) } @@ -558,25 +641,27 @@ func (dec *Decoder) decodeAny() any { switch c { case '{': ret := make(map[string]any) + typ := reflect.TypeOf(ret) var nameBuf strings.Builder - dec.decodeObject(&nameBuf, func() { + dec.decodeObject(typ, &nameBuf, func() { name := nameBuf.String() - dec.stackPush(name) + dec.stackPush(typ, name) defer dec.stackPop() ret[name] = dec.decodeAny() }) return ret case '[': ret := []any{} - dec.decodeArray(func() { - dec.stackPush(len(ret)) + typ := reflect.TypeOf(ret) + dec.decodeArray(typ, func() { + dec.stackPush(typ, len(ret)) defer dec.stackPop() ret = append(ret, dec.decodeAny()) }) return ret case '"': var buf strings.Builder - dec.decodeString(&buf) + dec.decodeString(nil, &buf) return buf.String() case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': var buf strings.Builder @@ -591,7 +676,7 @@ func (dec *Decoder) decodeAny() any { } return f64 case 't', 'f': - return dec.decodeBool() + return dec.decodeBool(nil) case 'n': dec.decodeNull() return nil @@ -600,8 +685,8 @@ func (dec *Decoder) decodeAny() any { } } -func (dec *Decoder) decodeObject(nameBuf runeBuffer, decodeKVal func()) { - dec.expectRuneType('{', RuneTypeObjectBeg) +func (dec *Decoder) decodeObject(gTyp reflect.Type, nameBuf runeBuffer, decodeKVal func()) { + dec.expectRuneType('{', RuneTypeObjectBeg, gTyp) _, t := dec.readRune() switch t { case RuneTypeObjectEnd: @@ -610,7 +695,7 @@ func (dec *Decoder) decodeObject(nameBuf runeBuffer, decodeKVal func()) { decodeMember: dec.unreadRune() nameBuf.Reset() - dec.decodeString(nameBuf) + dec.decodeString(nil, nameBuf) dec.expectRune(':', RuneTypeObjectColon) decodeKVal() _, t := dec.readRune() @@ -628,8 +713,8 @@ func (dec *Decoder) decodeObject(nameBuf runeBuffer, decodeKVal func()) { } } -func (dec *Decoder) decodeArray(decodeMember func()) { - dec.expectRuneType('[', RuneTypeArrayBeg) +func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func()) { + dec.expectRuneType('[', RuneTypeArrayBeg, gTyp) _, t := dec.readRune() switch t { case RuneTypeArrayEnd: @@ -650,8 +735,8 @@ func (dec *Decoder) decodeArray(decodeMember func()) { } } -func (dec *Decoder) decodeString(out io.Writer) { - dec.expectRuneType('"', RuneTypeStringBeg) +func (dec *Decoder) decodeString(gTyp reflect.Type, out io.Writer) { + dec.expectRuneType('"', RuneTypeStringBeg, gTyp) var uhex [4]byte for { c, t := dec.readRune() @@ -694,7 +779,42 @@ func (dec *Decoder) decodeString(out io.Writer) { rune(uhex[1])<<8 | rune(uhex[2])<<4 | rune(uhex[3])<<0 - _, _ = writeRune(out, c) + handleUnicode: + if utf16.IsSurrogate(c) { + if dec.peekRuneType() != RuneTypeStringEsc { + _, _ = writeRune(out, utf8.RuneError) + break + } + dec.expectRune('\\', RuneTypeStringEsc) + if dec.peekRuneType() != RuneTypeStringEscU { + _, _ = writeRune(out, utf8.RuneError) + break + } + dec.expectRune('u', RuneTypeStringEscU) + + b, _ := dec.readRune() + uhex[0], _ = hex2int(b) + b, _ = dec.readRune() + uhex[1], _ = hex2int(b) + b, _ = dec.readRune() + uhex[2], _ = hex2int(b) + b, _ = dec.readRune() + uhex[3], _ = hex2int(b) + c2 := 0 | + rune(uhex[0])<<12 | + rune(uhex[1])<<8 | + rune(uhex[2])<<4 | + rune(uhex[3])<<0 + d := utf16.DecodeRune(c, c2) + if d == utf8.RuneError { + _, _ = writeRune(out, utf8.RuneError) + c = c2 + goto handleUnicode + } + _, _ = writeRune(out, d) + } else { + _, _ = writeRune(out, c) + } case RuneTypeStringEnd: return default: @@ -703,8 +823,8 @@ func (dec *Decoder) decodeString(out io.Writer) { } } -func (dec *Decoder) decodeBool() bool { - c, _ := dec.readRune() +func (dec *Decoder) decodeBool(gTyp reflect.Type) bool { + c, t := dec.readRune() switch c { case 't': dec.expectRune('r', RuneTypeTrueR) @@ -718,13 +838,13 @@ func (dec *Decoder) decodeBool() bool { dec.expectRune('e', RuneTypeFalseE) return false default: - dec.panicType(boolType, fmt.Errorf("bool: expected %q or %q but got %q", 't', 'f', c)) + dec.panicType(t.jsonType(), gTyp, nil) panic("not reached") } } func (dec *Decoder) decodeNull() { - dec.expectRuneType('n', RuneTypeNullN) + dec.expectRune('n', RuneTypeNullN) dec.expectRune('u', RuneTypeNullU) dec.expectRune('l', RuneTypeNullL1) dec.expectRune('l', RuneTypeNullL2) diff --git a/decode_scan.go b/decode_scan.go index e75f1c5..3c41df6 100644 --- a/decode_scan.go +++ b/decode_scan.go @@ -6,29 +6,9 @@ package lowmemjson import ( "errors" - "fmt" "io" ) -type ReadError struct { - Err error - Offset int64 -} - -func (e *ReadError) Error() string { - return fmt.Sprintf("json: I/O error at input byte %v: %v", e.Offset, e.Err) -} -func (e *ReadError) Unwrap() error { return e.Err } - -type SyntaxError struct { - Err string - Offset int64 -} - -func (e *SyntaxError) Error() string { - return fmt.Sprintf("json: syntax error at input byte %v: %v", e.Offset, e.Err) -} - type runeTypeScanner interface { // The returned error is a *ReadError, a *SyntaxError, or nil. // An EOF condition is represented either as @@ -37,9 +17,9 @@ type runeTypeScanner interface { // // or // - // (char, size, RuneTypeError, &SyntaxError{Offset: offset: Err: io.ErrUnexepctedEOF}) + // (char, size, RuneTypeError, &DecodeSyntaxError{Offset: offset: Err: io.ErrUnexepctedEOF}) ReadRuneType() (rune, int, RuneType, error) - // The returned error is a *ReadError, a *SyntaxError, io.EOF, or nil. + // The returned error is a *DecodeReadError, a *DecodeSyntaxError, io.EOF, or nil. ReadRune() (rune, int, error) UnreadRune() error Reset() @@ -86,9 +66,9 @@ func (sc *runeTypeScannerImpl) ReadRuneType() (rune, int, RuneType, error) { case nil: sc.rType, err = sc.parser.HandleRune(sc.rRune) if err != nil { - sc.rErr = &SyntaxError{ + sc.rErr = &DecodeSyntaxError{ Offset: sc.offset, - Err: err.Error(), + Err: err, } } else { sc.rErr = nil @@ -96,16 +76,16 @@ func (sc *runeTypeScannerImpl) ReadRuneType() (rune, int, RuneType, error) { case io.EOF: sc.rType, err = sc.parser.HandleEOF() if err != nil { - sc.rErr = &SyntaxError{ + sc.rErr = &DecodeSyntaxError{ Offset: sc.offset, - Err: err.Error(), + Err: err, } } else { sc.rErr = nil } default: sc.rType = 0 - sc.rErr = &ReadError{ + sc.rErr = &DecodeReadError{ Offset: sc.offset, Err: err, } diff --git a/encode.go b/encode.go index 8479785..c881369 100644 --- a/encode.go +++ b/encode.go @@ -9,11 +9,13 @@ import ( "encoding" "encoding/base64" "encoding/json" + "fmt" "io" "reflect" "sort" "strconv" "strings" + "unsafe" ) type Encodable interface { @@ -46,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) { } } }() - encode(w, reflect.ValueOf(obj), false) + encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{}) if f, ok := w.(interface{ Flush() error }); ok { return f.Flush() } @@ -59,7 +61,9 @@ var ( textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) -func encode(w io.Writer, val reflect.Value, quote bool) { +const startDetectingCyclesAfter = 1000 + +func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -187,7 +191,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { if val.IsNil() { encodeWriteString(w, "null") } else { - encode(w, val.Elem(), quote) + encode(w, val.Elem(), quote, cycleDepth, cycleSeen) } case reflect.Struct: encodeWriteByte(w, '{') @@ -206,7 +210,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { empty = false encodeString(w, field.Name) encodeWriteByte(w, ':') - encode(w, fVal, field.Quote) + encode(w, fVal, field.Quote, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Map: @@ -218,6 +222,17 @@ func encode(w io.Writer, val reflect.Value, quote bool) { encodeWriteString(w, "{}") return } + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } encodeWriteByte(w, '{') type kv struct { @@ -228,7 +243,7 @@ func encode(w io.Writer, val reflect.Value, quote bool) { iter := val.MapRange() for i := 0; iter.Next(); i++ { var k strings.Builder - encode(&k, iter.Key(), false) + encode(&k, iter.Key(), false, cycleDepth, cycleSeen) kStr := k.String() if kStr == "null" { kStr = `""` @@ -251,14 +266,20 @@ func encode(w io.Writer, val reflect.Value, quote bool) { } encodeWriteString(w, kv.K) encodeWriteByte(w, ':') - encode(w, kv.V, false) + encode(w, kv.V, false, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Slice: switch { case val.IsNil(): encodeWriteString(w, "null") - case val.Type().Elem().Kind() == reflect.Uint8: + case val.Type().Elem().Kind() == reflect.Uint8 && !(false || + val.Type().Elem().Implements(encodableType) || + reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || + val.Type().Elem().Implements(jsonMarshalerType) || + reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || + val.Type().Elem().Implements(textMarshalerType) || + reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): encodeWriteByte(w, '"') enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { @@ -280,18 +301,40 @@ func encode(w io.Writer, val reflect.Value, quote bool) { } encodeWriteByte(w, '"') default: - encodeArray(w, val) + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } + encodeArray(w, val, cycleDepth, cycleSeen) } case reflect.Array: - encodeArray(w, val) + encodeArray(w, val, cycleDepth, cycleSeen) case reflect.Pointer: if val.IsNil() { encodeWriteString(w, "null") } else { - encode(w, val.Elem(), quote) + if cycleDepth++; cycleDepth > startDetectingCyclesAfter { + ptr := val.UnsafePointer() + if _, seen := cycleSeen[ptr]; seen { + panic(encodeError{&EncodeValueError{ + Value: val, + Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), + }}) + } + cycleSeen[ptr] = struct{}{} + defer delete(cycleSeen, ptr) + } + encode(w, val.Elem(), quote, cycleDepth, cycleSeen) } default: - panic(encodeError{&json.UnsupportedTypeError{ + panic(encodeError{&EncodeTypeError{ Type: val.Type(), }}) } @@ -310,14 +353,14 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) { encodeWriteByte(w, '"') } -func encodeArray(w io.Writer, val reflect.Value) { +func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { if i > 0 { encodeWriteByte(w, ',') } - encode(w, val.Index(i), false) + encode(w, val.Index(i), false, cycleDepth, cycleSeen) } encodeWriteByte(w, ']') } diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..a5a4080 --- /dev/null +++ b/errors.go @@ -0,0 +1,130 @@ +// Copyright (C) 2022 Luke Shumaker +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson + +import ( + "encoding/json" + "fmt" + "reflect" + "strings" + "errors" +) + +// low-level decode errors ///////////////////////////////////////////////////////////////////////// +// These will be wrapped in a *DecodeError. + +// A *DecodeReadError is returned from Decode if there is an I/O error +// reading the input. +type DecodeReadError struct { + Err error + Offset int64 +} + +func (e *DecodeReadError) Error() string { + return fmt.Sprintf("json: I/O error at input byte %v: %v", e.Offset, e.Err) +} +func (e *DecodeReadError) Unwrap() error { return e.Err } + +// A *DecodeSyntaxError is returned from Decode if there is a syntax +// error in the input. +type DecodeSyntaxError struct { + Err error + Offset int64 +} + +func (e *DecodeSyntaxError) Error() string { + return fmt.Sprintf("json: syntax error at input byte %v: %v", e.Offset, e.Err) +} + +// A *DecodeTypeError is returned from Decode if the JSON input is not +// appropriate for the given Go type. +// +// If a .DecodeJSON, .UnmarshalJSON, or .UnmashaleText method returns +// an error, it is wrapped in a *DecodeTypeError. +type DecodeTypeError struct { + JSONType string // (optional) + GoType reflect.Type + Offset int64 + Err error // (optional) +} + +func (e *DecodeTypeError) Error() string { + var buf strings.Builder + buf.WriteString("json: cannot decode ") + if e.JSONType != "" { + fmt.Fprintf(&buf, "JSON %s ", e.JSONType) + } + fmt.Fprintf(&buf, "at input byte %v in to Go %v", e.Offset, e.GoType) + if e.Err != nil { + fmt.Fprintf(&buf, ": %v", strings.TrimPrefix(e.Err.Error(), "json: ")) + } + return buf.String() +} + +func (e *DecodeTypeError) Unwrap() error { return e.Err } + +var ErrDecodeExceededMaxDepth = errors.New("exceeded max depth") + +// high-level decode errors //////////////////////////////////////////////////////////////////////// + +// A *DecodeArgumentError is returned from Decode if the argument is +// not a non-nil pointer or is not settable. +// +// Alternatively, a *DecodeArgument error may be found inside of a +// *DecodeTypeError if the type being decoded in to is not a type that +// can be decoded in to (such as map with non-stringable type as +// keys). +// +// type DecodeArgumentError struct { +// Type reflect.Type +// } +type DecodeArgumentError = json.InvalidUnmarshalError + +type DecodeError struct { + Field string + Err error + + FieldParent string // for compat + FieldName string // for compat +} + +func (e *DecodeError) Error() string { + return fmt.Sprintf("json: %s: %s", e.Field, strings.TrimPrefix(e.Err.Error(), "json: ")) +} +func (e *DecodeError) Unwrap() error { return e.Err } + +// encode errors /////////////////////////////////////////////////////////////////////////////////// + +// An *EncodeTypeError is returned by Encode when attempting to encode +// an unsupported type. +// +// type EncodeTypeError struct { +// Type reflect.Type +// } +type EncodeTypeError = json.UnsupportedTypeError + +// An *EncodeValueError is returned by Encode when attempting to +// encode an unsupported value (such as a datastructure with a cycle). +// +// type UnsupportedValueError struct { +// Value reflect.Value +// Str string +// } +type EncodeValueError = json.UnsupportedValueError + +// An *EncodeTypeError is returned by Encode when attempting to encode +// an unsupported value type. +type EncodeMethodError struct { + Type reflect.Type + Err error + SourceFunc string +} + +func (e *EncodeMethodError) Error() string { + return fmt.Sprintf("json: error calling %v for type %v: %v", + e.SourceFunc, e.Type, strings.TrimPrefix(e.Err.Error(), "json: ")) +} + +func (e *EncodeMethodError) Unwrap() error { return e.Err } diff --git a/parse.go b/parse.go index 866e9f4..23df5bc 100644 --- a/parse.go +++ b/parse.go @@ -388,7 +388,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { case 'n': return par.replaceState(RuneTypeNullN), nil default: - return RuneTypeError, fmt.Errorf("any: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q looking for beginning of value", c) } // object ////////////////////////////////////////////////////////////////////////////////// case RuneTypeObjectBeg: // waiting for key to start or '}' @@ -413,7 +413,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { par.pushState(RuneTypeError) return RuneTypeObjectColon, nil default: - return RuneTypeError, fmt.Errorf("object member: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q after object key", c) } case RuneTypeObjectComma: // waiting for ',' or '}' switch c { @@ -426,7 +426,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { par.popState() return RuneTypeObjectEnd, nil default: - return RuneTypeError, fmt.Errorf("object member: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q after object key:value pair", c) } // array /////////////////////////////////////////////////////////////////////////////////// case RuneTypeArrayBeg: // waiting for item to start or ']' @@ -452,7 +452,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { par.popState() return RuneTypeArrayEnd, nil default: - return RuneTypeError, fmt.Errorf("array: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q after array element", c) } // string ////////////////////////////////////////////////////////////////////////////////// case RuneTypeStringBeg: // waiting for char or '"' @@ -549,7 +549,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { case '1', '2', '3', '4', '5', '6', '7', '8', '9': return par.replaceState(RuneTypeNumberIntDig), nil default: - return RuneTypeError, fmt.Errorf("number: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q in numeric literal", c) } case RuneTypeNumberIntZero: // C switch c { @@ -578,7 +578,7 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return par.replaceState(RuneTypeNumberFracDig), nil default: - return RuneTypeError, fmt.Errorf("number: unexpected character: %q", c) + return RuneTypeError, fmt.Errorf("invalid character %q in numeric literal", c) } case RuneTypeNumberFracDig: // F switch c { @@ -597,14 +597,14 @@ func (par *Parser) HandleRune(c rune) (RuneType, error) { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return par.replaceState(RuneTypeNumberExpDig), nil default: - return RuneTypeError, fmt.Errorf("number: unexpected character: %c", c) + return RuneTypeError, fmt.Errorf("invalid character %q in numeric literal", c) } case RuneTypeNumberExpSign: // H switch c { case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': return par.replaceState(RuneTypeNumberExpDig), nil default: - return RuneTypeError, fmt.Errorf("number: unexpected character: %c", c) + return RuneTypeError, fmt.Errorf("invalid character %q in numeric literal", c) } case RuneTypeNumberExpDig: // I switch c { diff --git a/reencode.go b/reencode.go index 6bd1e48..92e870a 100644 --- a/reencode.go +++ b/reencode.go @@ -20,8 +20,7 @@ type ReEncoder struct { Compact bool // String to use to indent; ignored if Compact is true. Indent string - // String to put before indents, for testing-compat with - // encoding/json only. + // String to put before indents. Prefix string // Returns whether a given character in a string should be // backslash-escaped. The bool argument is whether it was @@ -80,9 +79,9 @@ func (enc *ReEncoder) Write(p []byte) (int, error) { func (enc *ReEncoder) Close() error { if enc.bufLen > 0 { - return &SyntaxError{ + return &DecodeSyntaxError{ Offset: enc.inputPos, - Err: fmt.Sprintf("%v: unflushed unicode garbage: %q", io.ErrUnexpectedEOF, enc.buf[:enc.bufLen]), + Err: fmt.Errorf("%w: unflushed unicode garbage: %q", io.ErrUnexpectedEOF, enc.buf[:enc.bufLen]), } } if _, err := enc.par.HandleEOF(); err != nil { diff --git a/struct.go b/struct.go index ad142d6..ee2bbf3 100644 --- a/struct.go +++ b/struct.go @@ -22,24 +22,24 @@ type structIndex struct { } func indexStruct(typ reflect.Type) structIndex { - byName := make(map[string][]structField) - var byPos []string + var byPos []structField + byName := make(map[string][]int) - indexStructInner(typ, nil, byName, &byPos) + indexStructInner(typ, nil, &byPos, byName, map[reflect.Type]struct{}{}) ret := structIndex{ byName: make(map[string]int), } - for _, name := range byPos { - fields := byName[name] - delete(byName, name) - switch len(fields) { + for curPos, _field := range byPos { + name := _field.Name + fieldPoss := byName[name] + switch len(fieldPoss) { case 0: // do nothing case 1: ret.byName[name] = len(ret.byPos) - ret.byPos = append(ret.byPos, fields[0]) + ret.byPos = append(ret.byPos, _field) default: // To quote the encoding/json docs (version 1.18.4): // @@ -56,27 +56,29 @@ func indexStruct(typ reflect.Type) structIndex { // // 3) Otherwise there are multiple fields, and all are ignored; no error // occurs. - leastLevel := len(fields[0].Path) - for _, field := range fields[1:] { + leastLevel := len(byPos[fieldPoss[0]].Path) + for _, fieldPos := range fieldPoss[1:] { + field := byPos[fieldPos] if len(field.Path) < leastLevel { leastLevel = len(field.Path) } } var numUntagged, numTagged int - var untaggedIdx, taggedIdx int - for i, field := range fields { + var untaggedPos, taggedPos int + for _, fieldPos := range fieldPoss { + field := byPos[fieldPos] if len(field.Path) != leastLevel { continue } if field.Tagged { numTagged++ - taggedIdx = i + taggedPos = fieldPos if numTagged > 1 { break // optimization } } else { numUntagged++ - untaggedIdx = i + untaggedPos = fieldPos } } switch numTagged { @@ -85,12 +87,16 @@ func indexStruct(typ reflect.Type) structIndex { case 0: // do nothing case 1: - ret.byName[name] = len(ret.byPos) - ret.byPos = append(ret.byPos, fields[untaggedIdx]) + if curPos == untaggedPos { + ret.byName[name] = len(ret.byPos) + ret.byPos = append(ret.byPos, byPos[curPos]) + } } case 1: - ret.byName[name] = len(ret.byPos) - ret.byPos = append(ret.byPos, fields[taggedIdx]) + if curPos == taggedPos { + ret.byName[name] = len(ret.byPos) + ret.byPos = append(ret.byPos, byPos[curPos]) + } } } } @@ -98,7 +104,13 @@ func indexStruct(typ reflect.Type) structIndex { return ret } -func indexStructInner(typ reflect.Type, prefix []int, byName map[string][]structField, byPos *[]string) { +func indexStructInner(typ reflect.Type, prefix []int, byPos *[]structField, byName map[string][]int, seen map[reflect.Type]struct{}) { + if _, ok := seen[typ]; ok { + return + } + seen[typ] = struct{}{} + defer delete(seen, typ) + n := typ.NumField() for i := 0; i < n; i++ { path := append(append([]int(nil), prefix...), i) @@ -123,25 +135,28 @@ func indexStructInner(typ reflect.Type, prefix []int, byName map[string][]struct } tagName, opts := parseTag(tag) name := tagName + if !isValidTag(name) { + name = "" + } if name == "" { name = fTyp.Name } - if embed { + if embed && tagName == "" { t := fTyp.Type if t.Kind() == reflect.Pointer { t = t.Elem() } - indexStructInner(t, path, byName, byPos) + indexStructInner(t, path, byPos, byName, seen) } else { - byName[name] = append(byName[name], structField{ + byName[name] = append(byName[name], len(*byPos)) + *byPos = append(*byPos, structField{ Name: name, Path: path, Tagged: tagName != "", OmitEmpty: opts.Contains("omitempty"), Quote: opts.Contains("string") && isQuotable(fTyp.Type), }) - *byPos = append(*byPos, name) } } } -- cgit v1.2.3-54-g00ecf