From 234e0836f1040f7724251b4120a2351bcbf64131 Mon Sep 17 00:00:00 2001 From: Luke Shumaker Date: Sat, 13 Aug 2022 15:11:17 -0600 Subject: set up as a separate repo --- decode.go | 812 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 812 insertions(+) create mode 100644 decode.go (limited to 'decode.go') diff --git a/decode.go b/decode.go new file mode 100644 index 0000000..03d5b7a --- /dev/null +++ b/decode.go @@ -0,0 +1,812 @@ +// Copyright (C) 2022 Luke Shumaker +// +// SPDX-License-Identifier: GPL-2.0-or-later + +package lowmemjson + +import ( + "bufio" + "bytes" + "encoding" + "encoding/json" + "fmt" + "io" + "reflect" + "strconv" + "strings" +) + +type Decodable interface { + DecodeJSON(io.RuneScanner) error +} + +type runeBuffer interface { + io.Writer + WriteRune(rune) (int, error) + Reset() +} + +type Decoder struct { + r io.RuneScanner + + // config + disallowUnknownFields bool + useNumber bool + + // state + err error + curPos int64 + nxtPos int64 + stack []any +} + +var forceBufio bool + +func NewDecoder(r io.Reader) *Decoder { + rs, ok := r.(io.RuneScanner) + if forceBufio || !ok { + rs = bufio.NewReader(r) + } + return &Decoder{ + r: rs, + } +} + +func (dec *Decoder) DisallowUnknownFields() { dec.disallowUnknownFields = true } +func (dec *Decoder) UseNumber() { dec.useNumber = true } +func (dec *Decoder) InputOffset() int64 { return dec.curPos } + +func (dec *Decoder) More() bool { + dec.decodeWS() + _, ok := dec.peekRuneOrEOF() + return ok +} + +func (dec *Decoder) stackStr() string { + var buf strings.Builder + buf.WriteString("v") + for _, item := range dec.stack { + fmt.Fprintf(&buf, "[%#v]", item) + } + return buf.String() +} + +func (dec *Decoder) stackPush(idx any) { + dec.stack = append(dec.stack, idx) +} +func (dec *Decoder) stackPop() { + dec.stack = dec.stack[:len(dec.stack)-1] +} + +type decodeError struct { + Err error +} + +func (dec *Decoder) panicIO(err error) { + panic(decodeError{fmt.Errorf("json: I/O error at input byte %v: %s: %w", + dec.nxtPos, dec.stackStr(), err)}) +} +func (dec *Decoder) panicSyntax(err error) { + panic(decodeError{fmt.Errorf("json: syntax error at input byte %v: %s: %w", + dec.curPos, dec.stackStr(), err)}) +} +func (dec *Decoder) panicType(typ reflect.Type, err error) { + panic(decodeError{fmt.Errorf("json: type mismatch error at input byte %v: %s: type %v: %w", + dec.curPos, dec.stackStr(), typ, err)}) +} + +func Decode(r io.Reader, ptr any) error { + return NewDecoder(r).Decode(ptr) +} + +func (dec *Decoder) Decode(ptr any) (err error) { + ptrVal := reflect.ValueOf(ptr) + if ptrVal.Kind() != reflect.Pointer || ptrVal.IsNil() || !ptrVal.Elem().CanSet() { + return &json.InvalidUnmarshalError{ + // don't use ptrVal.Type() because ptrVal might be invalid if ptr==nil + Type: reflect.TypeOf(ptr), + } + } + + if dec.err != nil { + return dec.err + } + + defer func() { + if r := recover(); r != nil { + if de, ok := r.(decodeError); ok { + dec.err = de.Err + err = dec.err + } else { + panic(r) + } + } + }() + dec.decodeWS() + dec.decode(ptrVal.Elem(), false) + return nil +} + +func (dec *Decoder) readRune() rune { + c, size, err := dec.r.ReadRune() + if err != nil { + if err == io.EOF { + dec.panicSyntax(io.ErrUnexpectedEOF) + } + dec.panicIO(err) + } + dec.curPos = dec.nxtPos + dec.nxtPos = dec.curPos + int64(size) + return c +} + +func (dec *Decoder) readRuneOrEOF() (c rune, ok bool) { + c, size, err := dec.r.ReadRune() + if err != nil { + if err == io.EOF { + return 0, false + } + dec.panicIO(err) + } + dec.curPos = dec.nxtPos + dec.nxtPos = dec.curPos + int64(size) + return c, true +} + +func (dec *Decoder) unreadRune() { + if err := dec.r.UnreadRune(); err != nil { + // .UnreadRune() must succeed if the previous call was + // .ReadRune(), which it always is for this code. + panic(err) + } + dec.nxtPos = dec.curPos +} + +func (dec *Decoder) peekRune() rune { + c, _, err := dec.r.ReadRune() + if err != nil { + if err == io.EOF { + dec.panicSyntax(io.ErrUnexpectedEOF) + } + dec.panicIO(err) + } + if err := dec.r.UnreadRune(); err != nil { + // .UnreadRune() must succeed if the previous call was + // .ReadRune(), which it always is for this code. + panic(err) + } + return c +} + +func (dec *Decoder) peekRuneOrEOF() (rune, bool) { + c, _, err := dec.r.ReadRune() + if err != nil { + if err == io.EOF { + return 0, false + } + dec.panicIO(err) + } + if err := dec.r.UnreadRune(); err != nil { + // .UnreadRune() must succeed if the previous call was + // .ReadRune(), which it always is for this code. + panic(err) + } + return c, true +} + +func (dec *Decoder) expectRune(exp rune) { + act := dec.readRune() + if act != exp { + dec.panicSyntax(fmt.Errorf("expected %q but got %q", exp, act)) + } +} + +var ( + rawMessagePtrType = reflect.TypeOf((*json.RawMessage)(nil)) + decodableType = reflect.TypeOf((*Decodable)(nil)).Elem() + jsonUnmarshalerType = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() + textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() +) + +var kind2bits = map[reflect.Kind]int{ + reflect.Int: int(32 << (^uint(0) >> 63)), + reflect.Int8: 8, + reflect.Int16: 16, + reflect.Int32: 32, + reflect.Int64: 64, + + reflect.Uint: int(32 << (^uint(0) >> 63)), + reflect.Uint8: 8, + reflect.Uint16: 16, + reflect.Uint32: 32, + reflect.Uint64: 64, + + reflect.Uintptr: int(32 << (^uintptr(0) >> 63)), + + reflect.Float32: 32, + reflect.Float64: 64, +} + +func (dec *Decoder) decode(val reflect.Value, nullOK bool) { + typ := val.Type() + switch { + case val.CanAddr() && reflect.PointerTo(typ) == rawMessagePtrType: + var buf bytes.Buffer + dec.scan(&buf) + if err := val.Addr().Interface().(*json.RawMessage).UnmarshalJSON(buf.Bytes()); err != nil { + dec.panicSyntax(err) + } + case val.CanAddr() && reflect.PointerTo(typ).Implements(decodableType): + obj := val.Addr().Interface().(Decodable) + if err := obj.DecodeJSON(dec.r); err != nil { + dec.panicSyntax(err) + } + case val.CanAddr() && reflect.PointerTo(typ).Implements(jsonUnmarshalerType): + var buf bytes.Buffer + dec.scan(&buf) + obj := val.Addr().Interface().(json.Unmarshaler) + if err := obj.UnmarshalJSON(buf.Bytes()); err != nil { + dec.panicSyntax(err) + } + case val.CanAddr() && reflect.PointerTo(typ).Implements(textUnmarshalerType): + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + var buf bytes.Buffer + dec.decodeString(&buf) + obj := val.Addr().Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(buf.Bytes()); err != nil { + dec.panicSyntax(err) + } + default: + kind := typ.Kind() + switch kind { + case reflect.Bool: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + val.SetBool(dec.decodeBool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + var buf strings.Builder + dec.scanNumber(&buf) + n, err := strconv.ParseInt(buf.String(), 10, kind2bits[kind]) + if err != nil { + dec.panicSyntax(err) + } + val.SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + var buf strings.Builder + dec.scanNumber(&buf) + n, err := strconv.ParseUint(buf.String(), 10, kind2bits[kind]) + if err != nil { + dec.panicSyntax(err) + } + val.SetUint(n) + case reflect.Float32, reflect.Float64: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + var buf strings.Builder + dec.scanNumber(&buf) + n, err := strconv.ParseFloat(buf.String(), kind2bits[kind]) + if err != nil { + dec.panicSyntax(err) + } + val.SetFloat(n) + case reflect.String: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + var buf strings.Builder + if typ == numberType { + dec.scanNumber(&buf) + val.SetString(buf.String()) + } else { + dec.decodeString(&buf) + val.SetString(buf.String()) + } + case reflect.Interface: + if typ.NumMethod() > 0 { + dec.panicType(typ, fmt.Errorf("cannot decode in to non-empty interface")) + } + switch dec.peekRune() { + case 'n': + if !val.IsNil() && val.Elem().Kind() == reflect.Pointer && val.Elem().Elem().Kind() == reflect.Pointer { + // XXX: I can't justify this case, other than "it's what encoding/json does, but + // I don't understand their rationale". + dec.decode(val.Elem(), false) + } else { + dec.decodeNull() + val.Set(reflect.Zero(typ)) + } + default: + if !val.IsNil() && val.Elem().Kind() == reflect.Pointer { + dec.decode(val.Elem(), false) + } else { + val.Set(reflect.ValueOf(dec.decodeAny())) + } + } + case reflect.Struct: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + index := indexStruct(typ) + var nameBuf strings.Builder + dec.decodeObject(&nameBuf, func() { + name := nameBuf.String() + dec.stackPush(name) + defer dec.stackPop() + idx, ok := index.byName[name] + if !ok { + if dec.disallowUnknownFields { + dec.panicType(typ, fmt.Errorf("unknown field %q", name)) + } + dec.scan(io.Discard) + return + } + field := index.byPos[idx] + fVal := val + for _, idx := range field.Path { + if fVal.Kind() == reflect.Pointer { + if fVal.IsNil() { + if !fVal.CanSet() { // https://golang.org/issue/21357 + dec.panicType(fVal.Type().Elem(), fmt.Errorf("cannot set embedded pointer to unexported type")) + } + fVal.Set(reflect.New(fVal.Type().Elem())) + } + fVal = fVal.Elem() + } + fVal = fVal.Field(idx) + } + if field.Quote { + switch dec.peekRune() { + case 'n': + dec.decodeNull() + switch fVal.Kind() { + // XXX: I can't justify this list, other than "it's what encoding/json + // does, but I don't understand their rationale". + case reflect.Interface, reflect.Pointer, reflect.Map, reflect.Slice: + fVal.Set(reflect.Zero(fVal.Type())) + } + case '"': + // TODO: Figure out how to do this without buffering. + var buf bytes.Buffer + subD := *dec // capture the .curPos *before* calling .decodeString + dec.decodeString(&buf) + subD.r = &buf + subD.decode(fVal, false) + default: + dec.panicSyntax(fmt.Errorf(",string field: expected %q or %q but got %q", + 'n', '"', dec.peekRune())) + } + } else { + dec.decode(fVal, true) + } + }) + case reflect.Map: + switch dec.peekRune() { + case 'n': + dec.decodeNull() + val.Set(reflect.Zero(typ)) + case '{': + if val.IsNil() { + val.Set(reflect.MakeMap(typ)) + } + var nameBuf bytes.Buffer + dec.decodeObject(&nameBuf, func() { + nameValTyp := typ.Key() + nameValPtr := reflect.New(nameValTyp) + switch { + case reflect.PointerTo(nameValTyp).Implements(textUnmarshalerType): + obj := nameValPtr.Interface().(encoding.TextUnmarshaler) + if err := obj.UnmarshalText(nameBuf.Bytes()); err != nil { + dec.panicSyntax(err) + } + default: + switch nameValTyp.Kind() { + case reflect.String: + nameValPtr.Elem().SetString(nameBuf.String()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + n, err := strconv.ParseInt(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + dec.panicSyntax(err) + } + nameValPtr.Elem().SetInt(n) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + n, err := strconv.ParseUint(nameBuf.String(), 10, kind2bits[nameValTyp.Kind()]) + if err != nil { + dec.panicSyntax(err) + } + nameValPtr.Elem().SetUint(n) + default: + dec.panicType(typ, fmt.Errorf("invalid map key type: %v", nameValTyp)) + } + } + dec.stackPush(nameValPtr.Elem()) + defer dec.stackPop() + + fValPtr := reflect.New(typ.Elem()) + dec.decode(fValPtr.Elem(), false) + + val.SetMapIndex(nameValPtr.Elem(), fValPtr.Elem()) + }) + default: + dec.panicSyntax(fmt.Errorf("map: expected %q or %q bug got %q", 'n', '{', dec.peekRune())) + } + case reflect.Slice: + switch { + case typ.Elem().Kind() == reflect.Uint8: + switch dec.peekRune() { + case 'n': + dec.decodeNull() + val.Set(reflect.Zero(typ)) + case '"': + var buf bytes.Buffer + dec.decodeString(newBase64Decoder(&buf)) + if typ.Elem() == byteType { + val.Set(reflect.ValueOf(buf.Bytes())) + } else { + bs := buf.Bytes() + // TODO: Surely there's a better way. + val.Set(reflect.MakeSlice(typ, len(bs), len(bs))) + for i := 0; i < len(bs); i++ { + val.Index(i).Set(reflect.ValueOf(bs[i]).Convert(typ.Elem())) + } + } + default: + dec.panicSyntax(fmt.Errorf("byte slice: expected %q or %q but got %q", 'n', '"', dec.peekRune())) + } + default: + switch dec.peekRune() { + case 'n': + dec.decodeNull() + val.Set(reflect.Zero(typ)) + case '[': + if val.IsNil() { + val.Set(reflect.MakeSlice(typ, 0, 0)) + } + if val.Len() > 0 { + val.Set(val.Slice(0, 0)) + } + i := 0 + dec.decodeArray(func() { + dec.stackPush(i) + defer dec.stackPop() + mValPtr := reflect.New(typ.Elem()) + dec.decode(mValPtr.Elem(), false) + val.Set(reflect.Append(val, mValPtr.Elem())) + i++ + }) + default: + dec.panicSyntax(fmt.Errorf("slice: expected %q or %q but got %q", 'n', '[', dec.peekRune())) + } + } + case reflect.Array: + if nullOK && dec.peekRune() == 'n' { + dec.decodeNull() + return + } + i := 0 + n := val.Len() + dec.decodeArray(func() { + dec.stackPush(i) + defer dec.stackPop() + if i < n { + mValPtr := reflect.New(typ.Elem()) + dec.decode(mValPtr.Elem(), false) + val.Index(i).Set(mValPtr.Elem()) + } else { + dec.scan(io.Discard) + } + i++ + }) + for ; i < n; i++ { + val.Index(i).Set(reflect.Zero(typ.Elem())) + } + case reflect.Pointer: + switch dec.peekRune() { + case 'n': + dec.decodeNull() + /* + for typ.Elem().Kind() == reflect.Pointer { + if val.IsNil() || !val.Elem().CanSet() { + val.Set(reflect.New(typ.Elem())) + } + val = val.Elem() + typ = val.Type() + } + */ + val.Set(reflect.Zero(typ)) + default: + if val.IsNil() { + val.Set(reflect.New(typ.Elem())) + } + dec.decode(val.Elem(), false) + } + default: + dec.panicType(typ, fmt.Errorf("unsupported type (kind=%v)", typ.Kind())) + } + } +} + +func (dec *Decoder) decodeWS() { + for { + c, ok := dec.readRuneOrEOF() + if !ok { + return + } + switch c { + // NB: The JSON definition of whitespace is more + // narrow than unicode.IsSpace + case 0x0020, 0x000A, 0x000D, 0x0009: + // do nothing + default: + dec.unreadRune() + return + } + } +} + +func (dec *Decoder) scan(out io.Writer) { + scanner := &ReEncoder{ + Out: out, + Compact: true, + } + if _, err := scanner.WriteRune(dec.readRune()); err != nil { + dec.panicSyntax(err) + } + scanner.bailAfterCurrent = true + var err error + var eof bool + for err == nil { + c, ok := dec.readRuneOrEOF() + if ok { + _, err = scanner.WriteRune(c) + } else { + eof = true + err = scanner.Flush() + break + } + } + if err != nil { + if err == errBailedAfterCurrent { + if !eof { + dec.unreadRune() + } + } else { + dec.panicSyntax(err) + } + } +} + +func (dec *Decoder) scanNumber(out io.Writer) { + c := dec.peekRune() + switch c { + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + dec.scan(out) + default: + dec.panicSyntax(fmt.Errorf("number: expected %q or a digit, but got %q", '-', c)) + } +} + +func (dec *Decoder) decodeAny() any { + c := dec.peekRune() + switch c { + case '{': + ret := make(map[string]any) + var nameBuf strings.Builder + dec.decodeObject(&nameBuf, func() { + name := nameBuf.String() + dec.stackPush(name) + defer dec.stackPop() + ret[name] = dec.decodeAny() + }) + return ret + case '[': + ret := []any{} + dec.decodeArray(func() { + dec.stackPush(len(ret)) + defer dec.stackPop() + ret = append(ret, dec.decodeAny()) + }) + return ret + case '"': + var buf strings.Builder + dec.decodeString(&buf) + return buf.String() + case '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + var buf strings.Builder + dec.scanNumber(&buf) + num := json.Number(buf.String()) + if dec.useNumber { + return num + } + f64, err := num.Float64() + if err != nil { + dec.panicSyntax(err) + } + return f64 + case 't', 'f': + return dec.decodeBool() + case 'n': + dec.decodeNull() + return nil + default: + dec.panicSyntax(fmt.Errorf("any: unexpected character: %c", c)) + panic("not reached") + } +} + +func (dec *Decoder) decodeObject(nameBuf runeBuffer, decodeKVal func()) { + dec.expectRune('{') + dec.decodeWS() + c := dec.readRune() + switch c { + case '"': + decodeMember: + dec.unreadRune() + nameBuf.Reset() + dec.decodeString(nameBuf) + dec.decodeWS() + dec.expectRune(':') + dec.decodeWS() + decodeKVal() + dec.decodeWS() + c := dec.readRune() + switch c { + case ',': + dec.decodeWS() + dec.expectRune('"') + goto decodeMember + case '}': + return + default: + dec.panicSyntax(fmt.Errorf("object: expected %q or %q but got %q", ',', '}', c)) + } + case '}': + return + default: + dec.panicSyntax(fmt.Errorf("object: expected %q or %q but got %q", '"', '}', c)) + } +} + +func (dec *Decoder) decodeArray(decodeMember func()) { + dec.expectRune('[') + dec.decodeWS() + c := dec.readRune() + switch c { + case ']': + return + default: + dec.unreadRune() + decodeNextMember: + decodeMember() + dec.decodeWS() + c := dec.readRune() + switch c { + case ',': + dec.decodeWS() + goto decodeNextMember + case ']': + return + default: + dec.panicSyntax(fmt.Errorf("array: expected %c or %c but got %c", ',', ']', c)) + } + } +} + +func (dec *Decoder) decodeHex() rune { + c := dec.readRune() + switch { + case '0' <= c && c <= '9': + return c - '0' + case 'a' <= c && c <= 'f': + return c - 'a' + 10 + case 'A' <= c && c <= 'F': + return c - 'A' + 10 + default: + dec.panicSyntax(fmt.Errorf("string: expected a hex digit but got %q", c)) + panic("not reached") + } +} + +func (dec *Decoder) decodeString(out io.Writer) { + dec.expectRune('"') + for { + c := dec.readRune() + switch { + case 0x0020 <= c && c <= 0x10FFFF && c != '"' && c != '\\': + if _, err := writeRune(out, c); err != nil { + dec.panicSyntax(err) + } + case c == '\\': + c = dec.readRune() + switch c { + case '"': + if _, err := writeRune(out, '"'); err != nil { + dec.panicSyntax(err) + } + case '\\': + if _, err := writeRune(out, '\\'); err != nil { + dec.panicSyntax(err) + } + case '/': + if _, err := writeRune(out, '/'); err != nil { + dec.panicSyntax(err) + } + case 'b': + if _, err := writeRune(out, '\b'); err != nil { + dec.panicSyntax(err) + } + case 'f': + if _, err := writeRune(out, '\f'); err != nil { + dec.panicSyntax(err) + } + case 'n': + if _, err := writeRune(out, '\n'); err != nil { + dec.panicSyntax(err) + } + case 'r': + if _, err := writeRune(out, '\r'); err != nil { + dec.panicSyntax(err) + } + case 't': + if _, err := writeRune(out, '\t'); err != nil { + dec.panicSyntax(err) + } + case 'u': + c = dec.decodeHex() + c = (c << 4) | dec.decodeHex() + c = (c << 4) | dec.decodeHex() + c = (c << 4) | dec.decodeHex() + if _, err := writeRune(out, c); err != nil { + dec.panicSyntax(err) + } + } + case c == '"': + return + default: + dec.panicSyntax(fmt.Errorf("string: unexpected %c", c)) + } + } +} + +func (dec *Decoder) decodeBool() bool { + c := dec.readRune() + switch c { + case 't': + dec.expectRune('r') + dec.expectRune('u') + dec.expectRune('e') + return true + case 'f': + dec.expectRune('a') + dec.expectRune('l') + dec.expectRune('s') + dec.expectRune('e') + return false + default: + dec.panicSyntax(fmt.Errorf("bool: expected %q or %q but got %q", 't', 'f', c)) + panic("not reached") + } +} + +func (dec *Decoder) decodeNull() { + dec.expectRune('n') + dec.expectRune('u') + dec.expectRune('l') + dec.expectRune('l') +} -- cgit v1.2.3-54-g00ecf