summaryrefslogtreecommitdiff
path: root/src/core/socket.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/socket.c')
-rw-r--r--src/core/socket.c194
1 files changed, 100 insertions, 94 deletions
diff --git a/src/core/socket.c b/src/core/socket.c
index d3b9a75547..972d494dbc 100644
--- a/src/core/socket.c
+++ b/src/core/socket.c
@@ -59,6 +59,13 @@
#include "user-util.h"
#include "in-addr-util.h"
+struct SocketPeer {
+ unsigned n_ref;
+
+ Socket *socket;
+ union sockaddr_union peer;
+};
+
static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
[SOCKET_DEAD] = UNIT_INACTIVE,
[SOCKET_START_PRE] = UNIT_ACTIVATING,
@@ -78,9 +85,6 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
-SocketPeer *socket_peer_new(void);
-int socket_find_peer(Socket *s, int fd, SocketPeer **p);
-
static void socket_init(Unit *u) {
Socket *s = SOCKET(u);
@@ -482,10 +486,11 @@ static void peer_address_hash_func(const void *p, struct siphash *state) {
const SocketPeer *s = p;
assert(s);
+ assert(IN_SET(s->peer.sa.sa_family, AF_INET, AF_INET6));
if (s->peer.sa.sa_family == AF_INET)
siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
- else if (s->peer.sa.sa_family == AF_INET6)
+ else
siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
}
@@ -503,8 +508,7 @@ static int peer_address_compare_func(const void *a, const void *b) {
case AF_INET6:
return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
}
-
- return -1;
+ assert_not_reached("Black sheep in the family!");
}
const struct hash_ops peer_address_hash_ops = {
@@ -537,6 +541,87 @@ static int socket_load(Unit *u) {
return socket_verify(s);
}
+static SocketPeer *socket_peer_new(void) {
+ SocketPeer *p;
+
+ p = new0(SocketPeer, 1);
+ if (!p)
+ return NULL;
+
+ p->n_ref = 1;
+
+ return p;
+}
+
+SocketPeer *socket_peer_ref(SocketPeer *p) {
+ if (!p)
+ return NULL;
+
+ assert(p->n_ref > 0);
+ p->n_ref++;
+
+ return p;
+}
+
+SocketPeer *socket_peer_unref(SocketPeer *p) {
+ if (!p)
+ return NULL;
+
+ assert(p->n_ref > 0);
+
+ p->n_ref--;
+
+ if (p->n_ref > 0)
+ return NULL;
+
+ if (p->socket)
+ set_remove(p->socket->peers_by_address, p);
+
+ return mfree(p);
+}
+
+static int socket_acquire_peer(Socket *s, int fd, SocketPeer **p) {
+ _cleanup_(socket_peer_unrefp) SocketPeer *remote = NULL;
+ SocketPeer sa = {}, *i;
+ socklen_t salen = sizeof(sa.peer);
+ int r;
+
+ assert(fd >= 0);
+ assert(s);
+
+ r = getpeername(fd, &sa.peer.sa, &salen);
+ if (r < 0)
+ return log_error_errno(errno, "getpeername failed: %m");
+
+ if (!IN_SET(sa.peer.sa.sa_family, AF_INET, AF_INET6)) {
+ *p = NULL;
+ return 0;
+ }
+
+ i = set_get(s->peers_by_address, &sa);
+ if (i) {
+ *p = socket_peer_ref(i);
+ return 1;
+ }
+
+ remote = socket_peer_new();
+ if (!remote)
+ return log_oom();
+
+ remote->peer = sa.peer;
+
+ r = set_put(s->peers_by_address, remote);
+ if (r < 0)
+ return r;
+
+ remote->socket = s;
+
+ *p = remote;
+ remote = NULL;
+
+ return 1;
+}
+
_const_ static const char* listen_lookup(int family, int type) {
if (family == AF_NETLINK)
@@ -2102,22 +2187,22 @@ static void socket_enter_running(Socket *s, int cfd) {
Service *service;
if (s->n_connections >= s->max_connections) {
- log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.", s->n_connections);
+ log_unit_warning(UNIT(s), "Too many incoming connections (%u), refusing connection attempt.",
+ s->n_connections);
safe_close(cfd);
return;
}
if (s->max_connections_per_source > 0) {
- r = socket_find_peer(s, cfd, &p);
+ r = socket_acquire_peer(s, cfd, &p);
if (r < 0) {
safe_close(cfd);
return;
- }
-
- if (p->n_ref > s->max_connections_per_source) {
- log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
+ } else if (r > 0 && p->n_ref > s->max_connections_per_source) {
+ log_unit_warning(UNIT(s),
+ "Too many incoming connections (%u) from source, refusing connection attempt.",
+ p->n_ref);
safe_close(cfd);
- p = NULL;
return;
}
}
@@ -2163,10 +2248,8 @@ static void socket_enter_running(Socket *s, int cfd) {
cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
s->n_connections++;
- if (s->max_connections_per_source > 0) {
- service->peer = socket_peer_ref(p);
- p = NULL;
- }
+ service->peer = p; /* Pass ownership of the peer reference */
+ p = NULL;
r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
if (r < 0) {
@@ -2662,83 +2745,6 @@ _pure_ static bool socket_check_gc(Unit *u) {
return s->n_connections > 0;
}
-SocketPeer *socket_peer_new(void) {
- SocketPeer *p;
-
- p = new0(SocketPeer, 1);
- if (!p)
- return NULL;
-
- p->n_ref = 1;
-
- return p;
-}
-
-SocketPeer *socket_peer_ref(SocketPeer *p) {
- if (!p)
- return NULL;
-
- assert(p->n_ref > 0);
- p->n_ref++;
-
- return p;
-}
-
-SocketPeer *socket_peer_unref(SocketPeer *p) {
- if (!p)
- return NULL;
-
- assert(p->n_ref > 0);
-
- p->n_ref--;
-
- if (p->n_ref > 0)
- return NULL;
-
- if (p->socket)
- set_remove(p->socket->peers_by_address, p);
-
- free(p);
-
- return NULL;
-}
-
-int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
- _cleanup_free_ SocketPeer *remote = NULL;
- SocketPeer sa, *i;
- socklen_t salen = sizeof(sa.peer);
- int r;
-
- assert(fd >= 0);
- assert(s);
-
- r = getpeername(fd, &sa.peer.sa, &salen);
- if (r < 0)
- return log_error_errno(errno, "getpeername failed: %m");
-
- i = set_get(s->peers_by_address, &sa);
- if (i) {
- *p = i;
- return 1;
- }
-
- remote = socket_peer_new();
- if (!remote)
- return log_oom();
-
- memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
- remote->socket = s;
-
- r = set_put(s->peers_by_address, remote);
- if (r < 0)
- return r;
-
- *p = remote;
- remote = NULL;
-
- return 0;
-}
-
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
SocketPort *p = userdata;
int cfd = -1;