// Copyright (C) 2022-2023 Luke Shumaker // // SPDX-License-Identifier: GPL-2.0-or-later package lowmemjson import ( "bytes" "encoding" "encoding/base64" "fmt" "io" "reflect" "sort" "strconv" "strings" "unsafe" "git.lukeshu.com/go/lowmemjson/internal/jsonstring" "git.lukeshu.com/go/lowmemjson/internal/jsonstruct" ) // Encodable is the interface implemented by types that can encode // themselves to JSON. Encodable is a low-memory-overhead replacement // for the json.Marshaler interface. // // The io.Writer passed to EncodeJSON returns an error if invalid JSON // is written to it. type Encodable interface { EncodeJSON(w io.Writer) error } // An Encoder encodes and writes values to a stream of JSON elements. // // Encoder is analogous to, and has a similar API to the standar // library's encoding/json.Encoder. Differences are that rather than // having .SetEscapeHTML and .SetIndent methods, the io.Writer passed // to it may be a *ReEncoder that has these settings (and more). If // something more similar to a json.Encoder is desired, // lowmemjson/compat/json.Encoder offers those .SetEscapeHTML and // .SetIndent methods. type Encoder struct { w *ReEncoder isRoot bool } // NewEncoder returns a new Encoder that writes to w. // // If w is an *ReEncoder, then the inner backslash-escaping of // double-encoded ",string" tagged string values obeys the // *ReEncoder's BackslashEscape policy. // // An Encoder tends to make many small writes; if w.Write calls are // syscalls, then you may want to wrap w in a bufio.Writer. func NewEncoder(w io.Writer) *Encoder { re, ok := w.(*ReEncoder) if !ok { re = NewReEncoder(w, ReEncoderConfig{ AllowMultipleValues: true, }) } return &Encoder{ w: re, isRoot: re.par.StackIsEmpty(), } } // Encode encodes obj to JSON and writes that JSON to the Encoder's // output stream. // // See the [documentation for encoding/json.Marshal] for details about // the conversion Go values to JSON; Encode behaves identically to // that, with the exception that in addition to the json.Marshaler // interface it also checks for the Encodable interface. // // Unlike encoding/json.Encoder.Encode, lowmemjson.Encoder.Encode does // not buffer its output; if a encode-error is encountered, lowmemjson // may write partial output, whereas encodin/json would not have // written anything. // // [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.20#Marshal func (enc *Encoder) Encode(obj any) (err error) { if enc.isRoot { enc.w.par.Reset() } escaper := enc.w.esc if escaper == nil { escaper = EscapeDefault } if err := encode(enc.w, reflect.ValueOf(obj), escaper, enc.w.utf, false, 0, map[any]struct{}{}); err != nil { if rwe, ok := err.(*ReEncodeWriteError); ok { err = &EncodeWriteError{ Err: rwe.Err, Offset: rwe.Offset, } } return err } if enc.isRoot { return enc.w.Close() } return nil } func discardInt(_ int, err error) error { return err } const startDetectingCyclesAfter = 1000 func encode(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, utf InvalidUTF8Mode, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error { if !val.IsValid() { return discardInt(w.WriteString("null")) } switch { case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(encodableType): val = val.Addr() fallthrough case val.Type().Implements(encodableType): if val.Kind() == reflect.Pointer && val.IsNil() { return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(Encodable) if !ok { return discardInt(w.WriteString("null")) } w.pushWriteBarrier() if err := obj.EncodeJSON(w); err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, } } if err := w.Close(); err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "EncodeJSON", Err: err, } } w.popWriteBarrier() case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(jsonMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(jsonMarshaler) if !ok { return discardInt(w.WriteString("null")) } dat, err := obj.MarshalJSON() if err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, } } w.pushWriteBarrier() if _, err := w.Write(dat); err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, } } if err := w.Close(); err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalJSON", Err: err, } } w.popWriteBarrier() case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType): val = val.Addr() fallthrough case val.Type().Implements(textMarshalerType): if val.Kind() == reflect.Pointer && val.IsNil() { return discardInt(w.WriteString("null")) } obj, ok := val.Interface().(encoding.TextMarshaler) if !ok { return discardInt(w.WriteString("null")) } text, err := obj.MarshalText() if err != nil { return &EncodeMethodError{ Type: val.Type(), SourceFunc: "MarshalText", Err: err, } } if err := jsonstring.EncodeStringFromBytes(w, escaper, utf, val, text); err != nil { return err } default: switch val.Kind() { case reflect.Bool: if quote { if err := w.WriteByte('"'); err != nil { return err } } if val.Bool() { if _, err := w.WriteString("true"); err != nil { return err } } else { if _, err := w.WriteString("false"); err != nil { return err } } if quote { if err := w.WriteByte('"'); err != nil { return err } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: if quote { if err := w.WriteByte('"'); err != nil { return err } } // MaxInt64 = 9223372036854775807 // MinInt64 = -9223372036854775808 // 0 1 2 // 12345678901234567890 var buf [20]byte if _, err := w.Write(strconv.AppendInt(buf[:0], val.Int(), 10)); err != nil { return err } if quote { if err := w.WriteByte('"'); err != nil { return err } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: if quote { if err := w.WriteByte('"'); err != nil { return err } } // MaxUint64 = 18446744073709551615 // 0 1 2 // 12345678901234567890 var buf [20]byte if _, err := w.Write(strconv.AppendUint(buf[:0], val.Uint(), 10)); err != nil { return err } if quote { if err := w.WriteByte('"'); err != nil { return err } } case reflect.Float32: if quote { if err := w.WriteByte('"'); err != nil { return err } } if err := encodeFloat(w, 32, val); err != nil { return err } if quote { if err := w.WriteByte('"'); err != nil { return err } } case reflect.Float64: if quote { if err := w.WriteByte('"'); err != nil { return err } } if err := encodeFloat(w, 64, val); err != nil { return err } if quote { if err := w.WriteByte('"'); err != nil { return err } } case reflect.String: if val.Type() == numberType { numStr := val.String() if numStr == "" { numStr = "0" } if quote { if err := w.WriteByte('"'); err != nil { return err } } if _, err := w.WriteString(numStr); err != nil { return err } if quote { if err := w.WriteByte('"'); err != nil { return err } } } else { if quote { var buf bytes.Buffer if err := jsonstring.EncodeStringFromString(&buf, escaper, utf, val, val.String()); err != nil { return err } if err := jsonstring.EncodeStringFromBytes(w, escaper, utf, val, buf.Bytes()); err != nil { return err } } else { if err := jsonstring.EncodeStringFromString(w, escaper, utf, val, val.String()); err != nil { return err } } } case reflect.Interface: if val.IsNil() { if _, err := w.WriteString("null"); err != nil { return err } } else { if err := encode(w, val.Elem(), escaper, utf, quote, cycleDepth, cycleSeen); err != nil { return err } } case reflect.Struct: if err := w.WriteByte('{'); err != nil { return err } empty := true for _, field := range jsonstruct.IndexStruct(val.Type()).ByPos { fVal, err := val.FieldByIndexErr(field.Path) if err != nil { continue } if field.OmitEmpty && isEmptyValue(fVal) { continue } if !empty { if err := w.WriteByte(','); err != nil { return err } } empty = false if err := jsonstring.EncodeStringFromString(w, escaper, utf, val, field.Name); err != nil { return err } if err := w.WriteByte(':'); err != nil { return err } if err := encode(w, fVal, escaper, utf, field.Quote, cycleDepth, cycleSeen); err != nil { return err } } if err := w.WriteByte('}'); err != nil { return err } case reflect.Map: if val.IsNil() { return discardInt(w.WriteString("null")) } if val.Len() == 0 { return discardInt(w.WriteString("{}")) } if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } if err := w.WriteByte('{'); err != nil { return err } var kBuf strings.Builder kEnc := NewReEncoder(&kBuf, ReEncoderConfig{ AllowMultipleValues: true, Compact: true, BackslashEscape: escaper, InvalidUTF8: utf, }) type kv struct { KStr string K reflect.Value V reflect.Value } kvs := make([]kv, val.Len()) iter := val.MapRange() for i := 0; iter.Next(); i++ { if err := encode(kEnc, iter.Key(), escaper, utf, false, cycleDepth, cycleSeen); err != nil { return err } if err := kEnc.Close(); err != nil { return err } kStr := strings.Trim(kBuf.String(), "\n") kBuf.Reset() if kStr == "null" { kStr = "" } // TODO(lukeshu): Have kEnc look at the first byte, and feed directly to a decoder, // instead of needing to buffer the whole thing twice. if strings.HasPrefix(kStr, `"`) { if err := DecodeString(strings.NewReader(kStr), &kBuf); err != nil { return err } kStr = kBuf.String() kBuf.Reset() } kvs[i].KStr = kStr kvs[i].K = iter.Key() kvs[i].V = iter.Value() } sort.Slice(kvs, func(i, j int) bool { return kvs[i].KStr < kvs[j].KStr }) for i, kv := range kvs { if i > 0 { if err := w.WriteByte(','); err != nil { return err } } if err := jsonstring.EncodeStringFromString(w, escaper, utf, kv.K, kv.KStr); err != nil { return err } if err := w.WriteByte(':'); err != nil { return err } if err := encode(w, kv.V, escaper, utf, false, cycleDepth, cycleSeen); err != nil { return err } } if err := w.WriteByte('}'); err != nil { return err } case reflect.Slice: switch { case val.IsNil(): if _, err := w.WriteString("null"); err != nil { return err } case val.Type().Elem().Kind() == reflect.Uint8 && !(false || val.Type().Elem().Implements(encodableType) || reflect.PointerTo(val.Type().Elem()).Implements(encodableType) || val.Type().Elem().Implements(jsonMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) || val.Type().Elem().Implements(textMarshalerType) || reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)): if err := w.WriteByte('"'); err != nil { return err } enc := base64.NewEncoder(base64.StdEncoding, w) if val.CanConvert(byteSliceType) { if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil { return err } } else { // TODO: Surely there's a better way. for i, n := 0, val.Len(); i < n; i++ { var buf [1]byte buf[0] = val.Index(i).Convert(byteType).Interface().(byte) if _, err := enc.Write(buf[:]); err != nil { return err } } } if err := enc.Close(); err != nil { return err } if err := w.WriteByte('"'); err != nil { return err } default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { // For slices, val.UnsafePointer() doesn't return a pointer to the slice header // or anything like that, it returns a pointer *to the first element in the // slice*. That means that the pointer isn't enough to uniquely identify the // slice! So we pair the pointer with the length of the slice, which is // sufficient. ptr := struct { ptr unsafe.Pointer len int }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } if err := encodeArray(w, val, escaper, utf, cycleDepth, cycleSeen); err != nil { return err } } case reflect.Array: if err := encodeArray(w, val, escaper, utf, cycleDepth, cycleSeen); err != nil { return err } case reflect.Pointer: if val.IsNil() { if _, err := w.WriteString("null"); err != nil { return err } } else { if cycleDepth++; cycleDepth > startDetectingCyclesAfter { ptr := val.UnsafePointer() if _, seen := cycleSeen[ptr]; seen { return &EncodeValueError{ Value: val, Str: fmt.Sprintf("encountered a cycle via %s", val.Type()), } } cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } if err := encode(w, val.Elem(), escaper, utf, quote, cycleDepth, cycleSeen); err != nil { return err } } default: return &EncodeTypeError{ Type: val.Type(), } } } return nil } func encodeArray(w *ReEncoder, val reflect.Value, escaper BackslashEscaper, utf InvalidUTF8Mode, cycleDepth uint, cycleSeen map[any]struct{}) error { if err := w.WriteByte('['); err != nil { return err } n := val.Len() for i := 0; i < n; i++ { if i > 0 { if err := w.WriteByte(','); err != nil { return err } } if err := encode(w, val.Index(i), escaper, utf, false, cycleDepth, cycleSeen); err != nil { return err } } if err := w.WriteByte(']'); err != nil { return err } return nil }