summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLuke Shumaker <lukeshu@lukeshu.com>2017-09-04 19:42:01 -0400
committerLuke Shumaker <lukeshu@lukeshu.com>2017-09-08 16:55:55 -0400
commit7b8aefea056f995ee2d00a79c22277c09cda5363 (patch)
tree8bc408fbcd20f698f18d82cac01960b67e49d322
parent542c732b94e0a5e7c02fd209a60bd068dbbfa03b (diff)
add tests
-rw-r--r--nslcd_systemd/misc_test.go251
-rw-r--r--nslcd_systemd/util_test.go193
2 files changed, 444 insertions, 0 deletions
diff --git a/nslcd_systemd/misc_test.go b/nslcd_systemd/misc_test.go
new file mode 100644
index 0000000..d2b2a7e
--- /dev/null
+++ b/nslcd_systemd/misc_test.go
@@ -0,0 +1,251 @@
+// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net>
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+// 02110-1301 USA
+
+package nslcd_systemd_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "net"
+ "os"
+ "strings"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "git.lukeshu.com/go/libnslcd/nslcd_proto"
+ "git.lukeshu.com/go/libnslcd/nslcd_server"
+ "git.lukeshu.com/go/libnslcd/nslcd_systemd"
+ "golang.org/x/sys/unix"
+)
+
+type testContext struct {
+ *testing.T
+ tmpdir string
+ status chan<- string
+}
+
+func testBadClient(t *testContext, backend nslcd_systemd.Backend, toclose bool) {
+ t.status <- "setting up"
+
+ errfatal := func(err error) {
+ if err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ evExitSupervisor := make(chan error)
+ evExitServer := make(chan uint8)
+ evReload := make(chan bool)
+
+ // supervisor //////////////////////////////////////////////////////////
+ notify_sock, err := sdNotifyListen(t.tmpdir + "/notify.sock")
+ errfatal(err)
+ go func() {
+ reloading := false
+ evExitSupervisor <- sdNotifyHandle(notify_sock, func(dat []byte, oob []byte) error {
+ for _, line := range strings.Split(string(dat), "\n") {
+ switch line {
+ case "RELOADING=1":
+ reloading = true
+ case "READY=1":
+ if reloading {
+ evReload <- true
+ }
+ reloading = false
+ }
+ }
+ return nil
+ })
+ }()
+
+ // server //////////////////////////////////////////////////////////////
+ errfatal(sdActivatedStream(t.tmpdir + "/nslcd.sock"))
+ go func() {
+ evExitServer <- nslcd_systemd.Main(backend)
+ }()
+
+ // client/driver ///////////////////////////////////////////////////////
+
+ t.status <- "talking with server"
+ conn, err := net.Dial("unix", t.tmpdir+"/nslcd.sock")
+ errfatal(err)
+ errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_VERSION))
+ errfatal(nslcd_proto.Write(conn, nslcd_proto.NSLCD_ACTION_PASSWD_ALL))
+ // Wait for NSLCD_RESULT_*, to make sure the server has made
+ // it in to backend code.
+ var n int32
+ errfatal(nslcd_proto.Read(conn, &n))
+ if n != nslcd_proto.NSLCD_VERSION {
+ t.Fatal("server version wrong")
+ }
+ errfatal(nslcd_proto.Read(conn, &n))
+ if n != nslcd_proto.NSLCD_ACTION_PASSWD_ALL {
+ t.Fatal("server action wrong")
+ }
+ errfatal(nslcd_proto.Read(conn, &n))
+ if n != nslcd_proto.NSLCD_RESULT_BEGIN && n != nslcd_proto.NSLCD_RESULT_END {
+ t.Fatal("server result malformed")
+ }
+ if toclose {
+ errfatal(conn.Close())
+ }
+
+ t.status <- "waiting for server reload"
+ errfatal(unix.Kill(unix.Getpid(), unix.SIGHUP))
+ <-evReload
+
+ // A limitation of Unix sockets is that some may get dropped
+ // if they arrive close together. So give it some (a half
+ // second is probably generous by a couple orders of
+ // magnitude) time to handle SIGHUP before sending SIGTERM, so
+ // that we are sure it gets both.
+ time.Sleep(time.Second / 2)
+
+ t.status <- "waiting for server exit"
+ errfatal(unix.Kill(unix.Getpid(), unix.SIGTERM))
+ status := <-evExitServer
+ if status != 0 {
+ t.Fatalf("Main() exited with %d", status)
+ }
+
+ t.status <- "waiting for supervisor exit"
+ errfatal(notify_sock.SetReadDeadline(time.Now()))
+ err = <-evExitSupervisor
+ if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
+ err = nil
+ }
+ errfatal(err)
+ errfatal(notify_sock.Close())
+}
+
+type NonLockingBackend struct {
+ nslcd_server.NilBackend
+}
+
+func (o *NonLockingBackend) Init() error { return nil }
+func (o *NonLockingBackend) Reload() error { return nil }
+func (o *NonLockingBackend) Close() {}
+
+func (o *NonLockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
+ ret := make(chan nslcd_proto.Passwd)
+ go func() {
+ defer close(ret)
+
+ for i := 0; i < 500; i++ {
+ ret <- nslcd_proto.Passwd{
+ Name: fmt.Sprintf("user%d", i),
+ PwHash: "x",
+ UID: int32(1000 + i),
+ GID: 1000,
+ GECOS: fmt.Sprintf("User %d", i),
+ HomeDir: fmt.Sprintf("/home/user%d", i),
+ Shell: "/bin/sh",
+ }
+ }
+ }()
+ return ret
+}
+
+type LockingBackend struct {
+ NonLockingBackend
+ lock sync.RWMutex
+}
+
+func (o *LockingBackend) Reload() error {
+ o.lock.Lock()
+ defer o.lock.Unlock()
+ return o.NonLockingBackend.Reload()
+}
+
+func (o *LockingBackend) Close() {
+ o.lock.Lock()
+ defer o.lock.Unlock()
+ o.NonLockingBackend.Close()
+}
+
+func (o *LockingBackend) Passwd_All(cred unix.Ucred, req nslcd_proto.Request_Passwd_All) <-chan nslcd_proto.Passwd {
+ o.lock.RLock()
+ ret := make(chan nslcd_proto.Passwd)
+ go func() {
+ defer o.lock.RUnlock()
+ defer close(ret)
+
+ for i := 0; i < 500; i++ {
+ ret <- nslcd_proto.Passwd{
+ Name: fmt.Sprintf("user%d", i),
+ PwHash: "x",
+ UID: int32(1000 + i),
+ GID: 1000,
+ GECOS: fmt.Sprintf("User %d", i),
+ HomeDir: fmt.Sprintf("/home/user%d", i),
+ Shell: "/bin/sh",
+ }
+ }
+ }()
+ return ret
+}
+
+func init() {
+ if fdIsDevNull(3) == nil {
+ return
+ }
+
+ devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666)
+ if err != nil {
+ panic(err)
+ }
+ if devnull.Fd() == 3 {
+ return
+ }
+
+ fmt.Fprintln(os.Stderr, "Could not open /dev/null on FD 3; calling dup2 and re-exec()ing")
+ // shell out to do the FD manipulation--If we made it here,
+ // there's a good chance that FD3 was managed by the go
+ // runtime, and would be changed before we execve(2).
+ panic(syscall.Exec("/bin/sh", append([]string{"sh", "-c", "exec -- \"$@\" 3<>/dev/null"}, os.Args...), os.Environ()))
+}
+
+func TestBadClient(t *testing.T) {
+ testcases := []struct {
+ name string
+ backend nslcd_systemd.Backend
+ toclose bool
+ }{
+ {"NoLocks-ClientOpen", &NonLockingBackend{}, false},
+ {"NoLocks-ClientClose", &NonLockingBackend{}, true},
+ {"Locking-ClientOpen", &LockingBackend{}, false},
+ {"Locking-ClientClose", &LockingBackend{}, true},
+ }
+ for _, testcase := range testcases {
+ func() {
+ tmpdir, err := ioutil.TempDir("", "go-test-libnslcd-bad-client.")
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer os.RemoveAll(tmpdir)
+ defer sdActivatedReset()
+ t.Run(testcase.name, func(t *testing.T) {
+ testWithTimeout(t, 2*time.Second, func(t *testing.T, s chan<- string) {
+ ctx := &testContext{T: t, tmpdir: tmpdir, status: s}
+ testBadClient(ctx, testcase.backend, testcase.toclose)
+ })
+ })
+ }()
+ }
+}
diff --git a/nslcd_systemd/util_test.go b/nslcd_systemd/util_test.go
new file mode 100644
index 0000000..15147b7
--- /dev/null
+++ b/nslcd_systemd/util_test.go
@@ -0,0 +1,193 @@
+// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net>
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+// 02110-1301 USA
+
+package nslcd_systemd_test
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "golang.org/x/sys/unix"
+)
+
+func testWithTimeout(t *testing.T, timeout time.Duration, fn func(t *testing.T, s chan<- string)) {
+ finished := make(chan bool)
+ status := make(chan string)
+ cur_status := ""
+ go func() {
+ t.Run("timed", func(t *testing.T) { fn(t, status) })
+ finished <- true
+ }()
+ for {
+ select {
+ case cur_status = <-status:
+ case <-finished:
+ close(status)
+ return
+ case <-time.After(timeout):
+ close(status)
+ if cur_status != "" {
+ t.Fatal("timed out: " + cur_status)
+ } else {
+ t.Fatal("timed out")
+ }
+ return
+ }
+ }
+}
+
+var sdListenFds = uintptr(0)
+var sdListenFdNames = []string{}
+
+type filer interface {
+ Close() error
+ File() (*os.File, error)
+}
+
+func sdOpenStream(streampath string) (filer, error) {
+ // I should have this change based on type of stream
+ listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: streampath})
+ if err != nil {
+ return nil, errors.Wrap(err, "net.ListenUnix()")
+ }
+ listener.SetUnlinkOnClose(false)
+ return listener, nil
+}
+
+func fdIsDevNull(fd uintptr) error {
+ file := os.NewFile(fd, fmt.Sprintf("/dev/fd/%d", fd))
+ if file == nil {
+ return errors.Errorf("not a valid file descriptor: %d", fd)
+ }
+
+ statFd, err := file.Stat()
+ if err != nil {
+ return err
+ }
+
+ statNull, err := os.Stat("/dev/null")
+ if err != nil {
+ return err
+ }
+
+ if !os.SameFile(statFd, statNull) {
+ return errors.Errorf("FD %d is not /dev/null", fd)
+ }
+ return nil
+}
+
+func sdActivatedStream(streampath string) error {
+ // Set up the file descriptor
+ err := func(streampath string) error {
+ file, err := func(streampath string) (*os.File, error) {
+ listener, err := sdOpenStream(streampath)
+ if err != nil {
+ return nil, err
+ }
+ defer listener.Close()
+ file, err := listener.File()
+ if err != nil {
+ return nil, errors.Wrap(err, "listener.File()")
+ }
+ return file, nil
+ }(streampath)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ fd := sdListenFds + 3
+ err = fdIsDevNull(fd)
+ if err != nil {
+ return err
+ }
+ err = unix.Dup2(int(file.Fd()), int(fd))
+ if err != nil {
+ return errors.Wrap(err, "Dup2()")
+ }
+ return nil
+ }(streampath)
+ if err != nil {
+ return err
+ }
+
+ sdListenFds++
+ sdListenFdNames = append(sdListenFdNames, streampath)
+
+ err = os.Setenv("LISTEN_PID", fmt.Sprintf("%d", os.Getpid()))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ err = os.Setenv("LISTEN_FDS", fmt.Sprintf("%d", sdListenFds))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ err = os.Setenv("LISTEN_FDNAMES", strings.Join(sdListenFdNames, ":"))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ return nil
+}
+
+func sdActivatedReset() error {
+ devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666)
+ if err != nil {
+ return err
+ }
+ defer devnull.Close()
+ for i := uintptr(0); i < sdListenFds; i++ {
+ err = unix.Dup2(int(devnull.Fd()), int(3+i))
+ if err != nil {
+ return err
+ }
+ }
+ sdListenFds = 0
+ sdListenFdNames = []string{}
+ return nil
+}
+
+func sdNotifyListen(sockname string) (*net.UnixConn, error) {
+ err := os.Setenv("NOTIFY_SOCKET", sockname)
+ if err != nil {
+ return nil, err
+ }
+ return net.ListenUnixgram("unixgram", &net.UnixAddr{Net: "unixgram", Name: sockname})
+}
+
+func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) error {
+ var dat [4096]byte
+ oob := make([]byte, unix.CmsgSpace(unix.SizeofUcred)+unix.CmsgSpace(8*768))
+ for {
+ n, oobn, flags, _, err := sock.ReadMsgUnix(dat[:], oob[:])
+ if err != nil {
+ return err
+ }
+ if flags&unix.MSG_TRUNC != 0 {
+ // Received notify message exceeded maximum size. Ignoring."
+ continue
+ }
+ err = fn(dat[:n], oob[:oobn])
+ if err != nil {
+ return err
+ }
+ }
+}