// Copyright (C) 2017 Luke Shumaker // // This library is free software; you can redistribute it and/or // modify it under the terms of the GNU Lesser General Public // License as published by the Free Software Foundation; either // version 2.1 of the License, or (at your option) any later version. // // This library is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU // Lesser General Public License for more details. // // You should have received a copy of the GNU Lesser General Public // License along with this library; if not, write to the Free Software // Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA // 02110-1301 USA package nslcd_systemd_test import ( "fmt" "net" "os" "strings" "testing" "time" "github.com/pkg/errors" "golang.org/x/sys/unix" ) func testWithTimeout(t *testing.T, timeout time.Duration, fn func(t *testing.T, s chan<- string)) { finished := make(chan bool) status := make(chan string) cur_status := "" go func() { t.Run("timed", func(t *testing.T) { fn(t, status) }) finished <- true }() for { select { case cur_status = <-status: case <-finished: close(status) return case <-time.After(timeout): close(status) if cur_status != "" { t.Fatal("timed out: " + cur_status) } else { t.Fatal("timed out") } return } } } var sdListenFds = uintptr(0) var sdListenFdNames = []string{} type filer interface { Close() error File() (*os.File, error) } func sdOpenStream(streampath string) (filer, error) { // I should have this change based on type of stream listener, err := net.ListenUnix("unix", &net.UnixAddr{Net: "unix", Name: streampath}) if err != nil { return nil, errors.Wrap(err, "net.ListenUnix()") } listener.SetUnlinkOnClose(false) return listener, nil } func fdIsDevNull(fd uintptr) error { file := os.NewFile(fd, fmt.Sprintf("/dev/fd/%d", fd)) if file == nil { return errors.Errorf("not a valid file descriptor: %d", fd) } statFd, err := file.Stat() if err != nil { return err } statNull, err := os.Stat("/dev/null") if err != nil { return err } if !os.SameFile(statFd, statNull) { return errors.Errorf("FD %d is not /dev/null", fd) } return nil } func sdActivatedStream(streampath string) error { // Set up the file descriptor err := func(streampath string) error { file, err := func(streampath string) (*os.File, error) { listener, err := sdOpenStream(streampath) if err != nil { return nil, err } defer listener.Close() file, err := listener.File() if err != nil { return nil, errors.Wrap(err, "listener.File()") } return file, nil }(streampath) if err != nil { return err } defer file.Close() fd := sdListenFds + 3 err = fdIsDevNull(fd) if err != nil { return err } err = unix.Dup2(int(file.Fd()), int(fd)) if err != nil { return errors.Wrap(err, "Dup2()") } return nil }(streampath) if err != nil { return err } sdListenFds++ sdListenFdNames = append(sdListenFdNames, streampath) err = os.Setenv("LISTEN_PID", fmt.Sprintf("%d", os.Getpid())) if err != nil { return errors.Wrap(err, "os.Setenv()") } err = os.Setenv("LISTEN_FDS", fmt.Sprintf("%d", sdListenFds)) if err != nil { return errors.Wrap(err, "os.Setenv()") } err = os.Setenv("LISTEN_FDNAMES", strings.Join(sdListenFdNames, ":")) if err != nil { return errors.Wrap(err, "os.Setenv()") } return nil } func sdActivatedReset() error { devnull, err := os.OpenFile("/dev/null", os.O_RDWR, 0666) if err != nil { return err } defer devnull.Close() for i := uintptr(0); i < sdListenFds; i++ { err = unix.Dup2(int(devnull.Fd()), int(3+i)) if err != nil { return err } } sdListenFds = 0 sdListenFdNames = []string{} return nil } func sdNotifyListen(sockname string) (*net.UnixConn, error) { err := os.Setenv("NOTIFY_SOCKET", sockname) if err != nil { return nil, err } return net.ListenUnixgram("unixgram", &net.UnixAddr{Net: "unixgram", Name: sockname}) } func sdNotifyHandle(sock *net.UnixConn, fn func(dat []byte, oob []byte) error) error { var dat [4096]byte oob := make([]byte, unix.CmsgSpace(unix.SizeofUcred)+unix.CmsgSpace(8*768)) for { n, oobn, flags, _, err := sock.ReadMsgUnix(dat[:], oob[:]) if err != nil { return err } if flags&unix.MSG_TRUNC != 0 { // Received notify message exceeded maximum size. Ignoring." continue } err = fn(dat[:n], oob[:oobn]) if err != nil { return err } } } func humanizeU64(n uint64) string { str := fmt.Sprintf("%d", n) bts := make([]byte, len(str)+(len(str)-1)/3) s := 0 b := 0 for s < len(str) && b < len(bts) { if (s % 3 == 0 && s > 0) { bts[len(bts)-1-b] = ',' b++ } bts[len(bts)-1-b] = str[len(str)-1-s] b++ s++ } return string(bts) }