diff options
Diffstat (limited to 'src/systemd-socket-proxyd/socket-proxyd.c')
| -rw-r--r-- | src/systemd-socket-proxyd/socket-proxyd.c | 679 | 
1 files changed, 679 insertions, 0 deletions
| diff --git a/src/systemd-socket-proxyd/socket-proxyd.c b/src/systemd-socket-proxyd/socket-proxyd.c new file mode 100644 index 0000000000..6a11804694 --- /dev/null +++ b/src/systemd-socket-proxyd/socket-proxyd.c @@ -0,0 +1,679 @@ +/*** +  This file is part of systemd. + +  Copyright 2013 David Strauss + +  systemd is free software; you can redistribute it and/or modify it +  under the terms of the GNU Lesser General Public License as published by +  the Free Software Foundation; either version 2.1 of the License, or +  (at your option) any later version. + +  systemd is distributed in the hope that it will be useful, but +  WITHOUT ANY WARRANTY; without even the implied warranty of +  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +  Lesser General Public License for more details. + +  You should have received a copy of the GNU Lesser General Public License +  along with systemd; If not, see <http://www.gnu.org/licenses/>. + ***/ + +#include <errno.h> +#include <fcntl.h> +#include <getopt.h> +#include <netdb.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/socket.h> +#include <sys/un.h> +#include <unistd.h> + +#include <systemd/sd-daemon.h> +#include <systemd/sd-event.h> + +#include "systemd-basic/alloc-util.h" +#include "systemd-basic/fd-util.h" +#include "systemd-basic/log.h" +#include "systemd-basic/path-util.h" +#include "systemd-basic/set.h" +#include "systemd-basic/socket-util.h" +#include "systemd-basic/string-util.h" +#include "systemd-basic/util.h" +#include "systemd-staging/sd-resolve.h" + +#define BUFFER_SIZE (256 * 1024) +#define CONNECTIONS_MAX 256 + +static const char *arg_remote_host = NULL; + +typedef struct Context { +        sd_event *event; +        sd_resolve *resolve; + +        Set *listen; +        Set *connections; +} Context; + +typedef struct Connection { +        Context *context; + +        int server_fd, client_fd; +        int server_to_client_buffer[2]; /* a pipe */ +        int client_to_server_buffer[2]; /* a pipe */ + +        size_t server_to_client_buffer_full, client_to_server_buffer_full; +        size_t server_to_client_buffer_size, client_to_server_buffer_size; + +        sd_event_source *server_event_source, *client_event_source; + +        sd_resolve_query *resolve_query; +} Connection; + +static void connection_free(Connection *c) { +        assert(c); + +        if (c->context) +                set_remove(c->context->connections, c); + +        sd_event_source_unref(c->server_event_source); +        sd_event_source_unref(c->client_event_source); + +        safe_close(c->server_fd); +        safe_close(c->client_fd); + +        safe_close_pair(c->server_to_client_buffer); +        safe_close_pair(c->client_to_server_buffer); + +        sd_resolve_query_unref(c->resolve_query); + +        free(c); +} + +static void context_free(Context *context) { +        sd_event_source *es; +        Connection *c; + +        assert(context); + +        while ((es = set_steal_first(context->listen))) +                sd_event_source_unref(es); + +        while ((c = set_first(context->connections))) +                connection_free(c); + +        set_free(context->listen); +        set_free(context->connections); + +        sd_event_unref(context->event); +        sd_resolve_unref(context->resolve); +} + +static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) { +        int r; + +        assert(c); +        assert(buffer); +        assert(sz); + +        if (buffer[0] >= 0) +                return 0; + +        r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK); +        if (r < 0) +                return log_error_errno(errno, "Failed to allocate pipe buffer: %m"); + +        (void) fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE); + +        r = fcntl(buffer[0], F_GETPIPE_SZ); +        if (r < 0) +                return log_error_errno(errno, "Failed to get pipe buffer size: %m"); + +        assert(r > 0); +        *sz = r; + +        return 0; +} + +static int connection_shovel( +                Connection *c, +                int *from, int buffer[2], int *to, +                size_t *full, size_t *sz, +                sd_event_source **from_source, sd_event_source **to_source) { + +        bool shoveled; + +        assert(c); +        assert(from); +        assert(buffer); +        assert(buffer[0] >= 0); +        assert(buffer[1] >= 0); +        assert(to); +        assert(full); +        assert(sz); +        assert(from_source); +        assert(to_source); + +        do { +                ssize_t z; + +                shoveled = false; + +                if (*full < *sz && *from >= 0 && *to >= 0) { +                        z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); +                        if (z > 0) { +                                *full += z; +                                shoveled = true; +                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { +                                *from_source = sd_event_source_unref(*from_source); +                                *from = safe_close(*from); +                        } else if (errno != EAGAIN && errno != EINTR) +                                return log_error_errno(errno, "Failed to splice: %m"); +                } + +                if (*full > 0 && *to >= 0) { +                        z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); +                        if (z > 0) { +                                *full -= z; +                                shoveled = true; +                        } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { +                                *to_source = sd_event_source_unref(*to_source); +                                *to = safe_close(*to); +                        } else if (errno != EAGAIN && errno != EINTR) +                                return log_error_errno(errno, "Failed to splice: %m"); +                } +        } while (shoveled); + +        return 0; +} + +static int connection_enable_event_sources(Connection *c); + +static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { +        Connection *c = userdata; +        int r; + +        assert(s); +        assert(fd >= 0); +        assert(c); + +        r = connection_shovel(c, +                              &c->server_fd, c->server_to_client_buffer, &c->client_fd, +                              &c->server_to_client_buffer_full, &c->server_to_client_buffer_size, +                              &c->server_event_source, &c->client_event_source); +        if (r < 0) +                goto quit; + +        r = connection_shovel(c, +                              &c->client_fd, c->client_to_server_buffer, &c->server_fd, +                              &c->client_to_server_buffer_full, &c->client_to_server_buffer_size, +                              &c->client_event_source, &c->server_event_source); +        if (r < 0) +                goto quit; + +        /* EOF on both sides? */ +        if (c->server_fd == -1 && c->client_fd == -1) +                goto quit; + +        /* Server closed, and all data written to client? */ +        if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0) +                goto quit; + +        /* Client closed, and all data written to server? */ +        if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0) +                goto quit; + +        r = connection_enable_event_sources(c); +        if (r < 0) +                goto quit; + +        return 1; + +quit: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int connection_enable_event_sources(Connection *c) { +        uint32_t a = 0, b = 0; +        int r; + +        assert(c); + +        if (c->server_to_client_buffer_full > 0) +                b |= EPOLLOUT; +        if (c->server_to_client_buffer_full < c->server_to_client_buffer_size) +                a |= EPOLLIN; + +        if (c->client_to_server_buffer_full > 0) +                a |= EPOLLOUT; +        if (c->client_to_server_buffer_full < c->client_to_server_buffer_size) +                b |= EPOLLIN; + +        if (c->server_event_source) +                r = sd_event_source_set_io_events(c->server_event_source, a); +        else if (c->server_fd >= 0) +                r = sd_event_add_io(c->context->event, &c->server_event_source, c->server_fd, a, traffic_cb, c); +        else +                r = 0; + +        if (r < 0) +                return log_error_errno(r, "Failed to set up server event source: %m"); + +        if (c->client_event_source) +                r = sd_event_source_set_io_events(c->client_event_source, b); +        else if (c->client_fd >= 0) +                r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, b, traffic_cb, c); +        else +                r = 0; + +        if (r < 0) +                return log_error_errno(r, "Failed to set up client event source: %m"); + +        return 0; +} + +static int connection_complete(Connection *c) { +        int r; + +        assert(c); + +        r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size); +        if (r < 0) +                goto fail; + +        r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size); +        if (r < 0) +                goto fail; + +        r = connection_enable_event_sources(c); +        if (r < 0) +                goto fail; + +        return 0; + +fail: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { +        Connection *c = userdata; +        socklen_t solen; +        int error, r; + +        assert(s); +        assert(fd >= 0); +        assert(c); + +        solen = sizeof(error); +        r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen); +        if (r < 0) { +                log_error_errno(errno, "Failed to issue SO_ERROR: %m"); +                goto fail; +        } + +        if (error != 0) { +                log_error_errno(error, "Failed to connect to remote host: %m"); +                goto fail; +        } + +        c->client_event_source = sd_event_source_unref(c->client_event_source); + +        return connection_complete(c); + +fail: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int connection_start(Connection *c, struct sockaddr *sa, socklen_t salen) { +        int r; + +        assert(c); +        assert(sa); +        assert(salen); + +        c->client_fd = socket(sa->sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0); +        if (c->client_fd < 0) { +                log_error_errno(errno, "Failed to get remote socket: %m"); +                goto fail; +        } + +        r = connect(c->client_fd, sa, salen); +        if (r < 0) { +                if (errno == EINPROGRESS) { +                        r = sd_event_add_io(c->context->event, &c->client_event_source, c->client_fd, EPOLLOUT, connect_cb, c); +                        if (r < 0) { +                                log_error_errno(r, "Failed to add connection socket: %m"); +                                goto fail; +                        } + +                        r = sd_event_source_set_enabled(c->client_event_source, SD_EVENT_ONESHOT); +                        if (r < 0) { +                                log_error_errno(r, "Failed to enable oneshot event source: %m"); +                                goto fail; +                        } +                } else { +                        log_error_errno(errno, "Failed to connect to remote host: %m"); +                        goto fail; +                } +        } else { +                r = connection_complete(c); +                if (r < 0) +                        goto fail; +        } + +        return 0; + +fail: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int resolve_cb(sd_resolve_query *q, int ret, const struct addrinfo *ai, void *userdata) { +        Connection *c = userdata; + +        assert(q); +        assert(c); + +        if (ret != 0) { +                log_error("Failed to resolve host: %s", gai_strerror(ret)); +                goto fail; +        } + +        c->resolve_query = sd_resolve_query_unref(c->resolve_query); + +        return connection_start(c, ai->ai_addr, ai->ai_addrlen); + +fail: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int resolve_remote(Connection *c) { + +        static const struct addrinfo hints = { +                .ai_family = AF_UNSPEC, +                .ai_socktype = SOCK_STREAM, +                .ai_flags = AI_ADDRCONFIG +        }; + +        union sockaddr_union sa = {}; +        const char *node, *service; +        int r; + +        if (path_is_absolute(arg_remote_host)) { +                sa.un.sun_family = AF_UNIX; +                strncpy(sa.un.sun_path, arg_remote_host, sizeof(sa.un.sun_path)); +                return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un)); +        } + +        if (arg_remote_host[0] == '@') { +                sa.un.sun_family = AF_UNIX; +                sa.un.sun_path[0] = 0; +                strncpy(sa.un.sun_path+1, arg_remote_host+1, sizeof(sa.un.sun_path)-1); +                return connection_start(c, &sa.sa, SOCKADDR_UN_LEN(sa.un)); +        } + +        service = strrchr(arg_remote_host, ':'); +        if (service) { +                node = strndupa(arg_remote_host, service - arg_remote_host); +                service++; +        } else { +                node = arg_remote_host; +                service = "80"; +        } + +        log_debug("Looking up address info for %s:%s", node, service); +        r = sd_resolve_getaddrinfo(c->context->resolve, &c->resolve_query, node, service, &hints, resolve_cb, c); +        if (r < 0) { +                log_error_errno(r, "Failed to resolve remote host: %m"); +                goto fail; +        } + +        return 0; + +fail: +        connection_free(c); +        return 0; /* ignore errors, continue serving */ +} + +static int add_connection_socket(Context *context, int fd) { +        Connection *c; +        int r; + +        assert(context); +        assert(fd >= 0); + +        if (set_size(context->connections) > CONNECTIONS_MAX) { +                log_warning("Hit connection limit, refusing connection."); +                safe_close(fd); +                return 0; +        } + +        r = set_ensure_allocated(&context->connections, NULL); +        if (r < 0) { +                log_oom(); +                return 0; +        } + +        c = new0(Connection, 1); +        if (!c) { +                log_oom(); +                return 0; +        } + +        c->context = context; +        c->server_fd = fd; +        c->client_fd = -1; +        c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1; +        c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1; + +        r = set_put(context->connections, c); +        if (r < 0) { +                free(c); +                log_oom(); +                return 0; +        } + +        return resolve_remote(c); +} + +static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { +        _cleanup_free_ char *peer = NULL; +        Context *context = userdata; +        int nfd = -1, r; + +        assert(s); +        assert(fd >= 0); +        assert(revents & EPOLLIN); +        assert(context); + +        nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC); +        if (nfd < 0) { +                if (errno != -EAGAIN) +                        log_warning_errno(errno, "Failed to accept() socket: %m"); +        } else { +                getpeername_pretty(nfd, true, &peer); +                log_debug("New connection from %s", strna(peer)); + +                r = add_connection_socket(context, nfd); +                if (r < 0) { +                        log_error_errno(r, "Failed to accept connection, ignoring: %m"); +                        safe_close(fd); +                } +        } + +        r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT); +        if (r < 0) { +                log_error_errno(r, "Error while re-enabling listener with ONESHOT: %m"); +                sd_event_exit(context->event, r); +                return r; +        } + +        return 1; +} + +static int add_listen_socket(Context *context, int fd) { +        sd_event_source *source; +        int r; + +        assert(context); +        assert(fd >= 0); + +        r = set_ensure_allocated(&context->listen, NULL); +        if (r < 0) { +                log_oom(); +                return r; +        } + +        r = sd_is_socket(fd, 0, SOCK_STREAM, 1); +        if (r < 0) +                return log_error_errno(r, "Failed to determine socket type: %m"); +        if (r == 0) { +                log_error("Passed in socket is not a stream socket."); +                return -EINVAL; +        } + +        r = fd_nonblock(fd, true); +        if (r < 0) +                return log_error_errno(r, "Failed to mark file descriptor non-blocking: %m"); + +        r = sd_event_add_io(context->event, &source, fd, EPOLLIN, accept_cb, context); +        if (r < 0) +                return log_error_errno(r, "Failed to add event source: %m"); + +        r = set_put(context->listen, source); +        if (r < 0) { +                log_error_errno(r, "Failed to add source to set: %m"); +                sd_event_source_unref(source); +                return r; +        } + +        /* Set the watcher to oneshot in case other processes are also +         * watching to accept(). */ +        r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT); +        if (r < 0) +                return log_error_errno(r, "Failed to enable oneshot mode: %m"); + +        return 0; +} + +static void help(void) { +        printf("%1$s [HOST:PORT]\n" +               "%1$s [SOCKET]\n\n" +               "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n" +               "  -h --help              Show this help\n" +               "     --version           Show package version\n", +               program_invocation_short_name); +} + +static int parse_argv(int argc, char *argv[]) { + +        enum { +                ARG_VERSION = 0x100, +                ARG_IGNORE_ENV +        }; + +        static const struct option options[] = { +                { "help",       no_argument, NULL, 'h'           }, +                { "version",    no_argument, NULL, ARG_VERSION   }, +                {} +        }; + +        int c; + +        assert(argc >= 0); +        assert(argv); + +        while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) + +                switch (c) { + +                case 'h': +                        help(); +                        return 0; + +                case ARG_VERSION: +                        return version(); + +                case '?': +                        return -EINVAL; + +                default: +                        assert_not_reached("Unhandled option"); +                } + +        if (optind >= argc) { +                log_error("Not enough parameters."); +                return -EINVAL; +        } + +        if (argc != optind+1) { +                log_error("Too many parameters."); +                return -EINVAL; +        } + +        arg_remote_host = argv[optind]; +        return 1; +} + +int main(int argc, char *argv[]) { +        Context context = {}; +        int r, n, fd; + +        log_parse_environment(); +        log_open(); + +        r = parse_argv(argc, argv); +        if (r <= 0) +                goto finish; + +        r = sd_event_default(&context.event); +        if (r < 0) { +                log_error_errno(r, "Failed to allocate event loop: %m"); +                goto finish; +        } + +        r = sd_resolve_default(&context.resolve); +        if (r < 0) { +                log_error_errno(r, "Failed to allocate resolver: %m"); +                goto finish; +        } + +        r = sd_resolve_attach_event(context.resolve, context.event, 0); +        if (r < 0) { +                log_error_errno(r, "Failed to attach resolver: %m"); +                goto finish; +        } + +        sd_event_set_watchdog(context.event, true); + +        n = sd_listen_fds(1); +        if (n < 0) { +                log_error("Failed to receive sockets from parent."); +                r = n; +                goto finish; +        } else if (n == 0) { +                log_error("Didn't get any sockets passed in."); +                r = -EINVAL; +                goto finish; +        } + +        for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) { +                r = add_listen_socket(&context, fd); +                if (r < 0) +                        goto finish; +        } + +        r = sd_event_loop(context.event); +        if (r < 0) { +                log_error_errno(r, "Failed to run event loop: %m"); +                goto finish; +        } + +finish: +        context_free(&context); + +        return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS; +} | 
