summaryrefslogtreecommitdiff
path: root/src/core/socket.c
diff options
context:
space:
mode:
authorSusant Sahani <ssahani@users.noreply.github.com>2016-08-02 23:18:23 +0530
committerZbigniew Jędrzejewski-Szmek <zbyszek@in.waw.pl>2016-08-02 13:48:23 -0400
commit9d565427647a874b7eda64ef0cb41ea74a96659b (patch)
tree0dc7f6b19040c1d6fbec0612cd9732e4fc83ffa0 /src/core/socket.c
parent87edd2b116c107b54c844e88de06be62ab55e29d (diff)
socket: add support to control no. of connections from one source (#3607)
Introduce MaxConnectionsPerSource= that is number of concurrent connections allowed per IP. RFE: 1939
Diffstat (limited to 'src/core/socket.c')
-rw-r--r--src/core/socket.c185
1 files changed, 185 insertions, 0 deletions
diff --git a/src/core/socket.c b/src/core/socket.c
index 1ce41a1f07..ff55885fb3 100644
--- a/src/core/socket.c
+++ b/src/core/socket.c
@@ -57,6 +57,7 @@
#include "unit-printf.h"
#include "unit.h"
#include "user-util.h"
+#include "in-addr-util.h"
static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
[SOCKET_DEAD] = UNIT_INACTIVE,
@@ -77,6 +78,9 @@ static const UnitActiveState state_translation_table[_SOCKET_STATE_MAX] = {
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata);
static int socket_dispatch_timer(sd_event_source *source, usec_t usec, void *userdata);
+SocketPeer *socket_peer_new(void);
+int socket_find_peer(Socket *s, int fd, SocketPeer **p);
+
static void socket_init(Unit *u) {
Socket *s = SOCKET(u);
@@ -141,11 +145,17 @@ void socket_free_ports(Socket *s) {
static void socket_done(Unit *u) {
Socket *s = SOCKET(u);
+ SocketPeer *p;
assert(s);
socket_free_ports(s);
+ while ((p = hashmap_steal_first(s->peers_by_address)))
+ p->socket = NULL;
+
+ s->peers_by_address = hashmap_free(s->peers_by_address);
+
s->exec_runtime = exec_runtime_unref(s->exec_runtime);
exec_command_free_array(s->exec_command, _SOCKET_EXEC_COMMAND_MAX);
s->control_command = NULL;
@@ -468,6 +478,40 @@ static int socket_verify(Socket *s) {
return 0;
}
+static void peer_address_hash_func(const void *p, struct siphash *state) {
+ const SocketPeer *s = p;
+
+ assert(s);
+
+ if (s->peer.sa.sa_family == AF_INET)
+ siphash24_compress(&s->peer.in.sin_addr, sizeof(s->peer.in.sin_addr), state);
+ else if (s->peer.sa.sa_family == AF_INET6)
+ siphash24_compress(&s->peer.in6.sin6_addr, sizeof(s->peer.in6.sin6_addr), state);
+}
+
+static int peer_address_compare_func(const void *a, const void *b) {
+ const SocketPeer *x = a, *y = b;
+
+ if (x->peer.sa.sa_family < y->peer.sa.sa_family)
+ return -1;
+ if (x->peer.sa.sa_family > y->peer.sa.sa_family)
+ return 1;
+
+ switch(x->peer.sa.sa_family) {
+ case AF_INET:
+ return memcmp(&x->peer.in.sin_addr, &y->peer.in.sin_addr, sizeof(x->peer.in.sin_addr));
+ case AF_INET6:
+ return memcmp(&x->peer.in6.sin6_addr, &y->peer.in6.sin6_addr, sizeof(x->peer.in6.sin6_addr));
+ }
+
+ return -1;
+}
+
+const struct hash_ops peer_address_hash_ops = {
+ .hash = peer_address_hash_func,
+ .compare = peer_address_compare_func
+};
+
static int socket_load(Unit *u) {
Socket *s = SOCKET(u);
int r;
@@ -475,6 +519,10 @@ static int socket_load(Unit *u) {
assert(u);
assert(u->load_state == UNIT_STUB);
+ r = hashmap_ensure_allocated(&s->peers_by_address, &peer_address_hash_ops);
+ if (r < 0)
+ return r;
+
r = unit_load_fragment_and_dropin(u);
if (r < 0)
return r;
@@ -2050,6 +2098,7 @@ static void socket_enter_running(Socket *s, int cfd) {
socket_set_state(s, SOCKET_RUNNING);
} else {
_cleanup_free_ char *prefix = NULL, *instance = NULL, *name = NULL;
+ _cleanup_(socket_peer_unrefp) SocketPeer *p = NULL;
Service *service;
if (s->n_connections >= s->max_connections) {
@@ -2058,6 +2107,21 @@ static void socket_enter_running(Socket *s, int cfd) {
return;
}
+ if (s->max_connections_per_source > 0) {
+ r = socket_find_peer(s, cfd, &p);
+ if (r < 0) {
+ safe_close(cfd);
+ return;
+ }
+
+ if (p->n_ref > s->max_connections_per_source) {
+ log_unit_warning(UNIT(s), "Too many incoming connections (%u) from source, refusing connection attempt.", p->n_ref);
+ safe_close(cfd);
+ p = NULL;
+ return;
+ }
+ }
+
r = socket_instantiate_service(s);
if (r < 0)
goto fail;
@@ -2099,6 +2163,11 @@ static void socket_enter_running(Socket *s, int cfd) {
cfd = -1; /* We passed ownership of the fd to the service now. Forget it here. */
s->n_connections++;
+ if (s->max_connections_per_source > 0) {
+ service->peer = socket_peer_ref(p);
+ p = NULL;
+ }
+
r = manager_add_job(UNIT(s)->manager, JOB_START, UNIT(service), JOB_REPLACE, &error, NULL);
if (r < 0) {
/* We failed to activate the new service, but it still exists. Let's make sure the service
@@ -2244,7 +2313,9 @@ static int socket_stop(Unit *u) {
static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
Socket *s = SOCKET(u);
+ SocketPeer *k;
SocketPort *p;
+ Iterator i;
int r;
assert(u);
@@ -2295,6 +2366,16 @@ static int socket_serialize(Unit *u, FILE *f, FDSet *fds) {
}
}
+ HASHMAP_FOREACH(k, s->peers_by_address, i) {
+ _cleanup_free_ char *t = NULL;
+
+ r = sockaddr_pretty(&k->peer.sa, FAMILY_ADDRESS_SIZE(k->peer.sa.sa_family), true, true, &t);
+ if (r < 0)
+ return r;
+
+ unit_serialize_item_format(u, f, "peer", "%u %s", k->n_ref, t);
+ }
+
return 0;
}
@@ -2458,6 +2539,33 @@ static int socket_deserialize_item(Unit *u, const char *key, const char *value,
}
}
+ } else if (streq(key, "peer")) {
+ _cleanup_(socket_peer_unrefp) SocketPeer *p;
+ int n_ref, skip = 0;
+ SocketAddress a;
+ int r;
+
+ if (sscanf(value, "%u %n", &n_ref, &skip) < 1 || n_ref < 1)
+ log_unit_debug(u, "Failed to parse socket peer value: %s", value);
+ else {
+ r = socket_address_parse(&a, value+skip);
+ if (r < 0)
+ return r;
+
+ p = socket_peer_new();
+ if (!p)
+ return log_oom();
+
+ p->n_ref = n_ref;
+ memcpy(&p->peer, &a.sockaddr, sizeof(a.sockaddr));
+ p->socket = s;
+
+ r = hashmap_put(s->peers_by_address, p, p);
+ if (r < 0)
+ return r;
+
+ p = NULL;
+ }
} else
log_unit_debug(UNIT(s), "Unknown serialization key: %s", key);
@@ -2554,6 +2662,83 @@ _pure_ static bool socket_check_gc(Unit *u) {
return s->n_connections > 0;
}
+SocketPeer *socket_peer_new(void) {
+ SocketPeer *p;
+
+ p = new0(SocketPeer, 1);
+ if (!p)
+ return NULL;
+
+ p->n_ref = 1;
+
+ return p;
+}
+
+SocketPeer *socket_peer_ref(SocketPeer *p) {
+ if (!p)
+ return NULL;
+
+ assert(p->n_ref > 0);
+ p->n_ref++;
+
+ return p;
+}
+
+SocketPeer *socket_peer_unref(SocketPeer *p) {
+ if (!p)
+ return NULL;
+
+ assert(p->n_ref > 0);
+
+ p->n_ref--;
+
+ if (p->n_ref > 0)
+ return NULL;
+
+ if (p->socket)
+ (void) hashmap_remove(p->socket->peers_by_address, p);
+
+ free(p);
+
+ return NULL;
+}
+
+int socket_find_peer(Socket *s, int fd, SocketPeer **p) {
+ _cleanup_free_ SocketPeer *remote = NULL;
+ SocketPeer sa, *i;
+ socklen_t salen = sizeof(sa.peer);
+ int r;
+
+ assert(fd >= 0);
+ assert(s);
+
+ r = getpeername(fd, &sa.peer.sa, &salen);
+ if (r < 0)
+ return log_error_errno(errno, "getpeername failed: %m");
+
+ i = hashmap_get(s->peers_by_address, &sa);
+ if (i) {
+ *p = i;
+ return 1;
+ }
+
+ remote = socket_peer_new();
+ if (!remote)
+ return log_oom();
+
+ memcpy(&remote->peer, &sa.peer, sizeof(union sockaddr_union));
+ remote->socket = s;
+
+ r = hashmap_put(s->peers_by_address, remote, remote);
+ if (r < 0)
+ return r;
+
+ *p = remote;
+ remote = NULL;
+
+ return 0;
+}
+
static int socket_dispatch_io(sd_event_source *source, int fd, uint32_t revents, void *userdata) {
SocketPort *p = userdata;
int cfd = -1;