// Copyright (C) 2015-2017 Luke Shumaker // // This library is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public // License as published by the Free Software Foundation; either // version 2.1 of the License, or (at your option) any later version. // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA // 02110-1301 USA // Package nslcd_systemd does the legwork for implementing a systemd // socket-activated nslcd server. // // You just need to implement the Backend interface, then pass it to // Main, which will return the exit code for the process. Everything // but the backend is taken care of for you! // // package main // // import ( // "context" // "os" // // "git.lukeshu.com/go/libnslcd/nslcd_server" // "git.lukeshu.com/go/libnslcd/nslcd_systemd" // ) // // func main() { // backend := ... // limits := nslcd_server.Limits{ ... } // ctx := context.Background() // os.Exit(int(nslcd_systemd.Main(backend, limits, ctx))) // } package nslcd_systemd // import "git.lukeshu.com/go/libnslcd/nslcd_systemd" import ( "context" "fmt" "net" "os" "os/signal" "sync" "time" "git.lukeshu.com/go/libnslcd/nslcd_server" "git.lukeshu.com/go/libsystemd/sd_daemon" "golang.org/x/sys/unix" ) type Backend interface { nslcd_server.Backend Init() error Reload() error Close() } func get_socket() (socket net.Listener, err error) { fds := sd_daemon.ListenFds(true) if fds == nil { err = fmt.Errorf("Failed to aquire sockets from systemd") return } if len(fds) != 1 { err = fmt.Errorf("Wrong number of sockets from systemd: expected %d but got %d", 1, len(fds)) return } socket, err = net.FileListener(fds[0]) fds[0].Close() return } func getpeercred(conn *net.UnixConn) (cred unix.Ucred, err error) { rawconn, err := conn.SyscallConn() if err != nil { return } var _cred *unix.Ucred var _err error err = rawconn.Control(func(fd uintptr) { _cred, _err = unix.GetsockoptUcred(int(fd), unix.SOL_SOCKET, unix.SO_PEERCRED) }) if err != nil { return } if _err != nil { err = _err return } cred = *_cred return } func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net.UnixConn, ctx context.Context) { defer conn.Close() ctx, cancel := context.WithCancel(ctx) defer cancel() cred, err := getpeercred(conn) if err != nil { sd_daemon.Log.Debug("Connection from unknown client") } else { sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v", cred.Pid, cred.Uid, cred.Gid)) ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred) } err = nslcd_server.HandleRequest(backend, limits, conn, ctx) if err != nil { sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err)) } } func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint8 { defer sd_daemon.Recover() var err error = nil sigs := make(chan os.Signal) signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP) disable_nss_module() err = backend.Init() if err != nil { sd_daemon.Log.Err(fmt.Sprintf("Could not initialize backend: %v", err)) sd_daemon.Notification{State: "STOPPING=1"}.Send(false) return sd_daemon.EXIT_FAILURE } defer backend.Close() var wg sync.WaitGroup defer wg.Wait() socket, err := get_socket() if err != nil { sd_daemon.Log.Err(fmt.Sprintf("%v", err)) sd_daemon.Notification{State: "STOPPING=1"}.Send(false) return sd_daemon.EXIT_NOTRUNNING } defer func() { socket.(*net.UnixListener).SetDeadline(time.Now()) }() ctx, cancel := context.WithCancel(ctx) defer cancel() socket_error := make(chan error) wg.Add(1) go func() { defer sd_daemon.Recover() defer wg.Done() var tempDelay time.Duration last := false for !last { conn, err := socket.Accept() if err != nil { if ne, ok := err.(net.Error); ok && ne.Timeout() { last = true } else if ne, ok := err.(net.Error); ok && ne.Temporary() { sd_daemon.Log.Notice(fmt.Sprintf("temporary error %v", err)) if tempDelay == 0 { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 } if max := 1 * time.Second; tempDelay > max { tempDelay = max } time.Sleep(tempDelay) } else { socket_error <- err last = true } } if conn != nil { wg.Add(1) go func() { defer sd_daemon.Recover() defer wg.Done() handler(backend, limits, conn.(*net.UnixConn), ctx) }() } } }() defer sd_daemon.Notification{State: "STOPPING=1"}.Send(false) sd_daemon.Notification{State: "READY=1"}.Send(false) for { select { case sig := <-sigs: switch sig { case unix.SIGTERM: sd_daemon.Log.Notice("Received SIGTERM, shutting down") return sd_daemon.EXIT_SUCCESS case unix.SIGHUP: sd_daemon.Log.Notice("Received SIGHUP, reloading") sd_daemon.Notification{State: "RELOADING=1"}.Send(false) err := backend.Reload() if err != nil { sd_daemon.Log.Notice(fmt.Sprintf("Could not reload backend: %v", err)) return sd_daemon.EXIT_NOTRUNNING } sd_daemon.Notification{State: "READY=1"}.Send(false) } case <-ctx.Done(): sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err())) return sd_daemon.EXIT_FAILURE case err = <-socket_error: sd_daemon.Log.Err(fmt.Sprintf("%v", err)) return sd_daemon.EXIT_NETWORK } } }