summaryrefslogtreecommitdiff
path: root/nslcd_systemd/nslcd_systemd.go
diff options
context:
space:
mode:
Diffstat (limited to 'nslcd_systemd/nslcd_systemd.go')
-rw-r--r--nslcd_systemd/nslcd_systemd.go23
1 files changed, 18 insertions, 5 deletions
diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go
index 0999106..b2f8e28 100644
--- a/nslcd_systemd/nslcd_systemd.go
+++ b/nslcd_systemd/nslcd_systemd.go
@@ -25,6 +25,7 @@
// package main
//
// import (
+// "context"
// "os"
//
// "git.lukeshu.com/go/libnslcd/nslcd_server"
@@ -34,11 +35,13 @@
// func main() {
// backend := ...
// limits := nslcd_server.Limits{ ... }
-// os.Exit(int(nslcd_systemd.Main(backend, limits)))
+// ctx := context.Background()
+// os.Exit(int(nslcd_systemd.Main(backend, limits, ctx)))
// }
package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd"
import (
+ "context"
"fmt"
"net"
"os"
@@ -94,22 +97,26 @@ func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) {
return
}
-func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn) {
+func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn, ctx context.Context) {
defer conn.Close()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
cred, err := getpeercred(conn)
if err != nil {
sd_daemon.Log.Debug("Connection from unknown client")
} else {
sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
cred.Pid, cred.Uid, cred.Gid))
+ ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred)
}
- err = nslcd_server.HandleRequest(backend, limits, conn, cred)
+ err = nslcd_server.HandleRequest(backend, limits, conn, ctx)
if err != nil {
sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
}
}
-func Main(backend Backend, limits nslcd_server.Limits) uint8 {
+func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint8 {
defer sd_daemon.Recover()
var err error = nil
@@ -137,6 +144,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
}
defer func() { socket.(*net.UnixListener).SetDeadline(time.Now()) }()
+ ctx, cancel := context.WithCancel(ctx)
+ defer cancel()
+
socket_error := make(chan error)
wg.Add(1)
go func() {
@@ -171,7 +181,7 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
go func() {
defer sd_daemon.Recover()
defer wg.Done()
- handler(backend, limits, conn.(*net.UnixConn))
+ handler(backend, limits, conn.(*net.UnixConn), ctx)
}()
}
}
@@ -196,6 +206,9 @@ func Main(backend Backend, limits nslcd_server.Limits) uint8 {
}
sd_daemon.Notification{State: "READY=1"}.Send(false)
}
+ case <-ctx.Done():
+ sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
+ return sd_daemon.EXIT_FAILURE
case err = <-socket_error:
sd_daemon.Log.Err(fmt.Sprintf("%v", err))
return sd_daemon.EXIT_NETWORK