diff options
Diffstat (limited to 'encode.go')
-rw-r--r-- | encode.go | 42 |
1 files changed, 35 insertions, 7 deletions
@@ -48,7 +48,7 @@ func Encode(w io.Writer, obj any) (err error) { } } }() - encode(w, reflect.ValueOf(obj), false, 0, map[unsafe.Pointer]struct{}{}) + encode(w, reflect.ValueOf(obj), false, 0, map[any]struct{}{}) if f, ok := w.(interface{ Flush() error }); ok { return f.Flush() } @@ -63,7 +63,7 @@ var ( const startDetectingCyclesAfter = 1000 -func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { +func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) { if !val.IsValid() { encodeWriteString(w, "null") return @@ -83,7 +83,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteString(w, "null") return } - if err := obj.EncodeJSON(w); err != nil { + validator := &ReEncoder{Out: w} + if err := obj.EncodeJSON(validator); err != nil { + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "EncodeJSON", + }}) + } + if err := validator.Close(); err != nil { panic(encodeError{err}) } @@ -102,9 +110,17 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } dat, err := obj.MarshalJSON() if err != nil { + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "MarshalJSON", + }}) + } + validator := &ReEncoder{Out: w} + if _, err := validator.Write(dat); err != nil { panic(encodeError{err}) } - if _, err := w.Write(dat); err != nil { + if err := validator.Close(); err != nil { panic(encodeError{err}) } @@ -123,7 +139,11 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe } text, err := obj.MarshalText() if err != nil { - panic(encodeError{err}) + panic(encodeError{&EncodeMethodError{ + Type: val.Type(), + Err: err, + SourceFunc: "MarshalText", + }}) } encodeString(w, text) @@ -302,7 +322,15 @@ func encode(w io.Writer, val reflect.Value, quote bool, cycleDepth uint, cycleSe encodeWriteByte(w, '"') default: if cycleDepth++; cycleDepth > startDetectingCyclesAfter { - ptr := val.UnsafePointer() + // For slices, val.UnsafePointer() doesn't return a pointer to the slice header + // or anything like that, it returns a pointer *to the first element in the + // slice*. That means that the pointer isn't enough to uniquely identify the + // slice! So we pair the pointer with the length of the slice, which is + // sufficient. + ptr := struct { + ptr unsafe.Pointer + len int + }{val.UnsafePointer(), val.Len()} if _, seen := cycleSeen[ptr]; seen { panic(encodeError{&EncodeValueError{ Value: val, @@ -353,7 +381,7 @@ func encodeString[T interface{ []byte | string }](w io.Writer, str T) { encodeWriteByte(w, '"') } -func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[unsafe.Pointer]struct{}) { +func encodeArray(w io.Writer, val reflect.Value, cycleDepth uint, cycleSeen map[any]struct{}) { encodeWriteByte(w, '[') n := val.Len() for i := 0; i < n; i++ { |