summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2017-09-07 23:28:47 -0400
committerLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 16:55:55 -0400
commitb58ea042394c66eabe67c3f58906c5d76b1e119d (patch)
treedb1f55fb187504c7866b81c33ce0dc1489135da5
parente7b6b3a7ae2e53d807e14697708c4110c038303b (diff)
nslcd_{server,systemd}: FIX, BREAKING CHANGE: add limits
Added types: nslcd_server: type Limits struct { ...} nslcd_server: type Conn interface{ ... } // a subset of net.Conn nslcd_server.HandleRequest() signature change: -func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) { +func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) { The `limits Limits` argument is added, and `conn Conn` replaces `in io.Reader` and `out io.Writer`. nslcd_systemd.Main() signature change: -func Main(backend Backend) uint8 { +func Main(backend Backend, limits nslcd_server.Limits) uint8 { The `limits Limits` argument is added.
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen64
-rw-r--r--nslcd_systemd/misc_test.go4
-rw-r--r--nslcd_systemd/nslcd_systemd.go16
3 files changed, 67 insertions, 17 deletions
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index d34db88..af36e84 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -26,8 +26,8 @@ package nslcd_server
import (
"fmt"
- "io"
"os"
+ "time"
"golang.org/x/sys/unix"
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
@@ -41,8 +41,32 @@ func maybePanic(err error) {
}
}
+type Limits struct {
+ // What is the maximum total amount of time that we spend
+ // handling a single request. This includes both the time
+ // reading the request and the time creating and writing the
+ // response.
+ Timeout time.Duration
+
+ // How long can we spend reading a request?
+ ReadTimeout time.Duration
+
+ // How long can we spend writing a response?
+ WriteTimeout time.Duration
+}
+
+type Conn interface {
+ // This is a subset of net.Conn; semantics are the same.
+
+ Read(b []byte) (n int, err error)
+ Write(b []byte) (n int, err error)
+ SetDeadline(t time.Time) error
+ SetReadDeadline(t time.Time) error
+ SetWriteDeadline(t time.Time) error
+}
+
// Handle a request to nslcd
-func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred) (err error) {
+func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) {
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
@@ -54,13 +78,27 @@ func HandleRequest(backend Backend, in io.Reader, out io.Writer, cred unix.Ucred
}
}()
+ now := time.Now()
+ if limits.Timeout != 0 {
+ err = conn.SetDeadline(now.Add(limits.Timeout))
+ if err != nil {
+ return err
+ }
+ }
+ if limits.ReadTimeout != 0 {
+ err = conn.SetReadDeadline(now.Add(limits.ReadTimeout))
+ if err != nil {
+ return err
+ }
+ }
+
var version int32
- maybePanic(p.Read(in, &version))
+ maybePanic(p.Read(conn, &version))
if version != p.NSLCD_VERSION {
return p.NslcdError(fmt.Sprintf("Version mismatch: server=%#08x client=%#08x", p.NSLCD_VERSION, version))
}
var action int32
- maybePanic(p.Read(in, &action))
+ maybePanic(p.Read(conn, &action))
switch action {
$(
@@ -68,7 +106,7 @@ while read -r request; do
cat <<EOT
case p.NSLCD_ACTION_${request^^}:
var req p.Request_${request}
- maybePanic(p.Read(in, &req))
+ maybePanic(p.Read(conn, &req))
$(
case "$request" in
PAM_Authentication)
@@ -94,19 +132,25 @@ while read -r request; do
;;
esac
)
- maybePanic(p.Write(out, p.NSLCD_VERSION))
- maybePanic(p.Write(out, action))
+ if limits.WriteTimeout != 0 {
+ err = conn.SetWriteDeadline(time.Now().Add(limits.WriteTimeout))
+ if err != nil {
+ return err
+ }
+ }
+ maybePanic(p.Write(conn, p.NSLCD_VERSION))
+ maybePanic(p.Write(conn, action))
ch := backend.${request}(cred, req)
for result := range ch {
if err == nil {
- err = p.Write(out, p.NSLCD_RESULT_BEGIN)
+ err = p.Write(conn, p.NSLCD_RESULT_BEGIN)
}
if err == nil {
- err = p.Write(out, result)
+ err = p.Write(conn, result)
}
}
maybePanic(err)
- maybePanic(p.Write(out, p.NSLCD_RESULT_END))
+ maybePanic(p.Write(conn, p.NSLCD_RESULT_END))
return nil
EOT
done < "$requests"
diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go
index d2b2a7e..a910cd9 100644
--- a/nslcd_systemd/misc_test.go
+++ b/nslcd_systemd/misc_test.go
@@ -77,7 +77,9 @@ func testBadClient(t *testContext, backend nslcd_systemd.Backend, toclose bool)
// server //////////////////////////////////////////////////////////////
errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock"))
go func() {
- evExitServer <- nslcd_systemd.Main(backend)
+ evExitServer <- nslcd_systemd.Main(backend, nslcd_server.Limits{
+ Timeout: 1 * time.Second,
+ })
}()
// client/driver ///////////////////////////////////////////////////////
diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go
index 97991c8..8bae046 100644
--- a/nslcd_systemd/nslcd_systemd.go
+++ b/nslcd_systemd/nslcd_systemd.go
@@ -24,11 +24,15 @@
//
// package main
//
-// import "nslcd/systemd"
+// import (
+// "git.lukeshu.com/go/libnslcd/nslcd_server"
+// "git.lukeshu.com/go/libnslcd/nslcd_systemd"
+// )
//
// func main() {
// backend := ...
-// os.Exit(int(nslcd_systemd.Main(backend)))
+// limits := nslcd_server.Limits{ ... }
+// os.Exit(int(nslcd_systemd.Main(backend, limits)))
// }
package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd"
@@ -88,7 +92,7 @@ func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) {
return
}
-func handler(conn *net.UnixConn, backend nslcd_server.Backend) {
+func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn) {
defer conn.Close()
cred, err := getpeercred(conn)
if err != nil {
@@ -97,13 +101,13 @@ func handler(conn *net.UnixConn, backend nslcd_server.Backend) {
sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
cred.Pid, cred.Uid, cred.Gid))
}
- err = nslcd_server.HandleRequest(backend, conn, conn, cred)
+ err = nslcd_server.HandleRequest(backend, limits, conn, cred)
if err != nil {
sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
}
}
-func Main(backend Backend) uint8 {
+func Main(backend Backend, limits nslcd_server.Limits) uint8 {
defer sd_daemon.Recover()
var err error = nil
@@ -165,7 +169,7 @@ func Main(backend Backend) uint8 {
go func() {
defer sd_daemon.Recover()
defer wg.Done()
- handler(conn.(*net.UnixConn), backend)
+ handler(backend, limits, conn.(*net.UnixConn))
}()
}
}