/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/ /*** This file is part of systemd. Copyright 2013 David Strauss systemd is free software; you can redistribute it and/or modify it under the terms of the GNU Lesser General Public License as published by the Free Software Foundation; either version 2.1 of the License, or (at your option) any later version. systemd is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details. You should have received a copy of the GNU Lesser General Public License along with systemd; If not, see <http://www.gnu.org/licenses/>. ***/ #include <arpa/inet.h> #include <errno.h> #include <getopt.h> #include <stdio.h> #include <stdlib.h> #include <string.h> #include <netdb.h> #include <sys/fcntl.h> #include <sys/socket.h> #include <sys/un.h> #include <unistd.h> #include "sd-daemon.h" #include "sd-event.h" #include "log.h" #include "socket-util.h" #include "util.h" #include "event-util.h" #define BUFFER_SIZE 16384 #define _cleanup_freeaddrinfo_ _cleanup_(freeaddrinfop) unsigned int total_clients = 0; DEFINE_TRIVIAL_CLEANUP_FUNC(struct addrinfo *, freeaddrinfo); struct proxy { int listen_fd; bool ignore_env; bool remote_is_inet; const char *remote_host; const char *remote_service; }; struct connection { int fd; uint32_t events; sd_event_source *w; struct connection *c_destination; size_t buffer_filled_len; size_t buffer_sent_len; char buffer[BUFFER_SIZE]; }; static void free_connection(struct connection *c) { if (c != NULL) { log_debug("Freeing fd=%d (conn %p).", c->fd, c); sd_event_source_unref(c->w); if (c->fd > 0) close_nointr_nofail(c->fd); free(c); } } static int add_event_to_connection(struct connection *c, uint32_t events) { int r; log_debug("Have revents=%d. Adding revents=%d.", c->events, events); c->events |= events; r = sd_event_source_set_io_events(c->w, c->events); if (r < 0) { log_error("Error %d setting revents: %s", r, strerror(-r)); return r; } r = sd_event_source_set_enabled(c->w, SD_EVENT_ON); if (r < 0) { log_error("Error %d enabling source: %s", r, strerror(-r)); return r; } return 0; } static int remove_event_from_connection(struct connection *c, uint32_t events) { int r; log_debug("Have revents=%d. Removing revents=%d.", c->events, events); c->events &= ~events; r = sd_event_source_set_io_events(c->w, c->events); if (r < 0) { log_error("Error %d setting revents: %s", r, strerror(-r)); return r; } if (c->events == 0) { r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF); if (r < 0) { log_error("Error %d disabling source: %s", r, strerror(-r)); return r; } } return 0; } static int send_buffer(struct connection *sender) { struct connection *receiver = sender->c_destination; ssize_t len; int r = 0; /* We cannot assume that even a partial send() indicates that * the next send() will return EAGAIN or EWOULDBLOCK. Loop until * it does. */ while (sender->buffer_filled_len > sender->buffer_sent_len) { len = send(receiver->fd, sender->buffer + sender->buffer_sent_len, sender->buffer_filled_len - sender->buffer_sent_len, 0); log_debug("send(%d, ...)=%ld", receiver->fd, len); if (len < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { log_error("Error %d in send to fd=%d: %m", errno, receiver->fd); return -errno; } else { /* send() is in a would-block state. */ break; } } /* len < 0 can't occur here. len == 0 is possible but * undefined behavior for nonblocking send(). */ assert(len > 0); sender->buffer_sent_len += len; } log_debug("send(%d, ...) completed with %lu bytes still buffered.", receiver->fd, sender->buffer_filled_len - sender->buffer_sent_len); /* Detect a would-block state or partial send. */ if (sender->buffer_filled_len > sender->buffer_sent_len) { /* If the buffer is full, disable events coming for recv. */ if (sender->buffer_filled_len == BUFFER_SIZE) { r = remove_event_from_connection(sender, EPOLLIN); if (r < 0) { log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); return r; } } /* Watch for when the recipient can be sent data again. */ r = add_event_to_connection(receiver, EPOLLOUT); if (r < 0) { log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); return r; } log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd); return r; } /* If we sent everything without any issues (would-block or * partial send), the buffer is now empty. */ sender->buffer_filled_len = 0; sender->buffer_sent_len = 0; /* Enable the sender's receive watcher, in case the buffer was * full and we disabled it. */ r = add_event_to_connection(sender, EPOLLIN); if (r < 0) { log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); return r; } /* Disable the other side's send watcher, as we have no data to send now. */ r = remove_event_from_connection(receiver, EPOLLOUT); if (r < 0) { log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); return r; } return 0; } static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { struct connection *c = (struct connection *) userdata; int r = 0; ssize_t len; assert(revents & (EPOLLIN | EPOLLOUT)); assert(fd == c->fd); assert(s == c->w); log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c); if (revents & EPOLLIN) { log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len); /* Receive until the buffer's full, there's no more data, * or the client/server disconnects. */ while (c->buffer_filled_len < BUFFER_SIZE) { len = recv(fd, c->buffer + c->buffer_filled_len, BUFFER_SIZE - c->buffer_filled_len, 0); log_debug("recv(%d, ...)=%ld", fd, len); if (len < 0) { if (errno != EWOULDBLOCK && errno != EAGAIN) { log_error("Error %d in recv from fd=%d: %m", errno, fd); return -errno; } else { /* recv() is in a blocking state. */ break; } } else if (len == 0) { log_debug("Clean disconnection from fd=%d", fd); total_clients--; free_connection(c->c_destination); free_connection(c); return 0; } assert(len > 0); log_debug("Recording that the buffer got %ld more bytes full.", len); c->buffer_filled_len += len; log_debug("Buffer now has %ld bytes full.", c->buffer_filled_len); } /* Try sending the data immediately. */ return send_buffer(c); } else { return send_buffer(c->c_destination); } return r; } /* Once sending to the server is ready, set up the real watchers. */ static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { struct sd_event *e = NULL; struct connection *c_server_to_client = (struct connection *) userdata; struct connection *c_client_to_server = c_server_to_client->c_destination; int r; assert(revents & EPOLLOUT); e = sd_event_get(s); /* Cancel the initial write watcher for the server. */ sd_event_source_unref(s); log_debug("Connected to server. Initializing watchers for receiving data."); /* A recv watcher for the server. */ r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w); if (r < 0) { log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); goto fail; } c_server_to_client->events = EPOLLIN; /* A recv watcher for the client. */ r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w); if (r < 0) { log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); goto fail; } c_client_to_server->events = EPOLLIN; goto finish; fail: free_connection(c_client_to_server); free_connection(c_server_to_client); finish: return r; } static int get_server_connection_fd(const struct proxy *proxy) { int server_fd; int r = -EBADF; int len; if (proxy->remote_is_inet) { int s; _cleanup_freeaddrinfo_ struct addrinfo *result = NULL; struct addrinfo hints = {.ai_family = AF_UNSPEC, .ai_socktype = SOCK_STREAM, .ai_flags = AI_PASSIVE}; log_debug("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service); s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result); if (s != 0) { log_error("getaddrinfo error (%d): %s", s, gai_strerror(s)); return r; } if (result == NULL) { log_error("getaddrinfo: no result"); return r; } /* @TODO: Try connecting to all results instead of just the first. */ server_fd = socket(result->ai_family, result->ai_socktype | SOCK_NONBLOCK, result->ai_protocol); if (server_fd < 0) { log_error("Error %d creating socket: %m", errno); return r; } r = connect(server_fd, result->ai_addr, result->ai_addrlen); /* Ignore EINPROGRESS errors because they're expected for a nonblocking socket. */ if (r < 0 && errno != EINPROGRESS) { log_error("Error %d while connecting to socket %s:%s: %m", errno, proxy->remote_host, proxy->remote_service); return r; } } else { struct sockaddr_un remote; server_fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0); if (server_fd < 0) { log_error("Error %d creating socket: %m", errno); return -EBADFD; } remote.sun_family = AF_UNIX; strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path)); len = strlen(remote.sun_path) + sizeof(remote.sun_family); r = connect(server_fd, (struct sockaddr *) &remote, len); if (r < 0 && errno != EINPROGRESS) { log_error("Error %d while connecting to Unix domain socket %s: %m", errno, proxy->remote_host); return -EBADFD; } } log_debug("Server connection is fd=%d", server_fd); return server_fd; } static int do_accept(sd_event *e, struct proxy *p, int fd) { struct connection *c_server_to_client = NULL; struct connection *c_client_to_server = NULL; int r = 0; union sockaddr_union sa; socklen_t salen = sizeof(sa); int client_fd, server_fd; client_fd = accept4(fd, (struct sockaddr *) &sa, &salen, SOCK_NONBLOCK|SOCK_CLOEXEC); if (client_fd < 0) { if (errno == EAGAIN || errno == EWOULDBLOCK) return -errno; log_error("Error %d accepting client connection: %m", errno); r = -errno; goto fail; } server_fd = get_server_connection_fd(p); if (server_fd < 0) { log_error("Error initiating server connection."); r = server_fd; goto fail; } c_client_to_server = new0(struct connection, 1); if (c_client_to_server == NULL) { log_oom(); goto fail; } c_server_to_client = new0(struct connection, 1); if (c_server_to_client == NULL) { log_oom(); goto fail; } c_client_to_server->fd = client_fd; c_server_to_client->fd = server_fd; if (sa.sa.sa_family == AF_INET || sa.sa.sa_family == AF_INET6) { char sa_str[INET6_ADDRSTRLEN]; const char *success; success = inet_ntop(sa.sa.sa_family, &sa.in6.sin6_addr, sa_str, INET6_ADDRSTRLEN); if (success == NULL) log_warning("Error %d calling inet_ntop: %m", errno); else log_debug("Accepted client connection from %s as fd=%d", sa_str, c_client_to_server->fd); } else { log_debug("Accepted client connection (non-IP) as fd=%d", c_client_to_server->fd); } total_clients++; log_debug("Client fd=%d (conn %p) successfully connected. Total clients: %u", c_client_to_server->fd, c_client_to_server, total_clients); log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client); /* Initialize watcher for send to server; this shows connectivity. */ r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w); if (r < 0) { log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); goto fail; } /* Allow lookups of the opposite connection. */ c_server_to_client->c_destination = c_client_to_server; c_client_to_server->c_destination = c_server_to_client; goto finish; fail: log_warning("Accepting a client connection or connecting to the server failed."); free_connection(c_client_to_server); free_connection(c_server_to_client); finish: return r; } static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) { struct proxy *p = (struct proxy *) userdata; sd_event *e = NULL; int r = 0; assert(revents & EPOLLIN); e = sd_event_get(s); for (;;) { r = do_accept(e, p, fd); if (r == -EAGAIN || r == -EWOULDBLOCK) break; if (r < 0) { log_error("Error %d while trying to accept: %s", r, strerror(-r)); break; } } /* Re-enable the watcher. */ r = sd_event_source_set_enabled(s, SD_EVENT_ONESHOT); if (r < 0) { log_error("Error %d while re-enabling listener with ONESHOT: %s", r, strerror(-r)); return r; } /* Preserve the main loop even if a single accept() fails. */ return 1; } static int run_main_loop(struct proxy *proxy) { _cleanup_event_source_unref_ sd_event_source *w_accept = NULL; _cleanup_event_unref_ sd_event *e = NULL; int r = EXIT_SUCCESS; r = sd_event_new(&e); if (r < 0) { log_error("Failed to allocate event loop: %s", strerror(-r)); return r; } r = fd_nonblock(proxy->listen_fd, true); if (r < 0) { log_error("Failed to make listen file descriptor nonblocking: %s", strerror(-r)); return r; } log_debug("Initializing main listener fd=%d", proxy->listen_fd); r = sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept); if (r < 0) { log_error("Error %d while adding event IO source: %s", r, strerror(-r)); return r; } /* Set the watcher to oneshot in case other processes are also * watching to accept(). */ r = sd_event_source_set_enabled(w_accept, SD_EVENT_ONESHOT); if (r < 0) { log_error("Error %d while setting event IO source to ONESHOT: %s", r, strerror(-r)); return r; } log_debug("Initialized main listener. Entering loop."); return sd_event_loop(e); } static int help(void) { printf("%s hostname-or-ip port-or-service\n" "%s unix-domain-socket-path\n\n" "Inherit a socket. Bidirectionally proxy.\n\n" " -h --help Show this help\n" " --version Print version and exit\n" " --ignore-env Ignore expected systemd environment\n", program_invocation_short_name, program_invocation_short_name); return 0; } static void version(void) { puts(PACKAGE_STRING " socket-proxyd"); } static int parse_argv(int argc, char *argv[], struct proxy *p) { enum { ARG_VERSION = 0x100, ARG_IGNORE_ENV }; static const struct option options[] = { { "help", no_argument, NULL, 'h' }, { "version", no_argument, NULL, ARG_VERSION }, { "ignore-env", no_argument, NULL, ARG_IGNORE_ENV}, { NULL, 0, NULL, 0 } }; int c; assert(argc >= 0); assert(argv); while ((c = getopt_long(argc, argv, "h", options, NULL)) >= 0) { switch (c) { case 'h': help(); return 0; case '?': return -EINVAL; case ARG_VERSION: version(); return 0; case ARG_IGNORE_ENV: p->ignore_env = true; continue; default: log_error("Unknown option code %c", c); return -EINVAL; } } if (optind + 1 != argc && optind + 2 != argc) { log_error("Incorrect number of positional arguments."); help(); return -EINVAL; } p->remote_host = argv[optind]; assert(p->remote_host); p->remote_is_inet = p->remote_host[0] != '/'; if (optind == argc - 2) { if (!p->remote_is_inet) { log_error("A port or service is not allowed for Unix socket destinations."); help(); return -EINVAL; } p->remote_service = argv[optind + 1]; assert(p->remote_service); } else if (p->remote_is_inet) { log_error("A port or service is required for IP destinations."); help(); return -EINVAL; } return 1; } int main(int argc, char *argv[]) { struct proxy p = {}; int r; log_parse_environment(); log_open(); r = parse_argv(argc, argv, &p); if (r <= 0) goto finish; p.listen_fd = SD_LISTEN_FDS_START; if (!p.ignore_env) { int n; n = sd_listen_fds(1); if (n == 0) { log_error("Found zero inheritable sockets. Are you sure this is running as a socket-activated service?"); r = EXIT_FAILURE; goto finish; } else if (n < 0) { log_error("Error %d while finding inheritable sockets: %s", n, strerror(-n)); r = EXIT_FAILURE; goto finish; } else if (n > 1) { log_error("Can't listen on more than one socket."); r = EXIT_FAILURE; goto finish; } } r = sd_is_socket(p.listen_fd, 0, SOCK_STREAM, 1); if (r < 0) { log_error("Error %d while checking inherited socket: %s", r, strerror(-r)); goto finish; } log_info("Starting the socket activation proxy with listener fd=%d.", p.listen_fd); r = run_main_loop(&p); finish: return r < 0 ? EXIT_FAILURE : EXIT_SUCCESS; }