summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2023-01-27 01:24:02 -0700
committerLuke Shumaker <lukeshu@lukeshu.com>2023-01-30 22:00:25 -0700
commitd5b1b73eaaa060ef468f20d8b9eed029eb60ce45 (patch)
tree6b85c8e28ae10cd5ae5de1242bcc34cc91b6a183
parent2828fa21c0ffd2a32a108b37c0417b01abc42929 (diff)
encode: Don't use panic for flow-control
-rw-r--r--encode.go308
-rw-r--r--encode_string.go38
2 files changed, 216 insertions, 130 deletions
diff --git a/encode.go b/encode.go
index c5a29b3..57f3852 100644
--- a/encode.go
+++ b/encode.go
@@ -32,22 +32,6 @@ type Encodable interface {
EncodeJSON(w io.Writer) error
}
-type encodeError struct {
- Err error
-}
-
-func encodeWriteByte(w io.ByteWriter, b byte) {
- if err := w.WriteByte(b); err != nil {
- panic(encodeError{err})
- }
-}
-
-func encodeWriteString(w io.StringWriter, str string) {
- if _, err := w.WriteString(str); err != nil {
- panic(encodeError{err})
- }
-}
-
// An Encoder encodes and writes values to a stream of JSON elements.
//
// Encoder is analogous to, and has a similar API to the standar
@@ -93,22 +77,19 @@ func NewEncoder(w io.Writer) *Encoder {
//
// [documentation for encoding/json.Marshal]: https://pkg.go.dev/encoding/json@go1.18#Marshal
func (enc *Encoder) Encode(obj any) (err error) {
- defer func() {
- if r := recover(); r != nil {
- if e, ok := r.(encodeError); ok {
- err = e.Err
- } else {
- panic(r)
- }
- }
- }()
- encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{})
+ if err := encode(enc.w, reflect.ValueOf(obj), enc.w.BackslashEscape, false, 0, map[any]struct{}{}); err != nil {
+ return err
+ }
if enc.closeAfterEncode {
return enc.w.Close()
}
return nil
}
+func discardInt(_ int, err error) error {
+ return err
+}
+
var (
encodableType = reflect.TypeOf((*Encodable)(nil)).Elem()
jsonMarshalerType = reflect.TypeOf((*json.Marshaler)(nil)).Elem()
@@ -117,10 +98,9 @@ var (
const startDetectingCyclesAfter = 1000
-func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) {
+func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, quote bool, cycleDepth uint, cycleSeen map[any]struct{}) error {
if !val.IsValid() {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
switch {
@@ -129,29 +109,27 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
fallthrough
case val.Type().Implements(encodableType):
if val.Kind() == reflect.Pointer && val.IsNil() {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
obj, ok := val.Interface().(Encodable)
if !ok {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
// Use a sub-ReEncoder to check that it's a full element.
validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper})
if err := obj.EncodeJSON(validator); err != nil {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "EncodeJSON",
Err: err,
- }})
+ }
}
if err := validator.Close(); err != nil && !errors.Is(err, iofs.ErrClosed) {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "EncodeJSON",
Err: err,
- }})
+ }
}
case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(jsonMarshalerType):
@@ -159,37 +137,35 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
fallthrough
case val.Type().Implements(jsonMarshalerType):
if val.Kind() == reflect.Pointer && val.IsNil() {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
obj, ok := val.Interface().(json.Marshaler)
if !ok {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
dat, err := obj.MarshalJSON()
if err != nil {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "MarshalJSON",
Err: err,
- }})
+ }
}
// Use a sub-ReEncoder to check that it's a full element.
validator := NewReEncoder(w, ReEncoderConfig{BackslashEscape: escaper})
if _, err := validator.Write(dat); err != nil {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "MarshalJSON",
Err: err,
- }})
+ }
}
if err := validator.Close(); err != nil {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "MarshalJSON",
Err: err,
- }})
+ }
}
case val.Kind() != reflect.Pointer && val.CanAddr() && reflect.PointerTo(val.Type()).Implements(textMarshalerType):
@@ -197,61 +173,86 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
fallthrough
case val.Type().Implements(textMarshalerType):
if val.Kind() == reflect.Pointer && val.IsNil() {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
obj, ok := val.Interface().(encoding.TextMarshaler)
if !ok {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
text, err := obj.MarshalText()
if err != nil {
- panic(encodeError{&EncodeMethodError{
+ return &EncodeMethodError{
Type: val.Type(),
SourceFunc: "MarshalText",
Err: err,
- }})
+ }
+ }
+ if err := encodeStringFromBytes(w, escaper, text); err != nil {
+ return err
}
- encodeStringFromBytes(w, escaper, text)
-
default:
switch val.Kind() {
case reflect.Bool:
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
if val.Bool() {
- encodeWriteString(w, "true")
+ if _, err := w.WriteString("true"); err != nil {
+ return err
+ }
} else {
- encodeWriteString(w, "false")
+ if _, err := w.WriteString("false"); err != nil {
+ return err
+ }
}
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ }
+ if _, err := w.WriteString(strconv.FormatInt(val.Int(), 10)); err != nil {
+ return err
}
- encodeWriteString(w, strconv.FormatInt(val.Int(), 10))
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ }
+ if _, err := w.WriteString(strconv.FormatUint(val.Uint(), 10)); err != nil {
+ return err
}
- encodeWriteString(w, strconv.FormatUint(val.Uint(), 10))
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
case reflect.Float32, reflect.Float64:
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ }
+ if err := encodeTODO(w, val); err != nil {
+ return err
}
- encodeTODO(w, val)
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
case reflect.String:
if val.Type() == numberType {
@@ -260,29 +261,47 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
numStr = "0"
}
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ }
+ if _, err := w.WriteString(numStr); err != nil {
+ return err
}
- encodeWriteString(w, numStr)
if quote {
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
}
} else {
if quote {
var buf bytes.Buffer
- encodeStringFromString(&buf, escaper, val.String())
- encodeStringFromBytes(w, escaper, buf.Bytes())
+ if err := encodeStringFromString(&buf, escaper, val.String()); err != nil {
+ return err
+ }
+ if err := encodeStringFromBytes(w, escaper, buf.Bytes()); err != nil {
+ return err
+ }
} else {
- encodeStringFromString(w, escaper, val.String())
+ if err := encodeStringFromString(w, escaper, val.String()); err != nil {
+ return err
+ }
}
}
case reflect.Interface:
if val.IsNil() {
- encodeWriteString(w, "null")
+ if _, err := w.WriteString("null"); err != nil {
+ return err
+ }
} else {
- encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
+ if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
}
case reflect.Struct:
- encodeWriteByte(w, '{')
+ if err := w.WriteByte('{'); err != nil {
+ return err
+ }
empty := true
for _, field := range indexStruct(val.Type()).byPos {
fVal, err := val.FieldByIndexErr(field.Path)
@@ -293,35 +312,45 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
continue
}
if !empty {
- encodeWriteByte(w, ',')
+ if err := w.WriteByte(','); err != nil {
+ return err
+ }
}
empty = false
- encodeStringFromString(w, escaper, field.Name)
- encodeWriteByte(w, ':')
- encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen)
+ if err := encodeStringFromString(w, escaper, field.Name); err != nil {
+ return err
+ }
+ if err := w.WriteByte(':'); err != nil {
+ return err
+ }
+ if err := encode(w, fVal, escaper, field.Quote, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
+ }
+ if err := w.WriteByte('}'); err != nil {
+ return err
}
- encodeWriteByte(w, '}')
case reflect.Map:
if val.IsNil() {
- encodeWriteString(w, "null")
- return
+ return discardInt(w.WriteString("null"))
}
if val.Len() == 0 {
- encodeWriteString(w, "{}")
- return
+ return discardInt(w.WriteString("{}"))
}
if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
ptr := val.UnsafePointer()
if _, seen := cycleSeen[ptr]; seen {
- panic(encodeError{&EncodeValueError{
+ return &EncodeValueError{
Value: val,
Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
- }})
+ }
}
cycleSeen[ptr] = struct{}{}
defer delete(cycleSeen, ptr)
}
- encodeWriteByte(w, '{')
+ if err := w.WriteByte('{'); err != nil {
+ return err
+ }
type kv struct {
K string
@@ -332,14 +361,18 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
for i := 0; iter.Next(); i++ {
// TODO: Avoid buffering the map key
var k strings.Builder
- encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen)
+ if err := encode(&k, iter.Key(), escaper, false, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
kStr := k.String()
if kStr == "null" {
kStr = `""`
}
if !strings.HasPrefix(kStr, `"`) {
k.Reset()
- encodeStringFromString(&k, escaper, kStr)
+ if err := encodeStringFromString(&k, escaper, kStr); err != nil {
+ return err
+ }
kStr = k.String()
}
kvs[i].K = kStr
@@ -351,17 +384,29 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
for i, kv := range kvs {
if i > 0 {
- encodeWriteByte(w, ',')
+ if err := w.WriteByte(','); err != nil {
+ return err
+ }
+ }
+ if _, err := w.WriteString(kv.K); err != nil {
+ return err
+ }
+ if err := w.WriteByte(':'); err != nil {
+ return err
+ }
+ if err := encode(w, kv.V, escaper, false, cycleDepth, cycleSeen); err != nil {
+ return err
}
- encodeWriteString(w, kv.K)
- encodeWriteByte(w, ':')
- encode(w, kv.V, escaper, false, cycleDepth, cycleSeen)
}
- encodeWriteByte(w, '}')
+ if err := w.WriteByte('}'); err != nil {
+ return err
+ }
case reflect.Slice:
switch {
case val.IsNil():
- encodeWriteString(w, "null")
+ if _, err := w.WriteString("null"); err != nil {
+ return err
+ }
case val.Type().Elem().Kind() == reflect.Uint8 && !(false ||
val.Type().Elem().Implements(encodableType) ||
reflect.PointerTo(val.Type().Elem()).Implements(encodableType) ||
@@ -369,11 +414,13 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
reflect.PointerTo(val.Type().Elem()).Implements(jsonMarshalerType) ||
val.Type().Elem().Implements(textMarshalerType) ||
reflect.PointerTo(val.Type().Elem()).Implements(textMarshalerType)):
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
enc := base64.NewEncoder(base64.StdEncoding, w)
if val.CanConvert(byteSliceType) {
if _, err := enc.Write(val.Convert(byteSliceType).Interface().([]byte)); err != nil {
- panic(encodeError{err})
+ return err
}
} else {
// TODO: Surely there's a better way.
@@ -381,14 +428,16 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
var buf [1]byte
buf[0] = val.Index(i).Convert(byteType).Interface().(byte)
if _, err := enc.Write(buf[:]); err != nil {
- panic(encodeError{err})
+ return err
}
}
}
if err := enc.Close(); err != nil {
- panic(encodeError{err})
+ return err
+ }
+ if err := w.WriteByte('"'); err != nil {
+ return err
}
- encodeWriteByte(w, '"')
default:
if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
// For slices, val.UnsafePointer() doesn't return a pointer to the slice header
@@ -401,61 +450,80 @@ func encode(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, q
len int
}{val.UnsafePointer(), val.Len()}
if _, seen := cycleSeen[ptr]; seen {
- panic(encodeError{&EncodeValueError{
+ return &EncodeValueError{
Value: val,
Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
- }})
+ }
}
cycleSeen[ptr] = struct{}{}
defer delete(cycleSeen, ptr)
}
- encodeArray(w, val, escaper, cycleDepth, cycleSeen)
+ if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
}
case reflect.Array:
- encodeArray(w, val, escaper, cycleDepth, cycleSeen)
+ if err := encodeArray(w, val, escaper, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
case reflect.Pointer:
if val.IsNil() {
- encodeWriteString(w, "null")
+ if _, err := w.WriteString("null"); err != nil {
+ return err
+ }
} else {
if cycleDepth++; cycleDepth > startDetectingCyclesAfter {
ptr := val.UnsafePointer()
if _, seen := cycleSeen[ptr]; seen {
- panic(encodeError{&EncodeValueError{
+ return &EncodeValueError{
Value: val,
Str: fmt.Sprintf("encountered a cycle via %s", val.Type()),
- }})
+ }
}
cycleSeen[ptr] = struct{}{}
defer delete(cycleSeen, ptr)
}
- encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen)
+ if err := encode(w, val.Elem(), escaper, quote, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
}
default:
- panic(encodeError{&EncodeTypeError{
+ return &EncodeTypeError{
Type: val.Type(),
- }})
+ }
}
}
+ return nil
}
-func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) {
- encodeWriteByte(w, '[')
+func encodeArray(w internal.AllWriter, val reflect.Value, escaper BackslashEscaper, cycleDepth uint, cycleSeen map[any]struct{}) error {
+ if err := w.WriteByte('['); err != nil {
+ return err
+ }
n := val.Len()
for i := 0; i < n; i++ {
if i > 0 {
- encodeWriteByte(w, ',')
+ if err := w.WriteByte(','); err != nil {
+ return err
+ }
}
- encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen)
+ if err := encode(w, val.Index(i), escaper, false, cycleDepth, cycleSeen); err != nil {
+ return err
+ }
+ }
+ if err := w.WriteByte(']'); err != nil {
+ return err
}
- encodeWriteByte(w, ']')
+ return nil
}
-func encodeTODO(w io.Writer, val reflect.Value) {
+func encodeTODO(w io.Writer, val reflect.Value) error {
bs, err := json.Marshal(val.Interface())
if err != nil {
- panic(encodeError{err})
+ return err
}
if _, err := w.Write(bs); err != nil {
- panic(encodeError{err})
+ return err
}
+ return nil
}
diff --git a/encode_string.go b/encode_string.go
index 831a038..12f934e 100644
--- a/encode_string.go
+++ b/encode_string.go
@@ -83,29 +83,47 @@ func writeStringChar(w internal.AllWriter, c rune, wasEscaped BackslashEscapeMod
}
}
-func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) {
- encodeWriteByte(w, '"')
+func encodeStringFromString(w internal.AllWriter, escaper BackslashEscaper, str string) error {
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
for _, c := range str {
if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil {
- panic(encodeError{err})
+ return err
}
}
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ return nil
}
-func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) {
- encodeWriteByte(w, '"')
+func encodeStringFromBytes(w internal.AllWriter, escaper BackslashEscaper, str []byte) error {
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
for i := 0; i < len(str); {
c, size := utf8.DecodeRune(str[i:])
if _, err := writeStringChar(w, c, BackslashEscapeNone, escaper); err != nil {
- panic(encodeError{err})
+ return err
}
i += size
}
- encodeWriteByte(w, '"')
+ if err := w.WriteByte('"'); err != nil {
+ return err
+ }
+ return nil
}
func init() {
- internal.EncodeStringFromString = func(w io.Writer, s string) { encodeStringFromString(internal.NewAllWriter(w), nil, s) }
- internal.EncodeStringFromBytes = func(w io.Writer, s []byte) { encodeStringFromBytes(internal.NewAllWriter(w), nil, s) }
+ internal.EncodeStringFromString = func(w io.Writer, s string) {
+ if err := encodeStringFromString(internal.NewAllWriter(w), nil, s); err != nil {
+ panic(err)
+ }
+ }
+ internal.EncodeStringFromBytes = func(w io.Writer, s []byte) {
+ if err := encodeStringFromBytes(internal.NewAllWriter(w), nil, s); err != nil {
+ panic(err)
+ }
+ }
}