summaryrefslogtreecommitdiff
path: root/nslcd_systemd/util_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'nslcd_systemd/util_test.go')
-rw-r--r--nslcd_systemd/util_test.go193
1 files changed, 193 insertions, 0 deletions
diff --git a/nslcd_systemd/util_test.go b/nslcd_systemd/util_test.go
new file mode 100644
index 0000000..15147b7
--- /dev/null
+++ b/nslcd_systemd/util_test.go
@@ -0,0 +1,193 @@
+// Copyright (C) 2017 Luke Shumaker <lukeshu@sbcglobal.net>
+//
+// This library is free software; you can redistribute it and/or
+// modify it under the terms of the GNU Lesser General Public
+// License as published by the Free Software Foundation; either
+// version 2.1 of the License, or (at your option) any later version.
+//
+// This library is distributed in the hope that it will be useful,
+// but WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// Lesser General Public License for more details.
+//
+// You should have received a copy of the GNU Lesser General Public
+// License along with this library; if not, write to the Free Software
+// Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
+// 02110-1301 USA
+
+package nslcd_systemd_test
+
+import (
+ "fmt"
+ "net"
+ "os"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/pkg/errors"
+ "golang.org/x/sys/unix"
+)
+
+func testWithTimeout(t *testing.T, timeout time.Duration, fn func(t *testing.T, s chan<- string)) {
+ finished := make(chan bool)
+ status := make(chan string)
+ cur_status := ""
+ go func() {
+ t.Run("timed", func(t *testing.T) { fn(t, status) })
+ finished <- true
+ }()
+ for {
+ select {
+ case cur_status = <-status:
+ case <-finished:
+ close(status)
+ return
+ case <-time.After(timeout):
+ close(status)
+ if cur_status != "" {
+ t.Fatal("timed out: " + cur_status)
+ } else {
+ t.Fatal("timed out")
+ }
+ return
+ }
+ }
+}
+
+var sdListenFds = uintptr(0)
+var sdListenFdNames = []string{}
+
+type filer interface {
+ Close() error
+ File() (*os.File, error)
+}
+
+func sdOpenStream(streampath string) (filer, error) {
+ // I should have this change based on type of stream
+ listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: streampath})
+ if err != nil {
+ return nil, errors.Wrap(err, "net.ListenUnix()")
+ }
+ listener.SetUnlinkOnClose(false)
+ return listener, nil
+}
+
+func fdIsDevNull(fd uintptr) error {
+ file := os.NewFile(fd, fmt.Sprintf("/dev/fd/%d", fd))
+ if file == nil {
+ return errors.Errorf("not a valid file descriptor: %d", fd)
+ }
+
+ statFd, err := file.Stat()
+ if err != nil {
+ return err
+ }
+
+ statNull, err := os.Stat("/dev/null")
+ if err != nil {
+ return err
+ }
+
+ if !os.SameFile(statFd, statNull) {
+ return errors.Errorf("FD %d is not /dev/null", fd)
+ }
+ return nil
+}
+
+func sdActivatedStream(streampath string) error {
+ // Set up the file descriptor
+ err := func(streampath string) error {
+ file, err := func(streampath string) (*os.File, error) {
+ listener, err := sdOpenStream(streampath)
+ if err != nil {
+ return nil, err
+ }
+ defer listener.Close()
+ file, err := listener.File()
+ if err != nil {
+ return nil, errors.Wrap(err, "listener.File()")
+ }
+ return file, nil
+ }(streampath)
+ if err != nil {
+ return err
+ }
+ defer file.Close()
+
+ fd := sdListenFds + 3
+ err = fdIsDevNull(fd)
+ if err != nil {
+ return err
+ }
+ err = unix.Dup2(int(file.Fd()), int(fd))
+ if err != nil {
+ return errors.Wrap(err, "Dup2()")
+ }
+ return nil
+ }(streampath)
+ if err != nil {
+ return err
+ }
+
+ sdListenFds++
+ sdListenFdNames = append(sdListenFdNames, streampath)
+
+ err = os.Setenv("LISTEN_PID", fmt.Sprintf("%d", os.Getpid()))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ err = os.Setenv("LISTEN_FDS", fmt.Sprintf("%d", sdListenFds))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ err = os.Setenv("LISTEN_FDNAMES", strings.Join(sdListenFdNames, ":"))
+ if err != nil {
+ return errors.Wrap(err, "os.Setenv()")
+ }
+ return nil
+}
+
+func sdActivatedReset() error {
+ devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666)
+ if err != nil {
+ return err
+ }
+ defer devnull.Close()
+ for i := uintptr(0); i < sdListenFds; i++ {
+ err = unix.Dup2(int(devnull.Fd()), int(3+i))
+ if err != nil {
+ return err
+ }
+ }
+ sdListenFds = 0
+ sdListenFdNames = []string{}
+ return nil
+}
+
+func sdNotifyListen(sockname string) (*net.UnixConn, error) {
+ err := os.Setenv("NOTIFY_SOCKET", sockname)
+ if err != nil {
+ return nil, err
+ }
+ return net.ListenUnixgram("unixgram", &net.UnixAddr{Net: "unixgram", Name: sockname})
+}
+
+func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) error {
+ var dat [4096]byte
+ oob := make([]byte, unix.CmsgSpace(unix.SizeofUcred)+unix.CmsgSpace(8*768))
+ for {
+ n, oobn, flags, _, err := sock.ReadMsgUnix(dat[:], oob[:])
+ if err != nil {
+ return err
+ }
+ if flags&unix.MSG_TRUNC != 0 {
+ // Received notify message exceeded maximum size. Ignoring."
+ continue
+ }
+ err = fn(dat[:n], oob[:oobn])
+ if err != nil {
+ return err
+ }
+ }
+}