diff options
author | Lennart Poettering <lennart@poettering.net> | 2013-11-06 22:40:54 +0100 |
---|---|---|
committer | Lennart Poettering <lennart@poettering.net> | 2013-11-06 23:03:12 +0100 |
commit | 8569a77629949b7818d00eba8eea1d05e2d1fc32 (patch) | |
tree | 0f807b151ed4649032012084c83a8ed95669e8ca /src/socket-proxy/socket-proxyd.c | |
parent | 175a3d25d0e8596d4ba0759aea3f89ee228e7d6d (diff) |
socket-proxyd: rework to support multiple sockets and splice()-based zero-copy network IO
This also drops --ignore-env, which can't really work anymore if we
allow multiple fds. Also adds support for pretty printing of peer
identities for debug purposes, and abstract namespace UNIX sockets. Also
ensures that we never take more connections than a certain limit.
Diffstat (limited to 'src/socket-proxy/socket-proxyd.c')
-rw-r--r-- | src/socket-proxy/socket-proxyd.c | 792 |
1 files changed, 409 insertions, 383 deletions
diff --git a/src/socket-proxy/socket-proxyd.c b/src/socket-proxy/socket-proxyd.c index a449b0eec4..1c64c0e2e5 100644 --- a/src/socket-proxy/socket-proxyd.c +++ b/src/socket-proxy/socket-proxyd.c @@ -38,480 +38,523 @@ #include "util.h" #include "event-util.h" #include "build.h" +#include "set.h" +#include "path-util.h" + +#define BUFFER_SIZE (256 * 1024) +#define CONNECTIONS_MAX 256 -#define BUFFER_SIZE 16384 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop) +DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo); -unsigned int total_clients = 0; +typedef struct Context { + Set *listen; + Set *connections; +} Context; -DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo); +typedef struct Connection { + int server_fd, client_fd; + int server_to_client_buffer[2]; /* a pipe */ + int client_to_server_buffer[2]; /* a pipe */ -struct proxy { - int listen_fd; - bool ignore_env; - bool remote_is_inet; - const char *remote_host; - const char *remote_service; -}; + size_t server_to_client_buffer_full, client_to_server_buffer_full; + size_t server_to_client_buffer_size, client_to_server_buffer_size; + + sd_event_source *server_event_source, *client_event_source; +} Connection; -struct connection { - int fd; - uint32_t events; - sd_event_source *w; - struct connection *c_destination; - size_t buffer_filled_len; - size_t buffer_sent_len; - char buffer[BUFFER_SIZE]; +union sockaddr_any { + struct sockaddr sa; + struct sockaddr_un un; + struct sockaddr_in in; + struct sockaddr_in6 in6; + struct sockaddr_storage storage; }; -static void free_connection(struct connection *c) { - if (c != NULL) { - log_debug("Freeing fd=%d (conn %p).", c->fd, c); - sd_event_source_unref(c->w); - if (c->fd > 0) - close_nointr_nofail(c->fd); - free(c); - } -} +static const char *arg_remote_host = NULL; -static int add_event_to_connection(struct connection *c, uint32_t events) { - int r; +static void connection_free(Connection *c) { + assert(c); - log_debug("Have revents=%d. Adding revents=%d.", c->events, events); + sd_event_source_unref(c->server_event_source); + sd_event_source_unref(c->client_event_source); - c->events |= events; + if (c->server_fd >= 0) + close_nointr_nofail(c->server_fd); + if (c->client_fd >= 0) + close_nointr_nofail(c->client_fd); - r = sd_event_source_set_io_events(c->w, c->events); - if (r < 0) { - log_error("Error %d setting revents: %s", r, strerror(-r)); - return r; - } + close_pipe(c->server_to_client_buffer); + close_pipe(c->client_to_server_buffer); - r = sd_event_source_set_enabled(c->w, SD_EVENT_ON); - if (r < 0) { - log_error("Error %d enabling source: %s", r, strerror(-r)); - return r; - } + free(c); +} - return 0; +static void context_free(Context *context) { + sd_event_source *es; + Connection *c; + + assert(context); + + while ((es = set_steal_first(context->listen))) + sd_event_source_unref(es); + + while ((c = set_steal_first(context->connections))) + connection_free(c); + + set_free(context->listen); + set_free(context->connections); } -static int remove_event_from_connection(struct connection *c, uint32_t events) { +static int get_remote_sockaddr(union sockaddr_any *sa, socklen_t *salen) { int r; - log_debug("Have revents=%d. Removing revents=%d.", c->events, events); + assert(sa); + assert(salen); - c->events &= ~events; + if (path_is_absolute(arg_remote_host)) { + sa->un.sun_family = AF_UNIX; + strncpy(sa->un.sun_path, arg_remote_host, sizeof(sa->un.sun_path)-1); + sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0; - r = sd_event_source_set_io_events(c->w, c->events); - if (r < 0) { - log_error("Error %d setting revents: %s", r, strerror(-r)); - return r; - } + *salen = offsetof(union sockaddr_any, un.sun_path) + strlen(sa->un.sun_path); - if (c->events == 0) { - r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF); - if (r < 0) { - log_error("Error %d disabling source: %s", r, strerror(-r)); - return r; + } else if (arg_remote_host[0] == '@') { + sa->un.sun_family = AF_UNIX; + sa->un.sun_path[0] = 0; + strncpy(sa->un.sun_path+1, arg_remote_host+1, sizeof(sa->un.sun_path)-2); + sa->un.sun_path[sizeof(sa->un.sun_path)-1] = 0; + + *salen = offsetof(union sockaddr_any, un.sun_path) + 1 + strlen(sa->un.sun_path + 1); + + } else { + _cleanup_freeaddrinfo_ struct addrinfo *result = NULL; + const char *node, *service; + + struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM, + .ai_flags = AI_ADDRCONFIG + }; + + service = strrchr(arg_remote_host, ':'); + if (service) { + node = strndupa(arg_remote_host, service - arg_remote_host); + service ++; + } else { + node = arg_remote_host; + service = "80"; } - } - return 0; -} + log_debug("Looking up address info for %s:%s", node, service); + r = getaddrinfo(node, service, &hints, &result); + if (r != 0) { + log_error("Failed to resolve host %s:%s: %s", node, service, gai_strerror(r)); + return -EHOSTUNREACH; + } -static int send_buffer(struct connection *sender) { - struct connection *receiver = sender->c_destination; - ssize_t len; - int r = 0; - - /* We cannot assume that even a partial send() indicates that - * the next send() will return EAGAIN or EWOULDBLOCK. Loop until - * it does. */ - while (sender->buffer_filled_len > sender->buffer_sent_len) { - len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0); - log_debug("send(%d, ...)=%zd", receiver->fd, len); - if (len < 0) { - if (errno != EWOULDBLOCK && errno != EAGAIN) { - log_error("Error %d in send to fd=%d: %m", errno, receiver->fd); - return -errno; - } - else { - /* send() is in a would-block state. */ - break; - } + assert(result); + if (result->ai_addrlen > sizeof(union sockaddr_any)) { + log_error("Address too long."); + return -E2BIG; } - /* len < 0 can't occur here. len == 0 is possible but - * undefined behavior for nonblocking send(). */ - assert(len > 0); - sender->buffer_sent_len += len; + memcpy(sa, result->ai_addr, result->ai_addrlen); + *salen = result->ai_addrlen; } - log_debug("send(%d, ...) completed with %zu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len); - - /* Detect a would-block state or partial send. */ - if (sender->buffer_filled_len > sender->buffer_sent_len) { + return 0; +} - /* If the buffer is full, disable events coming for recv. */ - if (sender->buffer_filled_len == BUFFER_SIZE) { - r = remove_event_from_connection(sender, EPOLLIN); - if (r < 0) { - log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); - return r; - } - } +static int connection_create_pipes(Connection *c, int buffer[2], size_t *sz) { + int r; - /* Watch for when the recipient can be sent data again. */ - r = add_event_to_connection(receiver, EPOLLOUT); - if (r < 0) { - log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); - return r; - } - log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd); - return r; - } + assert(c); + assert(buffer); + assert(sz); - /* If we sent everything without any issues (would-block or - * partial send), the buffer is now empty. */ - sender->buffer_filled_len = 0; - sender->buffer_sent_len = 0; + if (buffer[0] >= 0) + return 0; - /* Enable the sender's receive watcher, in case the buffer was - * full and we disabled it. */ - r = add_event_to_connection(sender, EPOLLIN); + r = pipe2(buffer, O_CLOEXEC|O_NONBLOCK); if (r < 0) { - log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); - return r; + log_error("Failed to allocate pipe buffer: %m"); + return -errno; } - /* Disable the other side's send watcher, as we have no data to send now. */ - r = remove_event_from_connection(receiver, EPOLLOUT); + fcntl(buffer[0], F_SETPIPE_SZ, BUFFER_SIZE); + + r = fcntl(buffer[0], F_GETPIPE_SZ); if (r < 0) { - log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); - return r; + log_error("Failed to get pipe buffer size: %m"); + return -errno; } + assert(r > 0); + *sz = r; + return 0; } -static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct connection *c = (struct connection *) userdata; - int r = 0; - ssize_t len; - - assert(revents & (EPOLLIN | EPOLLOUT)); - assert(fd == c->fd); - assert(s == c->w); - - log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c); - - if (revents & EPOLLIN) { - log_debug("About to recv up to %zu bytes from fd=%d (%zu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len); - - /* Receive until the buffer's full, there's no more data, - * or the client/server disconnects. */ - while (c->buffer_filled_len < BUFFER_SIZE) { - len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0); - log_debug("recv(%d, ...)=%zd", fd, len); - if (len < 0) { - if (errno != EWOULDBLOCK && errno != EAGAIN) { - log_error("Error %d in recv from fd=%d: %m", errno, fd); - return -errno; - } else { - /* recv() is in a blocking state. */ - break; - } - } - else if (len == 0) { - log_debug("Clean disconnection from fd=%d", fd); - total_clients--; - free_connection(c->c_destination); - free_connection(c); - return 0; +static int connection_shovel( + Connection *c, + int *from, int buffer[2], int *to, + size_t *full, size_t *sz, + sd_event_source **from_source, sd_event_source **to_source) { + + bool shoveled; + + assert(c); + assert(from); + assert(buffer); + assert(buffer[0] >= 0); + assert(buffer[1] >= 0); + assert(to); + assert(full); + assert(sz); + assert(from_source); + assert(to_source); + + do { + ssize_t z; + + shoveled = false; + + if (*full < *sz && *from >= 0 && *to >= 0) { + z = splice(*from, NULL, buffer[1], NULL, *sz - *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); + if (z > 0) { + *full += z; + shoveled = true; + } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { + *from_source = sd_event_source_unref(*from_source); + close_nointr_nofail(*from); + *from = -1; + } else if (errno != EAGAIN && errno != EINTR) { + log_error("Failed to splice: %m"); + return -errno; } - - assert(len > 0); - log_debug("Recording that the buffer got %zd more bytes full.", len); - c->buffer_filled_len += len; - log_debug("Buffer now has %zu bytes full.", c->buffer_filled_len); } - /* Try sending the data immediately. */ - return send_buffer(c); - } - else { - return send_buffer(c->c_destination); - } + if (*full > 0 && *to >= 0) { + z = splice(buffer[0], NULL, *to, NULL, *full, SPLICE_F_MOVE|SPLICE_F_NONBLOCK); + if (z > 0) { + *full -= z; + shoveled = true; + } else if (z == 0 || errno == EPIPE || errno == ECONNRESET) { + *to_source = sd_event_source_unref(*to_source); + close_nointr_nofail(*to); + *to = -1; + } else if (errno != EAGAIN && errno != EINTR) { + log_error("Failed to splice: %m"); + return -errno; + } + } + } while (shoveled); - return r; + return 0; } -/* Once sending to the server is ready, set up the real watchers. */ -static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct sd_event *e = NULL; - struct connection *c_server_to_client = (struct connection *) userdata; - struct connection *c_client_to_server = c_server_to_client->c_destination; +static int connection_enable_event_sources(Connection *c, sd_event *event); + +static int traffic_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + Connection *c = userdata; int r; - assert(revents & EPOLLOUT); + assert(s); + assert(fd >= 0); + assert(c); - e = sd_event_get(s); + r = connection_shovel(c, + &c->server_fd, c->server_to_client_buffer, &c->client_fd, + &c->server_to_client_buffer_full, &c->server_to_client_buffer_size, + &c->server_event_source, &c->client_event_source); + if (r < 0) + goto quit; - /* Cancel the initial write watcher for the server. */ - sd_event_source_unref(s); + r = connection_shovel(c, + &c->client_fd, c->client_to_server_buffer, &c->server_fd, + &c->client_to_server_buffer_full, &c->client_to_server_buffer_size, + &c->client_event_source, &c->server_event_source); + if (r < 0) + goto quit; - log_debug("Connected to server. Initializing watchers for receiving data."); + /* EOF on both sides? */ + if (c->server_fd == -1 && c->client_fd == -1) + goto quit; - /* A recv watcher for the server. */ - r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w); - if (r < 0) { - log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); - goto fail; - } - c_server_to_client->events = EPOLLIN; + /* Server closed, and all data written to client? */ + if (c->server_fd == -1 && c->server_to_client_buffer_full <= 0) + goto quit; - /* A recv watcher for the client. */ - r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w); - if (r < 0) { - log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); - goto fail; - } - c_client_to_server->events = EPOLLIN; + /* Client closed, and all data written to server? */ + if (c->client_fd == -1 && c->client_to_server_buffer_full <= 0) + goto quit; - goto finish; + r = connection_enable_event_sources(c, sd_event_get(s)); + if (r < 0) + goto quit; -fail: - free_connection(c_client_to_server); - free_connection(c_server_to_client); + return 1; -finish: - return r; +quit: + connection_free(c); + return 0; /* ignore errors, continue serving */ } -static int get_server_connection_fd(const struct proxy *proxy) { - int server_fd; - int r = -EBADF; - int len; +static int connection_enable_event_sources(Connection *c, sd_event *event) { + uint32_t a = 0, b = 0; + int r; - if (proxy->remote_is_inet) { - int s; - _cleanup_freeaddrinfo_ struct addrinfo *result = NULL; - struct addrinfo hints = {.ai_family = AF_UNSPEC, - .ai_socktype = SOCK_STREAM, - .ai_flags = AI_PASSIVE}; - - log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service); - s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result); - if (s != 0) { - log_error("getaddrinfo error (%d): %s", s, gai_strerror(s)); - return r; - } + assert(c); + assert(event); - if (result == NULL) { - log_error("getaddrinfo: no result"); - return r; - } + if (c->server_to_client_buffer_full > 0) + b |= EPOLLOUT; + if (c->server_to_client_buffer_full < c->server_to_client_buffer_size) + a |= EPOLLIN; - /* @TODO: Try connecting to all results instead of just the first. */ - server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol); - if (server_fd < 0) { - log_error("Error %d creating socket: %m", errno); - return r; - } + if (c->client_to_server_buffer_full > 0) + a |= EPOLLOUT; + if (c->client_to_server_buffer_full < c->client_to_server_buffer_size) + b |= EPOLLIN; - r = connect(server_fd, result->ai_addr, result->ai_addrlen); - /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */ - if (r < 0 && errno != EINPROGRESS) { - log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service); - return r; - } + if (c->server_event_source) + r = sd_event_source_set_io_events(c->server_event_source, a); + else if (c->server_fd >= 0) + r = sd_event_add_io(event, c->server_fd, a, traffic_cb, c, &c->server_event_source); + else + r = 0; + + if (r < 0) { + log_error("Failed to set up server event source: %s", strerror(-r)); + return r; } - else { - struct sockaddr_un remote; - server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0); - if (server_fd < 0) { - log_error("Error %d creating socket: %m", errno); - return -EBADFD; - } + if (c->client_event_source) + r = sd_event_source_set_io_events(c->client_event_source, b); + else if (c->client_fd >= 0) + r = sd_event_add_io(event, c->client_fd, b, traffic_cb, c, &c->client_event_source); + else + r = 0; - remote.sun_family = AF_UNIX; - strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path)); - len = strlen(remote.sun_path) + sizeof(remote.sun_family); - r = connect(server_fd, (struct sockaddr *) &remote, len); - if (r < 0 && errno != EINPROGRESS) { - log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host); - return -EBADFD; - } + if (r < 0) { + log_error("Failed to set up server event source: %s", strerror(-r)); + return r; } - log_debug("Server connection is fd=%d", server_fd); - return server_fd; + return 0; } -static int do_accept(sd_event *e, struct proxy *p, int fd) { - struct connection *c_server_to_client = NULL; - struct connection *c_client_to_server = NULL; - int r = 0; - union sockaddr_union sa; - socklen_t salen = sizeof(sa); - int client_fd, server_fd; - - client_fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC); - if (client_fd < 0) { - if (errno == EAGAIN || errno == EWOULDBLOCK) - return -errno; - log_error("Error %d accepting client connection: %m", errno); - r = -errno; +static int connect_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { + Connection *c = userdata; + socklen_t solen; + int error, r; + + assert(s); + assert(fd >= 0); + assert(c); + + solen = sizeof(error); + r = getsockopt(fd, SOL_SOCKET, SO_ERROR, &error, &solen); + if (r < 0) { + log_error("Failed to issue SO_ERROR: %m"); goto fail; } - server_fd = get_server_connection_fd(p); - if (server_fd < 0) { - log_error("Error initiating server connection."); - r = server_fd; + if (error != 0) { + log_error("Failed to connect to remote host: %s", strerror(error)); goto fail; } - c_client_to_server = new0(struct connection, 1); - if (c_client_to_server == NULL) { - log_oom(); + c->client_event_source = sd_event_source_unref(c->client_event_source); + + r = connection_create_pipes(c, c->server_to_client_buffer, &c->server_to_client_buffer_size); + if (r < 0) goto fail; - } - c_server_to_client = new0(struct connection, 1); - if (c_server_to_client == NULL) { - log_oom(); + r = connection_create_pipes(c, c->client_to_server_buffer, &c->client_to_server_buffer_size); + if (r < 0) goto fail; - } - c_client_to_server->fd = client_fd; - c_server_to_client->fd = server_fd; + r = connection_enable_event_sources(c, sd_event_get(s)); + if (r < 0) + goto fail; - if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) { - char sa_str[INET6_ADDRSTRLEN]; - const char *success; + return 0; - success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN); - if (success == NULL) - log_warning("Error %d calling inet_ntop: %m", errno); - else - log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd); - } - else { - log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd); +fail: + connection_free(c); + return 0; /* ignore errors, continue serving */ +} + +static int add_connection_socket(Context *context, sd_event *event, int fd) { + union sockaddr_any sa = {}; + socklen_t salen; + Connection *c; + int r; + + assert(context); + assert(event); + assert(fd >= 0); + + if (set_size(context->connections) > CONNECTIONS_MAX) { + log_warning("Hit connection limit, refusing connection."); + close_nointr_nofail(fd); + return 0; } - total_clients++; - log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients); - log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client); + r = set_ensure_allocated(&context->connections, trivial_hash_func, trivial_compare_func); + if (r < 0) + return log_oom(); - /* Initialize watcher for send to server; this shows connectivity. */ - r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w); - if (r < 0) { - log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); + c = new0(Connection, 1); + if (!c) + return log_oom(); + + c->server_fd = fd; + c->client_fd = -1; + c->server_to_client_buffer[0] = c->server_to_client_buffer[1] = -1; + c->client_to_server_buffer[0] = c->client_to_server_buffer[1] = -1; + + r = get_remote_sockaddr(&sa, &salen); + if (r < 0) + goto fail; + + c->client_fd = socket(sa.sa.sa_family, SOCK_STREAM|SOCK_NONBLOCK|SOCK_CLOEXEC, 0); + if (c->client_fd < 0) { + log_error("Failed to get remote socket: %m"); goto fail; } - /* Allow lookups of the opposite connection. */ - c_server_to_client->c_destination = c_client_to_server; - c_client_to_server->c_destination = c_server_to_client; + r = connect(c->client_fd, &sa.sa, salen); + if (r < 0) { + if (errno == EINPROGRESS) { + r = sd_event_add_io(event, c->client_fd, EPOLLOUT, connect_cb, c, &c->client_event_source); + if (r < 0) { + log_error("Failed to add connection socket: %s", strerror(-r)); + goto fail; + } + } else { + log_error("Failed to connect to remote host: %m"); + goto fail; + } + } else { + r = connection_enable_event_sources(c, event); + if (r < 0) + goto fail; + } - goto finish; + return 0; fail: - log_warning("Accepting a client connection or connecting to the server failed."); - free_connection(c_client_to_server); - free_connection(c_server_to_client); - -finish: - return r; + connection_free(c); + return 0; /* ignore non-OOM errors, continue serving */ } static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { - struct proxy *p = (struct proxy *) userdata; - sd_event *e = NULL; - int r = 0; + Context *context = userdata; + int nfd = -1, r; + assert(s); + assert(fd >= 0); assert(revents & EPOLLIN); + assert(context); + + nfd = accept4(fd, NULL, NULL, SOCK_NONBLOCK|SOCK_CLOEXEC); + if (nfd >= 0) { + _cleanup_free_ char *peer = NULL; - e = sd_event_get(s); + getpeername_pretty(nfd, &peer); + log_debug("New connection from %s", strna(peer)); - for (;;) { - r = do_accept(e, p, fd); - if (r == -EAGAIN || r == -EWOULDBLOCK) - break; + r = add_connection_socket(context, sd_event_get(s), nfd); if (r < 0) { - log_error("Error %d while trying to accept: %s", r, strerror(-r)); - break; + close_nointr_nofail(fd); + return r; } - } - /* Re-enable the watcher. */ + } else if (errno != -EAGAIN) + log_warning("Failed to accept() socket: %m"); + r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT); if (r < 0) { log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r)); return r; } - /* Preserve the main loop even if a single accept() fails. */ return 1; } -static int run_main_loop(struct proxy *proxy) { - _cleanup_event_source_unref_ sd_event_source *w_accept = NULL; - _cleanup_event_unref_ sd_event *e = NULL; - int r = EXIT_SUCCESS; +static int add_listen_socket(Context *context, sd_event *event, int fd) { + sd_event_source *source; + int r; + + assert(context); + assert(event); + assert(fd >= 0); - r = sd_event_new(&e); + log_info("Listening on %i", fd); + + r = set_ensure_allocated(&context->listen, trivial_hash_func, trivial_compare_func); if (r < 0) { - log_error("Failed to allocate event loop: %s", strerror(-r)); + log_oom(); + return r; + } + + r = sd_is_socket(fd, 0, SOCK_STREAM, 1); + if (r < 0) { + log_error("Failed to determine socket type: %s", strerror(-r)); return r; } + if (r == 0) { + log_error("Passed in socket is not a stream socket."); + return -EINVAL; + } - r = fd_nonblock(proxy->listen_fd, true); + r = fd_nonblock(fd, true); if (r < 0) { - log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r)); + log_error("Failed to mark file descriptor non-blocking: %s", strerror(-r)); return r; } - log_debug("Initializing main listener fd=%d", proxy->listen_fd); + r = sd_event_add_io(event, fd, EPOLLIN, accept_cb, context, &source); + if (r < 0) { + log_error("Failed to add event source: %s", strerror(-r)); + return r; + } - r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept); + r = set_put(context->listen, source); if (r < 0) { - log_error("Error %d while adding event IO source: %s", r, strerror(-r)); + log_error("Failed to add source to set: %s", strerror(-r)); + sd_event_source_unref(source); return r; } /* Set the watcher to oneshot in case other processes are also * watching to accept(). */ - r = sd_event_source_set_enabled(w_accept, SD_EVENT_ONESHOT); + r = sd_event_source_set_enabled(source, SD_EVENT_ONESHOT); if (r < 0) { - log_error("Error %d while setting event IO source to ONESHOT: %s", r, strerror(-r)); + log_error("Failed to enable oneshot mode: %s", strerror(-r)); return r; } - log_debug("Initialized main listener. Entering loop."); - - return sd_event_loop(e); + return 0; } static int help(void) { - printf("%s hostname-or-ip port-or-service\n" - "%s unix-domain-socket-path\n\n" - "Inherit a socket. Bidirectionally proxy.\n\n" - " -h --help Show this help\n" - " --version Print version and exit\n" - " --ignore-env Ignore expected systemd environment\n", + printf("%s [HOST:PORT]\n" + "%s [SOCKET]\n\n" + "Bidirectionally proxy local sockets to another (possibly remote) socket.\n\n" + " -h --help Show this help\n" + " --version Show package version\n", program_invocation_short_name, program_invocation_short_name); return 0; } -static int parse_argv(int argc, char *argv[], struct proxy *p) { +static int parse_argv(int argc, char *argv[]) { enum { ARG_VERSION = 0x100, @@ -521,7 +564,6 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { static const struct option options[] = { { "help", no_argument, NULL, 'h' }, { "version", no_argument, NULL, ARG_VERSION }, - { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV}, {} }; @@ -542,10 +584,6 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { puts(SYSTEMD_FEATURES); return 0; - case ARG_IGNORE_ENV: - p->ignore_env = true; - continue; - case '?': return -EINVAL; @@ -554,75 +592,63 @@ static int parse_argv(int argc, char *argv[], struct proxy *p) { } } - if (optind + 1 != argc && optind + 2 != argc) { - log_error("Incorrect number of positional arguments."); - help(); + if (optind >= argc) { + log_error("Not enough parameters."); return -EINVAL; } - p->remote_host = argv[optind]; - assert(p->remote_host); - - p->remote_is_inet = p->remote_host[0] != '/'; - - if (optind == argc - 2) { - if (!p->remote_is_inet) { - log_error("A port or service is not allowed for Unix socket destinations."); - help(); - return -EINVAL; - } - p->remote_service = argv[optind + 1]; - assert(p->remote_service); - } else if (p->remote_is_inet) { - log_error("A port or service is required for IP destinations."); - help(); + if (argc != optind+1) { + log_error("Too many parameters."); return -EINVAL; } + arg_remote_host = argv[optind]; return 1; } int main(int argc, char *argv[]) { - struct proxy p = {}; - int r; + _cleanup_event_unref_ sd_event *event = NULL; + Context context = {}; + int r, n, fd; log_parse_environment(); log_open(); - r = parse_argv(argc, argv, &p); + r = parse_argv(argc, argv); if (r <= 0) goto finish; - p.listen_fd = SD_LISTEN_FDS_START; + r = sd_event_new(&event); + if (r < 0) { + log_error("Failed to allocate event loop: %s", strerror(-r)); + goto finish; + } - if (!p.ignore_env) { - int n; - n = sd_listen_fds(1); - if (n == 0) { - log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?"); - r = EXIT_FAILURE; - goto finish; - } else if (n < 0) { - log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n)); - r = EXIT_FAILURE; - goto finish; - } else if (n > 1) { - log_error("Can't listen on more than one socket."); - r = EXIT_FAILURE; + n = sd_listen_fds(1); + if (n < 0) { + log_error("Failed to receive sockets from parent."); + r = n; + goto finish; + } else if (n == 0) { + log_error("Didn't get any sockets passed in."); + r = -EINVAL; + goto finish; + } + + for (fd = SD_LISTEN_FDS_START; fd < SD_LISTEN_FDS_START + n; fd++) { + r = add_listen_socket(&context, event, fd); + if (r < 0) goto finish; - } } - r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1); + r = sd_event_loop(event); if (r < 0) { - log_error("Error %d while checking inherited socket: %s", r, strerror(-r)); + log_error("Failed to run event loop: %s", strerror(-r)); goto finish; } - log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd); - - r = run_main_loop(&p); - finish: + context_free(&context); + return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS; } |