diff options
Diffstat (limited to 'decode.go')
-rw-r--r-- | decode.go | 267 |
1 files changed, 141 insertions, 126 deletions
@@ -96,6 +96,7 @@ type Decoder struct { // state posStack []int64 structStack []decodeStackItem + typeErr *DecodeError } const maxNestingDepth = 10000 @@ -241,19 +242,26 @@ func (dec *Decoder) Decode(ptr any) (err error) { } } + dec.typeErr = nil dec.io.Reset() dec.io.PushReadBarrier() if err := dec.decode(ptrVal.Elem(), false); err != nil { return err } dec.io.PopReadBarrier() + if dec.typeErr != nil { + return dec.typeErr + } return nil } // io helpers ////////////////////////////////////////////////////////////////////////////////////// -func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) *DecodeError { - return &DecodeError{ +func (dec *Decoder) newTypeError(jTyp string, gTyp reflect.Type, err error) { + if dec.typeErr != nil { + return + } + dec.typeErr = &DecodeError{ Field: dec.structStackStr(), FieldParent: dec.structStackParent(), FieldName: dec.structStackName(), @@ -296,7 +304,13 @@ func (dec *Decoder) peekRuneType() (jsonparse.RuneType, *DecodeError) { return t, nil } -func (dec *Decoder) expectRune(ec rune, et jsonparse.RuneType) *DecodeError { +// expectRuneOrPanic is for when you *know* what the next +// non-whitespace rune is going to be; for it to be anything else +// would be a syntax error. It will return an error for I/O errors +// and syntax errors, but panic if the result is not what was +// expected; as that would indicate a bug in the agreement between the +// parser and the decoder. +func (dec *Decoder) expectRuneOrPanic(ec rune, et jsonparse.RuneType) *DecodeError { ac, at, err := dec.readRune() if err != nil { return err @@ -307,17 +321,6 @@ func (dec *Decoder) expectRune(ec rune, et jsonparse.RuneType) *DecodeError { return nil } -func (dec *Decoder) expectRuneType(ec rune, et jsonparse.RuneType, gt reflect.Type) *DecodeError { - ac, at, err := dec.readRune() - if err != nil { - return err - } - if ac != ec || at != et { - return dec.newTypeError(at.JSONType(), gt, nil) - } - return nil -} - type decRuneScanner struct { dec *Decoder eof bool @@ -350,7 +353,11 @@ func (sc *decRuneScanner) UnreadRune() error { return sc.dec.io.UnreadRune() } -func (dec *Decoder) withLimitingScanner(fn func(io.RuneScanner) *DecodeError) (err *DecodeError) { +func (dec *Decoder) withLimitingScanner(gTyp reflect.Type, fn func(io.RuneScanner) error) (err *DecodeError) { + t, err := dec.peekRuneType() + if err != nil { + return err + } dec.io.PushReadBarrier() defer func() { if r := recover(); r != nil { @@ -361,8 +368,15 @@ func (dec *Decoder) withLimitingScanner(fn func(io.RuneScanner) *DecodeError) (e } } }() - if err := fn(&decRuneScanner{dec: dec}); err != nil { - return err + l := &decRuneScanner{dec: dec} + if err := fn(l); err != nil { + dec.newTypeError(t.JSONType(), gTyp, err) + } + if _, _, err := l.ReadRune(); err != io.EOF { + dec.newTypeError(t.JSONType(), gTyp, fmt.Errorf("did not consume entire %s", t.JSONType())) + for err != io.EOF { + _, _, err = l.ReadRune() + } } return nil } @@ -403,23 +417,11 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if err := val.Addr().Interface().(*RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): - t, err := dec.peekRuneType() - if err != nil { - return err - } obj := val.Addr().Interface().(Decodable) - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := obj.DecodeJSON(l); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + return dec.withLimitingScanner(reflect.PointerTo(typ), obj.DecodeJSON) case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): t, err := dec.peekRuneType() if err != nil { @@ -431,7 +433,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(jsonUnmarshaler) if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { - return dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) + dec.newTypeError(t.JSONType(), reflect.PointerTo(typ), err) } case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -443,7 +445,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } obj := val.Addr().Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(buf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(typ), err) + dec.newTypeError("string", reflect.PointerTo(typ), err) } default: switch kind := typ.Kind(); kind { @@ -460,39 +462,60 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetUint(n) case reflect.Float32, reflect.Float64: if ok, err := dec.maybeDecodeNull(nullOK); ok { return err } + if t, err := dec.peekRuneType(); err != nil { + return err + } else if !t.IsNumber() { + dec.newTypeError(t.JSONType(), typ, nil) + return dec.scan(fastio.Discard) + } var buf strings.Builder - if err := dec.scanNumber(typ, &buf); err != nil { + if err := dec.scan(&buf); err != nil { return err } n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) if err != nil { - return dec.newTypeError("number "+buf.String(), typ, err) + dec.newTypeError("number "+buf.String(), typ, err) + return nil } val.SetFloat(n) case reflect.String: @@ -509,9 +532,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), typ, + dec.newTypeError(t.JSONType(), typ, fmt.Errorf("json: invalid number literal, trying to unmarshal %q into Number", buf.String())) + return nil } val.SetString(buf.String()) } else { @@ -526,7 +550,8 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return err } if typ.NumMethod() > 0 { - return dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + dec.newTypeError(t.JSONType(), typ, ErrDecodeNonEmptyInterface) + return dec.scan(fastio.Discard) } // If the interface stores a pointer, try to use the type information of the pointer. if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { @@ -570,7 +595,9 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { if err != nil { return err } - val.Set(reflect.ValueOf(v)) + if v != nil { + val.Set(reflect.ValueOf(v)) + } } case reflect.Struct: if ok, err := dec.maybeDecodeNull(nullOK); ok { @@ -601,7 +628,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if !ok { if dec.disallowUnknownFields { - return dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) + dec.newTypeError("", typ, fmt.Errorf("json: unknown field %q", name)) } return dec.scan(fastio.Discard) } @@ -610,9 +637,10 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { for _, idx := range field.Path { if fVal.Kind() == reflect.Pointer { if fVal.IsNil() && !fVal.CanSet() { // https://golang.org/issue/21357 - return dec.newTypeError("", fVal.Type().Elem(), + dec.newTypeError("", fVal.Type().Elem(), fmt.Errorf("json: cannot set embedded pointer to unexported struct: %v", fVal.Type().Elem())) + return dec.scan(fastio.Discard) } t, err := dec.peekRuneType() if err != nil { @@ -653,15 +681,16 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } if err := NewDecoder(bytes.NewReader(buf.Bytes())).Decode(fVal.Addr().Interface()); err != nil { if str := buf.String(); str != "null" { - return dec.newTypeError("", fVal.Type(), + dec.newTypeError("", fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal %q into %v", str, fVal.Type())) } } default: - return dec.newTypeError(t.JSONType(), fVal.Type(), + dec.newTypeError(t.JSONType(), fVal.Type(), fmt.Errorf("json: invalid use of ,string struct tag, trying to unmarshal unquoted value into %v", fVal.Type())) + return dec.scan(fastio.Discard) } return nil } else { @@ -698,7 +727,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): obj := nameValPtr.Interface().(encoding.TextUnmarshaler) if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { - return dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) + dec.newTypeError("string", reflect.PointerTo(nameValTyp), err) } default: switch nameValTyp.Kind() { @@ -707,17 +736,19 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetInt(n) case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) if err != nil { - return dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + dec.newTypeError("number "+nameBuf.String(), nameValTyp, err) + return nil } nameValPtr.Elem().SetUint(n) default: - return dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) + dec.newTypeError("object", typ, &DecodeArgumentError{Type: nameValTyp}) } } return nil @@ -736,7 +767,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } case reflect.Slice: t, err := dec.peekRuneType() @@ -775,7 +806,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { } } default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } default: switch t { @@ -806,7 +837,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return nil }) default: - return dec.newTypeError(t.JSONType(), typ, nil) + dec.newTypeError(t.JSONType(), typ, nil) } } case reflect.Array: @@ -857,7 +888,7 @@ func (dec *Decoder) decode(val reflect.Value, nullOK bool) *DecodeError { return dec.decode(val.Elem(), false) } default: - return dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) + dec.newTypeError("", typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) } } return nil @@ -879,17 +910,6 @@ func (dec *Decoder) scan(out fastio.RuneWriter) *DecodeError { return nil } -func (dec *Decoder) scanNumber(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { - t, err := dec.peekRuneType() - if err != nil { - return err - } - if !t.IsNumber() { - return dec.newTypeError(t.JSONType(), gTyp, nil) - } - return dec.scan(out) -} - func (dec *Decoder) decodeAny() (any, *DecodeError) { t, err := dec.peekRuneType() if err != nil { @@ -956,7 +976,8 @@ func (dec *Decoder) decodeAny() (any, *DecodeError) { } f64, err := num.Float64() if err != nil { - return nil, dec.newTypeError("number "+buf.String(), float64Type, err) + dec.newTypeError("number "+buf.String(), float64Type, err) + return nil, nil } return f64, nil case jsonparse.RuneTypeTrueT, jsonparse.RuneTypeFalseF: @@ -983,51 +1004,41 @@ func DecodeObject(r io.RuneScanner, decodeKey, decodeVal func(io.RuneScanner) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeObject(nil, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeKey(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError("string", nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError("string", nil, fmt.Errorf("did not consume entire string")) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeKey) }, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - t, err := dec.peekRuneType() - if err != nil { - return err - } - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeVal(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeVal) }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('{', jsonparse.RuneTypeObjectBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeObjectBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1042,7 +1053,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() if err := decodeKey(); err != nil { return err } - if err := dec.expectRune(':', jsonparse.RuneTypeObjectColon); err != nil { + if err := dec.expectRuneOrPanic(':', jsonparse.RuneTypeObjectColon); err != nil { return err } if err := decodeVal(); err != nil { @@ -1054,7 +1065,7 @@ func (dec *Decoder) decodeObject(gTyp reflect.Type, decodeKey, decodeVal func() } switch t { case jsonparse.RuneTypeObjectComma: - if err := dec.expectRune('"', jsonparse.RuneTypeStringBeg); err != nil { + if err := dec.expectRuneOrPanic('"', jsonparse.RuneTypeStringBeg); err != nil { return err } goto decodeMember @@ -1083,35 +1094,34 @@ func DecodeArray(r io.RuneScanner, decodeMember func(r io.RuneScanner) error) er } else { dec = NewDecoder(r) } + if dec.typeErr != nil { + oldTypeErr := dec.typeErr + dec.typeErr = nil + defer func() { dec.typeErr = oldTypeErr }() + } dec.posStackPush() defer dec.posStackPop() if err := dec.decodeArray(nil, func() *DecodeError { dec.posStackPush() defer dec.posStackPop() - t, err := dec.peekRuneType() - if err != nil { - return err - } - return dec.withLimitingScanner(func(l io.RuneScanner) *DecodeError { - if err := decodeMember(l); err != nil { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, err) - } - if _, _, err := l.ReadRune(); err != io.EOF { - // TODO: Find a better Go type to use than `nil`. - return dec.newTypeError(t.JSONType(), nil, fmt.Errorf("did not consume entire %s", t.JSONType())) - } - return nil - }) + // TODO: Find a better Go type to use than `nil`. + return dec.withLimitingScanner(nil, decodeMember) }); err != nil { return err } + if dec.typeErr != nil { + return dec.typeErr + } return nil } func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeError) *DecodeError { - if err := dec.expectRuneType('[', jsonparse.RuneTypeArrayBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeArrayBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } _, t, err := dec.readRune() if err != nil { @@ -1142,8 +1152,12 @@ func (dec *Decoder) decodeArray(gTyp reflect.Type, decodeMember func() *DecodeEr } func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *DecodeError { - if err := dec.expectRuneType('"', jsonparse.RuneTypeStringBeg, gTyp); err != nil { + if _, t, err := dec.readRune(); err != nil { return err + } else if t != jsonparse.RuneTypeStringBeg { + dec.newTypeError(t.JSONType(), gTyp, nil) + dec.unreadRune() + return dec.scan(fastio.Discard) } var uhex [3]byte for { @@ -1195,7 +1209,7 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *Deco _, _ = out.WriteRune(utf8.RuneError) break } - if err := dec.expectRune('\\', jsonparse.RuneTypeStringEsc); err != nil { + if err := dec.expectRuneOrPanic('\\', jsonparse.RuneTypeStringEsc); err != nil { return err } t, err = dec.peekRuneType() @@ -1206,7 +1220,7 @@ func (dec *Decoder) decodeString(gTyp reflect.Type, out fastio.RuneWriter) *Deco _, _ = out.WriteRune(utf8.RuneError) break } - if err := dec.expectRune('u', jsonparse.RuneTypeStringEscU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeStringEscU); err != nil { return err } @@ -1255,46 +1269,47 @@ func (dec *Decoder) decodeBool(gTyp reflect.Type) (bool, *DecodeError) { } switch c { case 't': - if err := dec.expectRune('r', jsonparse.RuneTypeTrueR); err != nil { + if err := dec.expectRuneOrPanic('r', jsonparse.RuneTypeTrueR); err != nil { return false, err } - if err := dec.expectRune('u', jsonparse.RuneTypeTrueU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeTrueU); err != nil { return false, err } - if err := dec.expectRune('e', jsonparse.RuneTypeTrueE); err != nil { + if err := dec.expectRuneOrPanic('e', jsonparse.RuneTypeTrueE); err != nil { return false, err } return true, nil case 'f': - if err := dec.expectRune('a', jsonparse.RuneTypeFalseA); err != nil { + if err := dec.expectRuneOrPanic('a', jsonparse.RuneTypeFalseA); err != nil { return false, err } - if err := dec.expectRune('l', jsonparse.RuneTypeFalseL); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeFalseL); err != nil { return false, err } - if err := dec.expectRune('s', jsonparse.RuneTypeFalseS); err != nil { + if err := dec.expectRuneOrPanic('s', jsonparse.RuneTypeFalseS); err != nil { return false, err } - if err := dec.expectRune('e', jsonparse.RuneTypeFalseE); err != nil { + if err := dec.expectRuneOrPanic('e', jsonparse.RuneTypeFalseE); err != nil { return false, err } return false, nil default: - return false, dec.newTypeError(t.JSONType(), gTyp, nil) + dec.newTypeError(t.JSONType(), gTyp, nil) + return false, nil } } func (dec *Decoder) decodeNull() *DecodeError { - if err := dec.expectRune('n', jsonparse.RuneTypeNullN); err != nil { + if err := dec.expectRuneOrPanic('n', jsonparse.RuneTypeNullN); err != nil { return err } - if err := dec.expectRune('u', jsonparse.RuneTypeNullU); err != nil { + if err := dec.expectRuneOrPanic('u', jsonparse.RuneTypeNullU); err != nil { return err } - if err := dec.expectRune('l', jsonparse.RuneTypeNullL1); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeNullL1); err != nil { return err } - if err := dec.expectRune('l', jsonparse.RuneTypeNullL2); err != nil { + if err := dec.expectRuneOrPanic('l', jsonparse.RuneTypeNullL2); err != nil { return err } return nil |