summaryrefslogtreecommitdiff
path: root/encode.go
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@datawire.io>2022-08-14 20:52:33 -0600
committerLuke Shumaker <lukeshu@datawire.io>2022-08-17 02:02:47 -0600
commit28dc29b7b05dc9c7ea1cec577963757f75faa601 (patch)
treece7e0c4ddfeed8e2db99bf72383e71fe7fef4f20 /encode.go
parent35997d235f3bac7c3f9bcd4b8d2b26b0d88dc387 (diff)
Get the new borrowed tests passing
Diffstat (limited to 'encode.go')
-rw-r--r--encode.go61
1 files changed, 36 insertions, 25 deletions
diff --git a/encode.go b/encode.go
index f6f8f0e..a77d8aa 100644
--- a/encode.go
+++ b/encode.go
@@ -39,7 +39,12 @@ func encodeWriteString(w io.Writer, str string) {
}
}
-func Encode(w io.Writer, obj any) (err error) {
+type Encoder struct {
+ W io.Writer
+ BackslashEscape BackslashEscaper
+}
+
+func (enc *Encoder) Encode(obj any) (err error) {
defer func() {
if r := recover(); r != nil {
if e, ok := r.(encodeError); ok {
@@ -49,13 +54,18 @@ func Encode(w io.Writer, obj any) (err error) {
}
}
}()
- encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{})
- if f, ok := w.(interface{ Flush() error }); ok {
+ encode(enc.W, reflect.ValueOf(obj), enc.BackslashEscape, false, 0, map[any]struct{}{})
+ if f, ok := enc.W.(interface{ Flush() error }); ok {
return f.Flush()
}
return nil
}
+func Encode(w io.Writer, obj any) (err error) {
+ enc := &Encoder{W: w}
+ return enc.Encode(obj)
+}
+
var (
encodableType = reflect.TypeOf((*Encodable)(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
@@ -64,7 +74,7 @@ var (
const startDetectingCyclesAfter = 1000
-func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
+func encode(w io.Writer, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
if !val.IsValid() {
encodeWriteString(w, "null")
return
@@ -84,7 +94,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteString(w, "null")
return
}
- validator := &ReEncoder{Out: w}
+ validator := &ReEncoder{Out: w, BackslashEscape: escaper}
if err := obj.EncodeJSON(validator); err != nil {
panic(encodeError{&EncodeMethodError{
Type: val.Type(),
@@ -117,7 +127,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
SourceFunc: "MarshalJSON",
}})
}
- validator := &ReEncoder{Out: w}
+ validator := &ReEncoder{Out: w, BackslashEscape: escaper}
if _, err := validator.Write(dat); err != nil {
panic(encodeError{err})
}
@@ -146,7 +156,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
SourceFunc: "MarshalText",
}})
}
- encodeStringFromBytes(w, text)
+ encodeStringFromBytes(w, escaper, text)
default:
switch val.Kind() {
@@ -202,17 +212,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
} else {
if quote {
var buf bytes.Buffer
- encodeStringFromString(&buf, val.String())
- encodeStringFromBytes(w, buf.Bytes())
+ encodeStringFromString(&buf, escaper, val.String())
+ encodeStringFromBytes(w, escaper, buf.Bytes())
} else {
- encodeStringFromString(w, val.String())
+ encodeStringFromString(w, escaper, val.String())
}
}
case reflect.Interface:
if val.IsNil() {
encodeWriteString(w, "null")
} else {
- encode(w, val.Elem(), quote, cycleDepth, cycleSeen)
+ encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
}
case reflect.Struct:
encodeWriteByte(w, '{')
@@ -229,9 +239,9 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
encodeWriteByte(w, ',')
}
empty = false
- encodeStringFromString(w, field.Name)
+ encodeStringFromString(w, escaper, field.Name)
encodeWriteByte(w, ':')
- encode(w, fVal, field.Quote, cycleDepth, cycleSeen)
+ encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen)
}
encodeWriteByte(w, '}')
case reflect.Map:
@@ -263,15 +273,16 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
kvs := make([]kv, val.Len())
iter := val.MapRange()
for i := 0; iter.Next(); i++ {
+ // TODO: Avoid buffering the map key
var k strings.Builder
- encode(&k, iter.Key(), false, cycleDepth, cycleSeen)
+ encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen)
kStr := k.String()
if kStr == "null" {
kStr = `""`
}
if !strings.HasPrefix(kStr, `"`) {
k.Reset()
- encodeStringFromString(&k, kStr)
+ encodeStringFromString(&k, escaper, kStr)
kStr = k.String()
}
kvs[i].K = kStr
@@ -287,7 +298,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
encodeWriteString(w, kv.K)
encodeWriteByte(w, ':')
- encode(w, kv.V, false, cycleDepth, cycleSeen)
+ encode(w, kv.V, escaper, false, cycleDepth, cycleSeen)
}
encodeWriteByte(w, '}')
case reflect.Slice:
@@ -341,10 +352,10 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
cycleSeen[ptr] = struct{}{}
defer delete(cycleSeen, ptr)
}
- encodeArray(w, val, cycleDepth, cycleSeen)
+ encodeArray(w, val, escaper, cycleDepth, cycleSeen)
}
case reflect.Array:
- encodeArray(w, val, cycleDepth, cycleSeen)
+ encodeArray(w, val, escaper, cycleDepth, cycleSeen)
case reflect.Pointer:
if val.IsNil() {
encodeWriteString(w, "null")
@@ -360,7 +371,7 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
cycleSeen[ptr] = struct{}{}
defer delete(cycleSeen, ptr)
}
- encode(w, val.Elem(), quote, cycleDepth, cycleSeen)
+ encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
}
default:
panic(encodeError{&EncodeTypeError{
@@ -370,21 +381,21 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe
}
}
-func encodeStringFromString(w io.Writer, str string) {
+func encodeStringFromString(w io.Writer, escaper BackslashEscaper, str string) {
encodeWriteByte(w, '"')
for _, c := range str {
- if _, err := writeStringChar(w, c, BackslashEscapeNone, nil); err != nil {
+ if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil {
panic(encodeError{err})
}
}
encodeWriteByte(w, '"')
}
-func encodeStringFromBytes(w io.Writer, str []byte) {
+func encodeStringFromBytes(w io.Writer, escaper BackslashEscaper, str []byte) {
encodeWriteByte(w, '"')
for i := 0; i < len(str); {
c, size := utf8.DecodeRune(str[i:])
- if _, err := writeStringChar(w, c, BackslashEscapeNone, nil); err != nil {
+ if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil {
panic(encodeError{err})
}
i += size
@@ -392,14 +403,14 @@ func encodeStringFromBytes(w io.Writer, str []byte) {
encodeWriteByte(w, '"')
}
-func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) {
+func encodeArray(w io.Writer, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) {
encodeWriteByte(w, '[')
n := val.Len()
for i := 0; i < n; i++ {
if i > 0 {
encodeWriteByte(w, ',')
}
- encode(w, val.Index(i), false, cycleDepth, cycleSeen)
+ encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen)
}
encodeWriteByte(w, ']')
}