From 166cf510c2ba5aed15dcb807ec263ea2201dcb28 Mon Sep 17 00:00:00 2001 From: Zbigniew Jędrzejewski-Szmek Date: Thu, 4 Aug 2016 21:42:23 -0400 Subject: core/socket: rework SocketPeer refcounting Make functions and definitions that don't need to be shared local to socket.c. --- src/core/socket.c | 194 ++++++++++++++++++++++++++++-------------------------- 1 file changed, 100 insertions(+), 94 deletions(-) (limited to 'src/core/socket.c') diff --git a/src/core/socket.c b/src/core/socket.c index d3b9a75547..972d494dbc 100644 --- a/src/core/socket.c +++ b/src/core/socket.c @@ -59,6 +59,13 @@ #include "user-util.h" #include "in-addr-util.h" +struct SocketPeer { + unsigned n_ref; + + Socket *socket; + union sockaddr_union peer; +}; + static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = { [SOCKET_DEAD] = UNIT_INACTIVE, [SOCKET_START_PRE] = UNIT_ACTIVATING, @@ -78,9 +85,6 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = { static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata); static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata); -SocketPeer *socket_peer_new(void); -int socket_find_peer(Socket *s, int fd, SocketPeer **p); - static void socket_init(Unit *u) { Socket *s = SOCKET(u); @@ -482,10 +486,11 @@ static void peer_address_hash_func(const void *p, struct siphash *state) { const SocketPeer *s = p; assert(s); + assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6)); if (s->peer.sa.sa_family == AF_INET) siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state); - else if (s->peer.sa.sa_family == AF_INET6) + else siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state); } @@ -503,8 +508,7 @@ static int peer_address_compare_func(const void *a, const void *b) { case AF_INET6: return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr)); } - - return -1; + assert_not_reached("Black sheep in the family!"); } const struct hash_ops peer_address_hash_ops = { @@ -537,6 +541,87 @@ static int socket_load(Unit *u) { return socket_verify(s); } +static SocketPeer *socket_peer_new(void) { + SocketPeer *p; + + p = new0(SocketPeer, 1); + if (!p) + return NULL; + + p->n_ref = 1; + + return p; +} + +SocketPeer *socket_peer_ref(SocketPeer *p) { + if (!p) + return NULL; + + assert(p->n_ref > 0); + p->n_ref++; + + return p; +} + +SocketPeer *socket_peer_unref(SocketPeer *p) { + if (!p) + return NULL; + + assert(p->n_ref > 0); + + p->n_ref--; + + if (p->n_ref > 0) + return NULL; + + if (p->socket) + set_remove(p->socket->peers_by_address, p); + + return mfree(p); +} + +static int socket_acquire_peer(Socket *s, int fd, SocketPeer **p) { + _cleanup_(socket_peer_unrefp) SocketPeer *remote = NULL; + SocketPeer sa = {}, *i; + socklen_t salen = sizeof(sa.peer); + int r; + + assert(fd >= 0); + assert(s); + + r = getpeername(fd, &sa.peer.sa, &salen); + if (r < 0) + return log_error_errno(errno, "getpeername failed: %m"); + + if (!IN_SET(sa.peer.sa.sa_family, AF_INET, AF_INET6)) { + *p = NULL; + return 0; + } + + i = set_get(s->peers_by_address, &sa); + if (i) { + *p = socket_peer_ref(i); + return 1; + } + + remote = socket_peer_new(); + if (!remote) + return log_oom(); + + remote->peer = sa.peer; + + r = set_put(s->peers_by_address, remote); + if (r < 0) + return r; + + remote->socket = s; + + *p = remote; + remote = NULL; + + return 1; +} + _const_ static const char* listen_lookup(int family, int type) { if (family == AF_NETLINK) @@ -2102,22 +2187,22 @@ static void socket_enter_running(Socket *s, int cfd) { Service *service; if (s->n_connections >= s->max_connections) { - log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.", s->n_connections); + log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.", + s->n_connections); safe_close(cfd); return; } if (s->max_connections_per_source > 0) { - r = socket_find_peer(s, cfd, &p); + r = socket_acquire_peer(s, cfd, &p); if (r < 0) { safe_close(cfd); return; - } - - if (p->n_ref > s->max_connections_per_source) { - log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref); + } else if (r > 0 && p->n_ref > s->max_connections_per_source) { + log_unit_warning(UNIT(s), + "Too many incoming connections (%u) from source, refusing connection attempt.", + p->n_ref); safe_close(cfd); - p = NULL; return; } } @@ -2163,10 +2248,8 @@ static void socket_enter_running(Socket *s, int cfd) { cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */ s->n_connections++; - if (s->max_connections_per_source > 0) { - service->peer = socket_peer_ref(p); - p = NULL; - } + service->peer = p; /* Pass ownership of the peer reference */ + p = NULL; r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL); if (r < 0) { @@ -2662,83 +2745,6 @@ _pure_ static bool socket_check_gc(Unit *u) { return s->n_connections > 0; } -SocketPeer *socket_peer_new(void) { - SocketPeer *p; - - p = new0(SocketPeer, 1); - if (!p) - return NULL; - - p->n_ref = 1; - - return p; -} - -SocketPeer *socket_peer_ref(SocketPeer *p) { - if (!p) - return NULL; - - assert(p->n_ref > 0); - p->n_ref++; - - return p; -} - -SocketPeer *socket_peer_unref(SocketPeer *p) { - if (!p) - return NULL; - - assert(p->n_ref > 0); - - p->n_ref--; - - if (p->n_ref > 0) - return NULL; - - if (p->socket) - set_remove(p->socket->peers_by_address, p); - - free(p); - - return NULL; -} - -int socket_find_peer(Socket *s, int fd, SocketPeer **p) { - _cleanup_free_ SocketPeer *remote = NULL; - SocketPeer sa, *i; - socklen_t salen = sizeof(sa.peer); - int r; - - assert(fd >= 0); - assert(s); - - r = getpeername(fd, &sa.peer.sa, &salen); - if (r < 0) - return log_error_errno(errno, "getpeername failed: %m"); - - i = set_get(s->peers_by_address, &sa); - if (i) { - *p = i; - return 1; - } - - remote = socket_peer_new(); - if (!remote) - return log_oom(); - - memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union)); - remote->socket = s; - - r = set_put(s->peers_by_address, remote); - if (r < 0) - return r; - - *p = remote; - remote = NULL; - - return 0; -} - static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) { SocketPort *p = userdata; int cfd = -1; -- cgit v1.2.3-54-g00ecf