summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2017-09-04 19:11:14 -0400
committerLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 16:55:55 -0400
commit542c732b94e0a5e7c02fd209a60bd068dbbfa03b (patch)
tree780a49011174abc3f01f5110fc590d107989518a
parent3e1d4d0c562ab99e5a81030a08d224e8174152b9 (diff)
nslcd_proto: BREAKING CHANGE: Rethink the panic strategy
nslcd_proto.Read() and .Write() panic() with any errors that they may need to emit. This made composition really simple, I was OK with it being against the normal Go style. But, I'm not happy with it anymore; have them return errors now. This leads us in to nslcd_server.HandleRequest() using those panics for control flow. Add a maybePanic(error) function to wrap all of the proto.Read() and proto.Write() calls to restore the panicing behavior.
-rw-r--r--nslcd_proto/io.go100
-rw-r--r--nslcd_proto/nslcd_h.go24
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen26
3 files changed, 98 insertions, 52 deletions
diff --git a/nslcd_proto/io.go b/nslcd_proto/io.go
index a2adade..daced37 100644
--- a/nslcd_proto/io.go
+++ b/nslcd_proto/io.go
@@ -37,6 +37,10 @@ func (o NslcdError) Error() string {
return string(o)
}
+func npanic(err NslcdError) {
+ panic(err)
+}
+
// An nslcdObject is an object with a different network representation
// than a naive structure.
type nslcdObject interface {
@@ -53,9 +57,45 @@ type nslcdObjectPtr interface {
nslcdRead(fd io.Reader)
}
+// Write an object to a stream. Any errors returned are of type
+// NslcdError.
+func Write(fd io.Writer, data interface{}) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ switch r := r.(type) {
+ case NslcdError:
+ err = r
+ default:
+ panic(r)
+ }
+ }
+ }()
+ write(fd, data)
+
+ return err
+}
+
+// Read an object from a stream. Any errors returned are of type
+// NslcdError.
+func Read(fd io.Reader, data interface{}) (err error) {
+ defer func() {
+ if r := recover(); r != nil {
+ switch r := r.(type) {
+ case NslcdError:
+ err = r
+ default:
+ panic(r)
+ }
+ }
+ }()
+ read(fd, data)
+
+ return err
+}
+
// Write an object to a stream. In the event of an error, this
-// function may panic! Handle it!
-func Write(fd io.Writer, data interface{}) {
+// function may panic(NslcdError)! Handle it!
+func write(fd io.Writer, data interface{}) {
switch data := data.(type) {
// basic data types
case nslcdObject:
@@ -64,22 +104,22 @@ func Write(fd io.Writer, data interface{}) {
if len(data) > 0 {
_, err := fd.Write(data)
if err != nil {
- panic(err)
+ npanic(NslcdError(err.Error()))
}
}
case int32:
err := binary.Write(fd, binary.BigEndian, data)
if err != nil {
- panic(err)
+ npanic(NslcdError(err.Error()))
}
// composite datatypes
case string:
- Write(fd, int32(len(data)))
- Write(fd, []byte(data))
+ write(fd, int32(len(data)))
+ write(fd, []byte(data))
case []string:
- Write(fd, int32(len(data)))
+ write(fd, int32(len(data)))
for _, item := range data {
- Write(fd, item)
+ write(fd, item)
}
case net.IP:
var af int32 = -1
@@ -95,20 +135,20 @@ func Write(fd io.Writer, data interface{}) {
} else {
bytes = data
}
- Write(fd, af)
- Write(fd, int32(len(bytes)))
- Write(fd, bytes)
+ write(fd, af)
+ write(fd, int32(len(bytes)))
+ write(fd, bytes)
case []net.IP:
- Write(fd, int32(len(data)))
+ write(fd, int32(len(data)))
for _, item := range data {
- Write(fd, item)
+ write(fd, item)
}
default:
v := reflect.ValueOf(data)
switch v.Kind() {
case reflect.Struct:
for i, n := 0, v.NumField(); i < n; i++ {
- Write(fd, v.Field(i).Interface())
+ write(fd, v.Field(i).Interface())
}
default:
panic(fmt.Sprintf("Invalid structure to write NSLCD protocol data from: %T ( %#v )", data, data))
@@ -117,8 +157,8 @@ func Write(fd io.Writer, data interface{}) {
}
// Read an object from a stream. In the event of an error, this
-// function may panic! Handle it!
-func Read(fd io.Reader, data interface{}) {
+// function may panic(NslcdError)! Handle it!
+func read(fd io.Reader, data interface{}) {
switch data := data.(type) {
// basic data types
case nslcdObjectPtr:
@@ -127,31 +167,31 @@ func Read(fd io.Reader, data interface{}) {
if len(*data) > 0 {
_, err := io.ReadFull(fd, *data)
if err != nil {
- panic(err)
+ npanic(NslcdError(err.Error()))
}
}
case *int32:
err := binary.Read(fd, binary.BigEndian, data)
if err != nil {
- panic(err)
+ npanic(NslcdError(err.Error()))
}
// composite datatypes
case *string:
var len int32
- Read(fd, &len)
+ read(fd, &len)
buf := make([]byte, len) // BUG(lukeshu): Read: `string` length needs sanity checked
- Read(fd, &buf)
+ read(fd, &buf)
*data = string(buf)
case *[]string:
var num int32
- Read(fd, &num)
+ read(fd, &num)
*data = make([]string, num) // BUG(lukeshu): Read: `[]string` length needs sanity checked
for i := 0; i < int(num); i++ {
- Read(fd, &((*data)[i]))
+ read(fd, &((*data)[i]))
}
case *net.IP:
var af int32
- Read(fd, &af)
+ read(fd, &af)
var _len int32
switch af {
case unix.AF_INET:
@@ -159,22 +199,22 @@ func Read(fd io.Reader, data interface{}) {
case unix.AF_INET6:
_len = net.IPv6len
default:
- panic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af)))
+ npanic(NslcdError(fmt.Sprintf("incorrect address family specified: %d", af)))
}
var len int32
- Read(fd, &len)
+ read(fd, &len)
if len != _len {
- panic(NslcdError(fmt.Sprintf("address length incorrect: %d", len)))
+ npanic(NslcdError(fmt.Sprintf("address length incorrect: %d", len)))
}
buf := make([]byte, len)
- Read(fd, &buf)
+ read(fd, &buf)
*data = buf
case *[]net.IP:
var num int32
- Read(fd, &num)
+ read(fd, &num)
*data = make([]net.IP, num) // BUG(lukeshu): Read: `[]net.IP` length needs sanity checked
for i := 0; i < int(num); i++ {
- Read(fd, &((*data)[i]))
+ read(fd, &((*data)[i]))
}
default:
p := reflect.ValueOf(data)
@@ -183,7 +223,7 @@ func Read(fd io.Reader, data interface{}) {
panic(fmt.Sprintf("The argument to nslcd_proto.Read() must be a pointer: %T ( %#v )", data, data))
}
for i, n := 0, v.NumField(); i < n; i++ {
- Read(fd, v.Field(i).Addr().Interface())
+ read(fd, v.Field(i).Addr().Interface())
}
}
}
diff --git a/nslcd_proto/nslcd_h.go b/nslcd_proto/nslcd_h.go
index cb210cd..d8eee9f 100644
--- a/nslcd_proto/nslcd_h.go
+++ b/nslcd_proto/nslcd_h.go
@@ -1,5 +1,5 @@
// This file is based heavily on nslcd.h from nss-pam-ldapd
-// Copyright (C) 2015 Luke Shumaker
+// Copyright (C) 2015, 2017 Luke Shumaker
/*
nslcd.h - file describing client/server protocol
@@ -171,19 +171,19 @@ func (data Netgroup_PartList) nslcdWrite(fd io.Writer) {
t = NSLCD_NETGROUP_TYPE_TRIPLE
}
if t < 0 {
- panic("unrecognized netgroup type")
+ panic(fmt.Sprintf("unrecognized netgroup type: %#08x", t))
}
- Write(fd, t)
- Write(fd, part)
+ write(fd, t)
+ write(fd, part)
}
- Write(fd, NSLCD_NETGROUP_TYPE_END)
+ write(fd, NSLCD_NETGROUP_TYPE_END)
}
func (data *Netgroup_PartList) nslcdRead(fd io.Reader) {
*data = make([]interface{}, 0)
for {
var t int32
var v interface{}
- Read(fd, &t)
+ read(fd, &t)
switch t {
case NSLCD_NETGROUP_TYPE_NETGROUP:
v = Netgroup_Netgroup{}
@@ -192,9 +192,9 @@ func (data *Netgroup_PartList) nslcdRead(fd io.Reader) {
case NSLCD_NETGROUP_TYPE_END:
return
default:
- panic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t)))
+ npanic(NslcdError(fmt.Sprintf("unrecognized netgroup type: %#08x", t)))
}
- Read(fd, &v)
+ read(fd, &v)
*data = append(*data, v)
}
}
@@ -384,20 +384,20 @@ type UserMod_Item struct {
type UserMod_ItemList []UserMod_Item
func (data UserMod_ItemList) nslcdWrite(fd io.Writer) {
for _, item := range data {
- Write(fd, item)
+ write(fd, item)
}
- Write(fd, NSLCD_USERMOD_END)
+ write(fd, NSLCD_USERMOD_END)
}
func (data *UserMod_ItemList) nslcdRead(fd io.Reader) {
*data = make([]UserMod_Item, 0)
for {
var t int32
- Read(fd, &t)
+ read(fd, &t)
if t == NSLCD_USERMOD_END {
return
}
var v UserMod_Item
- Read(fd, &v)
+ read(fd, &v)
*data = append(*data, v)
}
}
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index e7e2dcc..40e00c0 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -1,6 +1,6 @@
#!/usr/bin/env bash
# -*- Mode: Go -*-
-# Copyright (C) 2015-2016 Luke Shumaker <lukeshu@sbcglobal.net>
+# Copyright (C) 2015-2017 Luke Shumaker <lukeshu@sbcglobal.net>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
@@ -35,12 +35,18 @@ import (
const sensitive = "<omitted-from-log>"
+func maybePanic(err error) {
+ if err != nil {
+ panic(err)
+ }
+}
+
// Handle a request to nslcd
func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) {
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
- case error:
+ case p.NslcdError:
err = r
default:
panic(r)
@@ -53,12 +59,12 @@ func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred
func handleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) {
var version int32
- p.Read(in, &version)
+ maybePanic(p.Read(in, &version))
if version != p.NSLCD_VERSION {
panic(p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version)))
}
var action int32
- p.Read(in, &action)
+ maybePanic(p.Read(in, &action))
ch := make(chan interface{})
switch action {
@@ -67,7 +73,7 @@ while read -r request; do
cat <<EOT
case p.NSLCD_ACTION_${request^^}:
var req p.Request_${request}
- p.Read(in, &req)
+ maybePanic(p.Read(in, &req))
$(
case "$request" in
PAM_Authentication)
@@ -107,13 +113,13 @@ done < "$requests"
close(ch)
panic(p.NslcdError(fmt.Sprintf("Unknown request action: %#08x", action)))
}
- p.Write(out, p.NSLCD_VERSION)
- p.Write(out, action)
+ maybePanic(p.Write(out, p.NSLCD_VERSION))
+ maybePanic(p.Write(out, action))
for result := range ch {
- p.Write(out, p.NSLCD_RESULT_BEGIN)
- p.Write(out, result)
+ maybePanic(p.Write(out, p.NSLCD_RESULT_BEGIN))
+ maybePanic(p.Write(out, result))
}
- p.Write(out, p.NSLCD_RESULT_END)
+ maybePanic(p.Write(out, p.NSLCD_RESULT_END))
}
EOF