summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nslcd_server/ctx.go46
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen26
-rwxr-xr-xnslcd_server/interface_backend.go.gen9
-rwxr-xr-xnslcd_server/type_nilbackend.go.gen7
-rw-r--r--nslcd_systemd/misc_test.go7
-rw-r--r--nslcd_systemd/nslcd_systemd.go23
6 files changed, 98 insertions, 20 deletions
diff --git a/nslcd_server/ctx.go b/nslcd_server/ctx.go
new file mode 100644
index 0000000..5214adc
--- /dev/null
+++ b/nslcd_server/ctx.go
@@ -0,0 +1,46 @@
+// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net>
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+// 02110-1301 USA
+
+package nslcd_server
+
+import (
+ "context"
+
+ "golang.org/x/sys/unix"
+)
+
+// contextKey is a value for use with context.WithValue. It's used as
+// a pointer so it fits in an interface{} without allocation.
+type contextKey struct {
+ name string
+}
+
+var (
+ // PeerCredKey is a context key. It can be used in backend
+ // methods to access the credentials of the client process.
+ // The associated value will be of type
+ // "golang.org/x/sys/unix".Ucred
+ PeerCredKey = &contextKey{"peercred"}
+)
+
+// PeerCredFromContext is a convenience function for
+//
+// cred, ok := ctx.Value(nslcd_server.PeerCredKey).(unix.Ucred)
+func PeerCredFromContext(ctx context.Context) (unix.Ucred, bool) {
+ cred, ok := ctx.Value(PeerCredKey).(unix.Ucred)
+ return cred, ok
+}
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index 00e9663..750a7b0 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -25,12 +25,12 @@ cat <<EOF | gofmt
package nslcd_server
import (
+ "context"
"fmt"
"io"
"os"
"time"
- "golang.org/x/sys/unix"
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
)
@@ -70,8 +70,9 @@ type Conn interface {
SetWriteDeadline(t time.Time) error
}
-// Handle a request to nslcd
-func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (err error) {
+// Handle a request to nslcd. The caller is responsible for
+// initializing the context with PeerCredKey.
+func HandleRequest(backend Backend, limits Limits, conn Conn, ctx context.Context) (err error) {
defer func() {
if r := recover(); r != nil {
switch r := r.(type) {
@@ -89,6 +90,11 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, cred unix.Ucred) (
if limits.Timeout != 0 {
deadlineAll = now.Add(limits.Timeout)
}
+ if deadline, ok := ctx.Deadline(); ok {
+ if deadlineAll.IsZero() || deadline.Before(deadlineAll) {
+ deadlineAll = deadline
+ }
+ }
if limits.ReadTimeout != 0 {
deadlineRead = now.Add(limits.ReadTimeout)
if !deadlineAll.IsZero() && deadlineAll.Before(deadlineRead) {
@@ -149,6 +155,7 @@ while read -r request; do
;;
esac
)
+
if limits.WriteTimeout != 0 {
deadlineWrite = time.Now().Add(limits.WriteTimeout)
if !deadlineAll.IsZero() && deadlineAll.Before(deadlineWrite) {
@@ -161,9 +168,18 @@ while read -r request; do
return err
}
}
+
+ var cancel context.CancelFunc
+ if deadline, ok := ctx.Deadline(); !ok || (!deadlineWrite.IsZero() && deadline.After(deadlineWrite)) {
+ ctx, cancel = context.WithDeadline(ctx, deadlineWrite)
+ } else {
+ ctx, cancel = context.WithCancel(ctx)
+ }
+ defer cancel()
+
maybePanic(p.Write(out, p.NSLCD_VERSION))
maybePanic(p.Write(out, action))
- ch := backend.${request}(cred, req)
+ ch := backend.${request}(ctx, req)
for result := range ch {
if err == nil {
err = p.Write(out, p.NSLCD_RESULT_BEGIN)
@@ -174,7 +190,7 @@ while read -r request; do
}
maybePanic(err)
maybePanic(p.Write(out, p.NSLCD_RESULT_END))
- return nil
+ return ctx.Err() // probably nil
EOT
done < "$requests"
)
diff --git a/nslcd_server/interface_backend.go.gen b/nslcd_server/interface_backend.go.gen
index 4749d0c..4ced5e7 100755
--- a/nslcd_server/interface_backend.go.gen
+++ b/nslcd_server/interface_backend.go.gen
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
-# Copyright (C) 2015 Luke Shumaker <lukeshu@sbcglobal.net>
+# Copyright (C) 2015, 2017 Luke Shumaker <lukeshu@sbcglobal.net>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
@@ -24,7 +24,8 @@ cat <<EOF | gofmt
package nslcd_server
import (
- "golang.org/x/sys/unix"
+ "context"
+
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
)
@@ -33,7 +34,7 @@ import (
// that the nslcd server may reply to is implemented simply as a
// method that returns a channel of the resulting values.
type Backend interface {
- $(sed -rn 's/([^_]+)(.*)/\1\2(unix.Ucred, p.Request_\1\2) <-chan p.\1/p' "$requests" | grep -v PAM)
- $(sed -rn 's/(PAM)(.*)/\1\2(unix.Ucred, p.Request_\1\2) <-chan p.\1\2/p' "$requests")
+ $(sed -rn 's/([^_]+)(.*)/\1\2(context.Context, p.Request_\1\2) <-chan p.\1/p' "$requests" | grep -v PAM)
+ $(sed -rn 's/(PAM)(.*)/\1\2(context.Context, p.Request_\1\2) <-chan p.\1\2/p' "$requests")
}
EOF
diff --git a/nslcd_server/type_nilbackend.go.gen b/nslcd_server/type_nilbackend.go.gen
index 0c6f4b5..b7ea372 100755
--- a/nslcd_server/type_nilbackend.go.gen
+++ b/nslcd_server/type_nilbackend.go.gen
@@ -24,7 +24,8 @@ cat <<EOF | gofmt
package nslcd_server
import (
- "golang.org/x/sys/unix"
+ "context"
+
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
)
@@ -35,8 +36,8 @@ import (
type NilBackend struct{}
$(
- re_in='^\t([^(]+)\(unix\.Ucred, ([^)]+)\) <-chan (\S+)$'
- re_out='func (o NilBackend) \1(unix.Ucred, \2) <-chan \3 { r := make(chan \3); close(r); return r }'
+ re_in='^\t([^(]+)\(context\.Context, ([^)]+)\) <-chan (\S+)$'
+ re_out='func (o NilBackend) \1(context.Context, \2) <-chan \3 { r := make(chan \3); close(r); return r }'
< "$interface" sed -rn "s/$re_in/$re_out/p"
)
diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go
index a75083d..bc7ace6 100644
--- a/nslcd_systemd/misc_test.go
+++ b/nslcd_systemd/misc_test.go
@@ -18,6 +18,7 @@
package nslcd_systemd_test
import (
+ "context"
"fmt"
"io/ioutil"
"net"
@@ -89,7 +90,7 @@ func testDriver(
// server //////////////////////////////////////////////////////////////
errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock"))
go func() {
- evExitServer <- nslcd_systemd.Main(backend, limits)
+ evExitServer <- nslcd_systemd.Main(backend, limits, context.Background())
}()
// client/driver ///////////////////////////////////////////////////////
@@ -133,7 +134,7 @@ func (o *NonLockingBackend) Init() error { return nil }
func (o *NonLockingBackend) Reload() error { return nil }
func (o *NonLockingBackend) Close() {}
-func (o *NonLockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
+func (o *NonLockingBackend) Passwd_All(ctx context.Context, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
ret := make(chan nslcd_proto.Passwd)
go func() {
defer close(ret)
@@ -170,7 +171,7 @@ func (o *LockingBackend) Close() {
o.NonLockingBackend.Close()
}
-func (o *LockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
+func (o *LockingBackend) Passwd_All(ctx context.Context, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
o.lock.RLock()
ret := make(chan nslcd_proto.Passwd)
go func() {
diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go
index 0999106..b2f8e28 100644
--- a/nslcd_systemd/nslcd_systemd.go
+++ b/nslcd_systemd/nslcd_systemd.go
@@ -25,6 +25,7 @@
// package main
//
// import (
+// "context"
// "os"
//
// "git.lukeshu.com/go/libnslcd/nslcd_server"
@@ -34,11 +35,13 @@
// func main() {
// backend := ...
// limits := nslcd_server.Limits{ ... }
-// os.Exit(int(nslcd_systemd.Main(backend, limits)))
+// ctx := context.Background()
+// os.Exit(int(nslcd_systemd.Main(backend, limits, ctx)))
// }
package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd"
import (
+ "context"
"fmt"
"net"
"os"
@@ -94,22 +97,26 @@ func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) {
return
}
-func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn) {
+func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn, ctx context.Context) {
defer conn.Close()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
cred, err := getpeercred(conn)
if err != nil {
sd_daemon.Log.Debug("Connection from unknown client")
} else {
sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
cred.Pid, cred.Uid, cred.Gid))
+ ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred)
}
- err = nslcd_server.HandleRequest(backend, limits, conn, cred)
+ err = nslcd_server.HandleRequest(backend, limits, conn, ctx)
if err != nil {
sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
}
}
-func Main(backend Backend, limits nslcd_server.Limits) uint8 {
+func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint8 {
defer sd_daemon.Recover()
var err error = nil
@@ -137,6 +144,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
}
defer func() { socket.(*net.UnixListener).SetDeadline(time.Now()) }()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
socket_error := make(chan error)
wg.Add(1)
go func() {
@@ -171,7 +181,7 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
go func() {
defer sd_daemon.Recover()
defer wg.Done()
- handler(backend, limits, conn.(*net.UnixConn))
+ handler(backend, limits, conn.(*net.UnixConn), ctx)
}()
}
}
@@ -196,6 +206,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
}
sd_daemon.Notification{State: "READY=1"}.Send(false)
}
+ case <-ctx.Done():
+ sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
+ return sd_daemon.EXIT_FAILURE
case err = <-socket_error:
sd_daemon.Log.Err(fmt.Sprintf("%v", err))
return sd_daemon.EXIT_NETWORK