diff options
Diffstat (limited to 'encode.go')
-rw-r--r-- | encode.go | 61 |
1 files changed, 36 insertions, 25 deletions
@@ -39,7 +39,12 @@ func encodeWriteString(w io.Writer, str string) { } } -func Encode(w io.Writer, obj any) (err error) { +type Encoder struct { + W io.Writer + BackslashEscape BackslashEscaper +} + +func (enc *Encoder) Encode(obj any) (err error) { defer func() { if r := recover(); r != nil { if e, ok := r.(encodeError); ok { @@ -49,13 +54,18 @@ func Encode(w io.Writer, obj any) (err error) { } } }() - encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{}) - if f, ok := w.(interface{ Flush() error }); ok { + encode(enc.W, reflect.ValueOf(obj), enc.BackslashEscape, false, 0, map[any]struct{}{}) + if f, ok := enc.W.(interface{ Flush() error }); ok { return f.Flush() } return nil } +func Encode(w io.Writer, obj any) (err error) { + enc := &Encoder{W: w} + return enc.Encode(obj) +} + var ( encodableType = reflect.TypeOf((*Encodable)(nil)).Elem() jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem() @@ -64,7 +74,7 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { +func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -84,7 +94,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteString(w, "null") return } - validator := &ReEncoder{Out: w} + validator := &ReEncoder{Out: w, BackslashEscape: escaper} if err := obj.EncodeJSON(validator); err != nil { panic(encodeError{&EncodeMethodError{ Type: val.Type(), @@ -117,7 +127,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe SourceFunc: "MarshalJSON", }}) } - validator := &ReEncoder{Out: w} + validator := &ReEncoder{Out: w, BackslashEscape: escaper} if _, err := validator.Write(dat); err != nil { panic(encodeError{err}) } @@ -146,7 +156,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe SourceFunc: "MarshalText", }}) } - encodeStringFromBytes(w, text) + encodeStringFromBytes(w, escaper, text) default: switch val.Kind() { @@ -202,17 +212,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } else { if quote { var buf bytes.Buffer - encodeStringFromString(&buf, val.String()) - encodeStringFromBytes(w, buf.Bytes()) + encodeStringFromString(&buf, escaper, val.String()) + encodeStringFromBytes(w, escaper, buf.Bytes()) } else { - encodeStringFromString(w, val.String()) + encodeStringFromString(w, escaper, val.String()) } } case reflect.Interface: if val.IsNil() { encodeWriteString(w, "null") } else { - encode(w, val.Elem(), quote, cycleDepth, cycleSeen) + encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) } case reflect.Struct: encodeWriteByte(w, '{') @@ -229,9 +239,9 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteByte(w, ',') } empty = false - encodeStringFromString(w, field.Name) + encodeStringFromString(w, escaper, field.Name) encodeWriteByte(w, ':') - encode(w, fVal, field.Quote, cycleDepth, cycleSeen) + encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Map: @@ -263,15 +273,16 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe kvs := make([]kv, val.Len()) iter := val.MapRange() for i := 0; iter.Next(); i++ { + // TODO: Avoid buffering the map key var k strings.Builder - encode(&k, iter.Key(), false, cycleDepth, cycleSeen) + encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen) kStr := k.String() if kStr == "null" { kStr = `""` } if !strings.HasPrefix(kStr, `"`) { k.Reset() - encodeStringFromString(&k, kStr) + encodeStringFromString(&k, escaper, kStr) kStr = k.String() } kvs[i].K = kStr @@ -287,7 +298,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } encodeWriteString(w, kv.K) encodeWriteByte(w, ':') - encode(w, kv.V, false, cycleDepth, cycleSeen) + encode(w, kv.V, escaper, false, cycleDepth, cycleSeen) } encodeWriteByte(w, '}') case reflect.Slice: @@ -341,10 +352,10 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encodeArray(w, val, cycleDepth, cycleSeen) + encodeArray(w, val, escaper, cycleDepth, cycleSeen) } case reflect.Array: - encodeArray(w, val, cycleDepth, cycleSeen) + encodeArray(w, val, escaper, cycleDepth, cycleSeen) case reflect.Pointer: if val.IsNil() { encodeWriteString(w, "null") @@ -360,7 +371,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe cycleSeen[ptr] = struct{}{} defer delete(cycleSeen, ptr) } - encode(w, val.Elem(), quote, cycleDepth, cycleSeen) + encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen) } default: panic(encodeError{&EncodeTypeError{ @@ -370,21 +381,21 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } } -func encodeStringFromString(w io.Writer, str string) { +func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) { encodeWriteByte(w, '"') for _, c := range str { - if _, err := writeStringChar(w, c, BackslashEscapeNone, nil); err != nil { + if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil { panic(encodeError{err}) } } encodeWriteByte(w, '"') } -func encodeStringFromBytes(w io.Writer, str []byte) { +func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) { encodeWriteByte(w, '"') for i := 0; i < len(str); { c, size := utf8.DecodeRune(str[i:]) - if _, err := writeStringChar(w, c, BackslashEscapeNone, nil); err != nil { + if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil { panic(encodeError{err}) } i += size @@ -392,14 +403,14 @@ func encodeStringFromBytes(w io.Writer, str []byte) { encodeWriteByte(w, '"') } -func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) { +func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { if i > 0 { encodeWriteByte(w, ',') } - encode(w, val.Index(i), false, cycleDepth, cycleSeen) + encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen) } encodeWriteByte(w, ']') } |