summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nslcd_server/ctx.go30
-rwxr-xr-xnslcd_server/func_handlerequest.go.gen11
-rw-r--r--nslcd_systemd/disable_nss_module.go12
-rw-r--r--nslcd_systemd/nslcd_systemd.go44
4 files changed, 73 insertions, 24 deletions
diff --git a/nslcd_server/ctx.go b/nslcd_server/ctx.go
index 5214adc..9722017 100644
--- a/nslcd_server/ctx.go
+++ b/nslcd_server/ctx.go
@@ -20,9 +20,26 @@ package nslcd_server
import (
"context"
+ "git.lukeshu.com/go/libsystemd/sd_daemon"
"golang.org/x/sys/unix"
)
+// Logger is the common interface between
+// `"git.lukeshu.com/go/libsystemd/sd_daemon".Logger` and
+// `"log/syslog".Writer`.
+type Logger interface {
+ Emerg(m string) error
+ Alert(m string) error
+ Crit(m string) error
+ Err(m string) error
+ Warning(m string) error
+ Notice(m string) error
+ Info(m string) error
+ Debug(m string) error
+
+ Write(m []byte) (int, error)
+}
+
// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
@@ -35,6 +52,11 @@ var (
// The associated value will be of type
// "golang.org/x/sys/unix".Ucred
PeerCredKey = &contextKey{"peercred"}
+
+ // LoggerKey is a context key. It can be used in the backend
+ // methods to access a logger. The associated value will be
+ // an implementation of Logger.
+ LoggerKey = &contextKey{"log"}
)
// PeerCredFromContext is a convenience function for
@@ -44,3 +66,11 @@ func PeerCredFromContext(ctx context.Context) (unix.Ucred, bool) {
cred, ok := ctx.Value(PeerCredKey).(unix.Ucred)
return cred, ok
}
+
+func LoggerFromContext(ctx context.Context) Logger {
+ logger, ok := ctx.Value(LoggerKey).(Logger)
+ if !ok {
+ return sd_daemon.Log
+ }
+ return logger
+}
diff --git a/nslcd_server/func_handlerequest.go.gen b/nslcd_server/func_handlerequest.go.gen
index 750a7b0..8abfb01 100755
--- a/nslcd_server/func_handlerequest.go.gen
+++ b/nslcd_server/func_handlerequest.go.gen
@@ -28,7 +28,6 @@ import (
"context"
"fmt"
"io"
- "os"
"time"
p "git.lukeshu.com/go/libnslcd/nslcd_proto"
@@ -109,6 +108,8 @@ func HandleRequest(backend Backend, limits Limits, conn Conn, ctx context.Contex
}
}
+ log := LoggerFromContext(ctx)
+
var in io.Reader = conn
if limits.RequestMaxSize > 0 {
in = &io.LimitedReader{R: in, N: limits.RequestMaxSize}
@@ -135,7 +136,7 @@ while read -r request; do
PAM_Authentication)
echo '_req := req'
echo '_req.Password = sensitive'
- echo 'fmt.Fprintf(os.Stderr, "Request: %#v\n", _req)'
+ echo 'log.Info(fmt.Sprintf("Request: %#v\n", _req))'
;;
PAM_PwMod)
echo '_req := req'
@@ -143,15 +144,15 @@ while read -r request; do
echo ' _req.OldPassword = sensitive'
echo '}'
echo '_req.NewPassword = sensitive'
- echo 'fmt.Fprintf(os.Stderr, "Request: %#v\n", _req)'
+ echo 'log.Info(fmt.Sprintf("Request: %#v", _req))'
;;
PAM_UserMod)
echo '_req := req'
echo '_req.Password = sensitive'
- echo 'fmt.Fprintf(os.Stderr, "Request: %#v\n", _req)'
+ echo 'log.Info(fmt.Sprintf("Request: %#v", _req))'
;;
*)
- echo 'fmt.Fprintf(os.Stderr, "Request: %#v\n", req)'
+ echo 'log.Info(fmt.Sprintf("Request: %#v", req))'
;;
esac
)
diff --git a/nslcd_systemd/disable_nss_module.go b/nslcd_systemd/disable_nss_module.go
index df22360..32b105a 100644
--- a/nslcd_systemd/disable_nss_module.go
+++ b/nslcd_systemd/disable_nss_module.go
@@ -23,7 +23,7 @@ import (
"fmt"
"git.lukeshu.com/go/libgnulinux/dl"
- "git.lukeshu.com/go/libsystemd/sd_daemon"
+ "git.lukeshu.com/go/libnslcd/nslcd_server"
)
//static char *strary(char **ary, unsigned int n) { return ary[n]; }
@@ -35,27 +35,27 @@ const (
nss_module_sym_enablelookups = "_nss_ldap_enablelookups"
)
-func disable_nss_module() {
+func disable_nss_module(log nslcd_server.Logger) {
handle, err := dl.Open(nss_module_soname, dl.RTLD_LAZY|dl.RTLD_NODELETE)
if err == nil {
defer handle.Close()
} else {
- sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err))
+ log.Warning(fmt.Sprintf("NSS module %s not loaded: %v", nss_module_soname, err))
return
}
c_version_info, err := handle.Sym(nss_module_sym_version)
if err == nil {
g_version_info := (**C.char)(c_version_info)
- sd_daemon.Log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname,
+ log.Debug(fmt.Sprintf("NSS module %s version %s %s", nss_module_soname,
C.GoString(C.strary(g_version_info, 0)),
C.GoString(C.strary(g_version_info, 1))))
} else {
- sd_daemon.Log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err))
+ log.Warning(fmt.Sprintf("NSS module %s version missing: %v", nss_module_soname, err))
}
c_enable_flag, err := handle.Sym(nss_module_sym_enablelookups)
if err != nil {
- sd_daemon.Log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err))
+ log.Warning(fmt.Sprintf("Unable to disable NSS ldap module for nslcd process: %v", err))
return
}
g_enable_flag := (*C.int)(c_enable_flag)
diff --git a/nslcd_systemd/nslcd_systemd.go b/nslcd_systemd/nslcd_systemd.go
index b2f8e28..29d49d6 100644
--- a/nslcd_systemd/nslcd_systemd.go
+++ b/nslcd_systemd/nslcd_systemd.go
@@ -61,6 +61,15 @@ type Backend interface {
Close()
}
+type contextKey struct {
+ name string
+}
+
+var (
+ // ConnectionIDKey is a context key.
+ ConnectionIDKey = &contextKey{"connection-id"}
+)
+
func get_socket() (socket net.Listener, err error) {
fds := sd_daemon.ListenFds(true)
if fds == nil {
@@ -102,17 +111,20 @@ func handler(backend nslcd_server.Backend, limits nslcd_server.Limits, conn *net
ctx, cancel := context.WithCancel(ctx)
defer cancel()
+ // TODO: override nslcd_server.LoggerKey with a logger that includes ConnectionIDKey
+ log := nslcd_server.LoggerFromContext(ctx)
+
cred, err := getpeercred(conn)
if err != nil {
- sd_daemon.Log.Debug("Connection from unknown client")
+ log.Debug("Connection from unknown client")
} else {
- sd_daemon.Log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
+ log.Debug(fmt.Sprintf("Connection from pid=%v uid=%v gid=%v",
cred.Pid, cred.Uid, cred.Gid))
ctx = context.WithValue(ctx, nslcd_server.PeerCredKey, cred)
}
err = nslcd_server.HandleRequest(backend, limits, conn, ctx)
if err != nil {
- sd_daemon.Log.Notice(fmt.Sprintf("Error while handling request: %v", err))
+ log.Notice(fmt.Sprintf("Error while handling request: %v", err))
}
}
@@ -123,11 +135,13 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
sigs := make(chan os.Signal)
signal.Notify(sigs, unix.SIGTERM, unix.SIGHUP)
- disable_nss_module()
+ log := nslcd_server.LoggerFromContext(ctx)
+
+ disable_nss_module(log)
err = backend.Init()
if err != nil {
- sd_daemon.Log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
+ log.Err(fmt.Sprintf("Could not initialize backend: %v", err))
sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
return sd_daemon.EXIT_FAILURE
}
@@ -138,7 +152,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
socket, err := get_socket()
if err != nil {
- sd_daemon.Log.Err(fmt.Sprintf("%v", err))
+ log.Err(fmt.Sprintf("%v", err))
sd_daemon.Notification{State: "STOPPING=1"}.Send(false)
return sd_daemon.EXIT_NOTRUNNING
}
@@ -153,6 +167,8 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
defer sd_daemon.Recover()
defer wg.Done()
+ id := 0
+
var tempDelay time.Duration
last := false
for !last {
@@ -161,7 +177,7 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
if ne, ok := err.(net.Error); ok && ne.Timeout() {
last = true
} else if ne, ok := err.(net.Error); ok && ne.Temporary() {
- sd_daemon.Log.Notice(fmt.Sprintf("temporary error %v", err))
+ log.Notice(fmt.Sprintf("temporary error %v", err))
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
@@ -178,10 +194,12 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
}
if conn != nil {
wg.Add(1)
+ id++
+ hctx := context.WithValue(ctx, ConnectionIDKey, id)
go func() {
defer sd_daemon.Recover()
defer wg.Done()
- handler(backend, limits, conn.(*net.UnixConn), ctx)
+ handler(backend, limits, conn.(*net.UnixConn), hctx)
}()
}
}
@@ -194,23 +212,23 @@ func Main(backend Backend, limits nslcd_server.Limits, ctx context.Context) uint
case sig := <-sigs:
switch sig {
case unix.SIGTERM:
- sd_daemon.Log.Notice("Received SIGTERM, shutting down")
+ log.Notice("Received SIGTERM, shutting down")
return sd_daemon.EXIT_SUCCESS
case unix.SIGHUP:
- sd_daemon.Log.Notice("Received SIGHUP, reloading")
+ log.Notice("Received SIGHUP, reloading")
sd_daemon.Notification{State: "RELOADING=1"}.Send(false)
err := backend.Reload()
if err != nil {
- sd_daemon.Log.Notice(fmt.Sprintf("Could not reload backend: %v", err))
+ log.Notice(fmt.Sprintf("Could not reload backend: %v", err))
return sd_daemon.EXIT_NOTRUNNING
}
sd_daemon.Notification{State: "READY=1"}.Send(false)
}
case <-ctx.Done():
- sd_daemon.Log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
+ log.Err(fmt.Sprintf("Context was canceled, shutting down: %v", ctx.Err()))
return sd_daemon.EXIT_FAILURE
case err = <-socket_error:
- sd_daemon.Log.Err(fmt.Sprintf("%v", err))
+ log.Err(fmt.Sprintf("%v", err))
return sd_daemon.EXIT_NETWORK
}
}