diff options
Diffstat (limited to 'src/basic/util.c')
| -rw-r--r-- | src/basic/util.c | 210 | 
1 files changed, 152 insertions, 58 deletions
| diff --git a/src/basic/util.c b/src/basic/util.c index e3b2af8e02..18be0bfd5a 100644 --- a/src/basic/util.c +++ b/src/basic/util.c @@ -19,49 +19,48 @@    along with systemd; If not, see <http://www.gnu.org/licenses/>.  ***/ -#include <string.h> -#include <unistd.h> +#include <ctype.h> +#include <dirent.h>  #include <errno.h> -#include <stdlib.h> -#include <signal.h> +#include <fcntl.h> +#include <glob.h> +#include <grp.h> +#include <langinfo.h>  #include <libintl.h> -#include <stdio.h> -#include <syslog.h> -#include <sched.h> -#include <sys/resource.h> +#include <limits.h> +#include <linux/magic.h>  #include <linux/sched.h> -#include <sys/types.h> -#include <sys/stat.h> -#include <fcntl.h> -#include <dirent.h> -#include <sys/ioctl.h> -#include <stdarg.h> +#include <locale.h> +#include <netinet/ip.h>  #include <poll.h> -#include <ctype.h> -#include <sys/prctl.h> -#include <sys/utsname.h>  #include <pwd.h> -#include <netinet/ip.h> -#include <sys/wait.h> -#include <sys/time.h> -#include <glob.h> -#include <grp.h> +#include <sched.h> +#include <signal.h> +#include <stdarg.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/file.h> +#include <sys/ioctl.h>  #include <sys/mman.h> -#include <sys/vfs.h>  #include <sys/mount.h> -#include <linux/magic.h> -#include <limits.h> -#include <langinfo.h> -#include <locale.h>  #include <sys/personality.h> -#include <sys/xattr.h> +#include <sys/prctl.h> +#include <sys/resource.h> +#include <sys/stat.h>  #include <sys/statvfs.h> -#include <sys/file.h> -#include <linux/fs.h> +#include <sys/time.h> +#include <sys/types.h> +#include <sys/utsname.h> +#include <sys/vfs.h> +#include <sys/wait.h> +#include <sys/xattr.h> +#include <syslog.h> +#include <unistd.h>  /* When we include libgen.h because we need dirname() we immediately - * undefine basename() since libgen.h defines it as a macro to the POSIX - * version which is really broken. We prefer GNU basename(). */ + * undefine basename() since libgen.h defines it as a macro to the + * POSIX version which is really broken. We prefer GNU basename(). */  #include <libgen.h>  #undef basename @@ -69,31 +68,34 @@  #include <sys/auxv.h>  #endif -#include "config.h" -#include "macro.h" -#include "util.h" +/* We include linux/fs.h as last of the system headers, as it + * otherwise conflicts with sys/mount.h. Yay, Linux is great! */ +#include <linux/fs.h> + +#include "def.h" +#include "device-nodes.h" +#include "env-util.h" +#include "exit-status.h" +#include "fileio.h" +#include "formats-util.h" +#include "gunicode.h" +#include "hashmap.h" +#include "hostname-util.h"  #include "ioprio.h" -#include "missing.h"  #include "log.h" -#include "strv.h" +#include "macro.h" +#include "missing.h"  #include "mkdir.h"  #include "path-util.h" -#include "exit-status.h" -#include "hashmap.h" -#include "env-util.h" -#include "fileio.h" -#include "device-nodes.h" -#include "utf8.h" -#include "gunicode.h" -#include "virt.h" -#include "def.h" -#include "sparse-endian.h" -#include "formats-util.h"  #include "process-util.h"  #include "random-util.h" -#include "terminal-util.h" -#include "hostname-util.h"  #include "signal-util.h" +#include "sparse-endian.h" +#include "strv.h" +#include "terminal-util.h" +#include "utf8.h" +#include "util.h" +#include "virt.h"  /* Put this test here for a lack of better place */  assert_cc(EAGAIN == EWOULDBLOCK); @@ -354,6 +356,17 @@ FILE* safe_fclose(FILE *f) {          return NULL;  } +DIR* safe_closedir(DIR *d) { + +        if (d) { +                PROTECT_ERRNO; + +                assert_se(closedir(d) >= 0 || errno != EBADF); +        } + +        return NULL; +} +  int unlink_noerrno(const char *path) {          PROTECT_ERRNO;          int r; @@ -2133,7 +2146,13 @@ ssize_t loop_read(int fd, void *buf, size_t nbytes, bool do_poll) {          assert(fd >= 0);          assert(buf); -        while (nbytes > 0) { +        /* If called with nbytes == 0, let's call read() at least +         * once, to validate the operation */ + +        if (nbytes > (size_t) SSIZE_MAX) +                return -EINVAL; + +        do {                  ssize_t k;                  k = read(fd, p, nbytes); @@ -2147,7 +2166,7 @@ ssize_t loop_read(int fd, void *buf, size_t nbytes, bool do_poll) {                                   * and expect that any error/EOF is reported                                   * via read() */ -                                fd_wait_for_event(fd, POLLIN, USEC_INFINITY); +                                (void) fd_wait_for_event(fd, POLLIN, USEC_INFINITY);                                  continue;                          } @@ -2157,10 +2176,12 @@ ssize_t loop_read(int fd, void *buf, size_t nbytes, bool do_poll) {                  if (k == 0)                          return n; +                assert((size_t) k <= nbytes); +                  p += k;                  nbytes -= k;                  n += k; -        } +        } while (nbytes > 0);          return n;  } @@ -2170,9 +2191,10 @@ int loop_read_exact(int fd, void *buf, size_t nbytes, bool do_poll) {          n = loop_read(fd, buf, nbytes, do_poll);          if (n < 0) -                return n; +                return (int) n;          if ((size_t) n != nbytes)                  return -EIO; +          return 0;  } @@ -2182,7 +2204,8 @@ int loop_write(int fd, const void *buf, size_t nbytes, bool do_poll) {          assert(fd >= 0);          assert(buf); -        errno = 0; +        if (nbytes > (size_t) SSIZE_MAX) +                return -EINVAL;          do {                  ssize_t k; @@ -2197,16 +2220,18 @@ int loop_write(int fd, const void *buf, size_t nbytes, bool do_poll) {                                   * and expect that any error/EOF is reported                                   * via write() */ -                                fd_wait_for_event(fd, POLLOUT, USEC_INFINITY); +                                (void) fd_wait_for_event(fd, POLLOUT, USEC_INFINITY);                                  continue;                          }                          return -errno;                  } -                if (nbytes > 0 && k == 0) /* Can't really happen */ +                if (_unlikely_(nbytes > 0 && k == 0)) /* Can't really happen */                          return -EIO; +                assert((size_t) k <= nbytes); +                  p += k;                  nbytes -= k;          } while (nbytes > 0); @@ -6538,7 +6563,7 @@ ssize_t string_table_lookup(const char * const *table, size_t len, const char *k          for (i = 0; i < len; ++i)                  if (streq_ptr(table[i], key)) -                        return (ssize_t)i; +                        return (ssize_t) i;          return -1;  } @@ -6775,3 +6800,72 @@ int fgetxattr_malloc(int fd, const char *name, char **value) {                          return -errno;          }  } + +int send_one_fd(int transport_fd, int fd) { +        union { +                struct cmsghdr cmsghdr; +                uint8_t buf[CMSG_SPACE(sizeof(int))]; +        } control = {}; +        struct msghdr mh = { +                .msg_control = &control, +                .msg_controllen = sizeof(control), +        }; +        struct cmsghdr *cmsg; +        ssize_t k; + +        assert(transport_fd >= 0); +        assert(fd >= 0); + +        cmsg = CMSG_FIRSTHDR(&mh); +        cmsg->cmsg_level = SOL_SOCKET; +        cmsg->cmsg_type = SCM_RIGHTS; +        cmsg->cmsg_len = CMSG_LEN(sizeof(int)); +        memcpy(CMSG_DATA(cmsg), &fd, sizeof(int)); + +        mh.msg_controllen = CMSG_SPACE(sizeof(int)); +        k = sendmsg(transport_fd, &mh, MSG_NOSIGNAL); +        if (k < 0) +                return -errno; + +        return 0; +} + +int receive_one_fd(int transport_fd) { +        union { +                struct cmsghdr cmsghdr; +                uint8_t buf[CMSG_SPACE(sizeof(int))]; +        } control = {}; +        struct msghdr mh = { +                .msg_control = &control, +                .msg_controllen = sizeof(control), +        }; +        struct cmsghdr *cmsg; +        ssize_t k; + +        assert(transport_fd >= 0); + +        /* +         * Receive a single FD via @transport_fd. We don't care for the +         * transport-type, but the caller must assure that no other CMSG types +         * than SCM_RIGHTS is enabled. We also retrieve a single FD at most, so +         * for packet-based transports, the caller must ensure to send only a +         * single FD per packet. +         * This is best used in combination with send_one_fd(). +         */ + +        k = recvmsg(transport_fd, &mh, MSG_NOSIGNAL | MSG_CMSG_CLOEXEC); +        if (k < 0) +                return -errno; + +        cmsg = CMSG_FIRSTHDR(&mh); +        if (!cmsg || CMSG_NXTHDR(&mh, cmsg) || +            cmsg->cmsg_level != SOL_SOCKET || +            cmsg->cmsg_type != SCM_RIGHTS || +            cmsg->cmsg_len != CMSG_LEN(sizeof(int)) || +            *(const int *)CMSG_DATA(cmsg) < 0) { +                cmsg_close_all(&mh); +                return -EIO; +        } + +        return *(const int *)CMSG_DATA(cmsg); +} | 
