diff options
-rw-r--r-- | src/saproxy/saproxy.c | 110 |
1 files changed, 66 insertions, 44 deletions
diff --git a/src/saproxy/saproxy.c b/src/saproxy/saproxy.c index 91599a8550..6504d09251 100644 --- a/src/saproxy/saproxy.c +++ b/src/saproxy/saproxy.c @@ -54,8 +54,8 @@ struct proxy { struct connection { int fd; - sd_event_source *w_recv; - sd_event_source *w_send; + uint32_t events; + sd_event_source *w; struct connection *c_destination; size_t buffer_filled_len; size_t buffer_sent_len; @@ -64,12 +64,57 @@ struct connection { static void free_connection(struct connection *c) { log_debug("Freeing fd=%d (conn %p).", c->fd, c); - sd_event_source_unref(c->w_recv); - sd_event_source_unref(c->w_send); + sd_event_source_unref(c->w); close(c->fd); free(c); } +static int add_event_to_connection(struct connection *c, uint32_t events) { + int r; + + log_debug("Have revents=%d. Adding revents=%d.", c->events, events); + + c->events |= events; + + r = sd_event_source_set_io_events(c->w, c->events); + if (r < 0) { + log_error("Error %d setting revents: %s", r, strerror(-r)); + return r; + } + + r = sd_event_source_set_enabled(c->w, SD_EVENT_ON); + if (r < 0) { + log_error("Error %d enabling source: %s", r, strerror(-r)); + return r; + } + + return 0; +} + +static int remove_event_from_connection(struct connection *c, uint32_t events) { + int r; + + log_debug("Have revents=%d. Removing revents=%d.", c->events, events); + + c->events &= ~events; + + r = sd_event_source_set_io_events(c->w, c->events); + if (r < 0) { + log_error("Error %d setting revents: %s", r, strerror(-r)); + return r; + } + + if (c->events == 0) { + r = sd_event_source_set_enabled(c->w, SD_EVENT_OFF); + if (r < 0) { + log_error("Error %d disabling source: %s", r, strerror(-r)); + return r; + } + } + + return 0; +} + static int send_buffer(struct connection *sender) { struct connection *receiver = sender->c_destination; ssize_t len; @@ -104,17 +149,17 @@ static int send_buffer(struct connection *sender) { /* If the buffer is full, disable events coming for recv. */ if (sender->buffer_filled_len == BUFFER_SIZE) { - r = sd_event_source_set_enabled(sender->w_recv, SD_EVENT_OFF); + r = remove_event_from_connection(sender, EPOLLIN); if (r < 0) { - log_error("Error %d disabling recv for fd=%d: %s", r, sender->fd, strerror(-r)); + log_error("Error %d disabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); return r; } } /* Watch for when the recipient can be sent data again. */ - r = sd_event_source_set_enabled(receiver->w_send, SD_EVENT_ON); + r = add_event_to_connection(receiver, EPOLLOUT); if (r < 0) { - log_error("Error %d enabling send for fd=%d: %s", r, receiver->fd, strerror(-r)); + log_error("Error %d enabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); return r; } log_debug("Done with recv for fd=%d. Waiting on send for fd=%d.", sender->fd, receiver->fd); @@ -125,17 +170,18 @@ static int send_buffer(struct connection *sender) { sender->buffer_filled_len = 0; sender->buffer_sent_len = 0; - /* Unmute the sender, in case the buffer was full. */ - r = sd_event_source_set_enabled(sender->w_recv, SD_EVENT_ON); + /* Enable the sender's receive watcher, in case the buffer was + * full and we disabled it. */ + r = add_event_to_connection(sender, EPOLLIN); if (r < 0) { - log_error("Error %d enabling recv for fd=%d: %s", r, sender->fd, strerror(-r)); + log_error("Error %d enabling EPOLLIN for fd=%d: %s", r, sender->fd, strerror(-r)); return r; } - /* Mute the recipient, as we have no data to send now. */ - r = sd_event_source_set_enabled(receiver->w_send, SD_EVENT_OFF); + /* Disable the other side's send watcher, as we have no data to send now. */ + r = remove_event_from_connection(receiver, EPOLLOUT); if (r < 0) { - log_error("Error %d disabling send for fd=%d: %s", r, receiver->fd, strerror(-r)); + log_error("Error %d disabling EPOLLOUT for fd=%d: %s", r, receiver->fd, strerror(-r)); return r; } @@ -149,14 +195,13 @@ static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void * assert(revents & (EPOLLIN | EPOLLOUT)); assert(fd == c->fd); + assert(s == c->w); log_debug("Got event revents=%d from fd=%d (conn %p).", revents, fd, c); if (revents & EPOLLIN) { log_debug("About to recv up to %lu bytes from fd=%d (%lu/BUFFER_SIZE).", BUFFER_SIZE - c->buffer_filled_len, fd, c->buffer_filled_len); - assert(s == c->w_recv); - /* Receive until the buffer's full, there's no more data, * or the client/server disconnects. */ while (c->buffer_filled_len < BUFFER_SIZE) { @@ -190,7 +235,6 @@ static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void * return send_buffer(c); } else { - assert(s == c->w_send); return send_buffer(c->c_destination); } @@ -213,43 +257,21 @@ static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, log_debug("Connected to server. Initializing watchers for receiving data."); - /* A disabled send watcher for the server. */ - r = sd_event_add_io(e, c_server_to_client->fd, EPOLLOUT, transfer_data_cb, c_server_to_client, &c_server_to_client->w_send); - if (r < 0) { - log_error("Error %d creating send watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); - goto fail; - } - r = sd_event_source_set_enabled(c_server_to_client->w_send, SD_EVENT_OFF); - if (r < 0) { - log_error("Error %d muting send for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); - goto finish; - } - /* A recv watcher for the server. */ - r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w_recv); + r = sd_event_add_io(e, c_server_to_client->fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->w); if (r < 0) { log_error("Error %d creating recv watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); goto fail; } - - /* A disabled send watcher for the client. */ - r = sd_event_add_io(e, c_client_to_server->fd, EPOLLOUT, transfer_data_cb, c_client_to_server, &c_client_to_server->w_send); - if (r < 0) { - log_error("Error %d creating send watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); - goto fail; - } - r = sd_event_source_set_enabled(c_client_to_server->w_send, SD_EVENT_OFF); - if (r < 0) { - log_error("Error %d muting send for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); - goto finish; - } + c_server_to_client->events = EPOLLIN; /* A recv watcher for the client. */ - r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w_recv); + r = sd_event_add_io(e, c_client_to_server->fd, EPOLLIN, transfer_data_cb, c_client_to_server, &c_client_to_server->w); if (r < 0) { log_error("Error %d creating recv watcher for fd=%d: %s", r, c_client_to_server->fd, strerror(-r)); goto fail; } + c_client_to_server->events = EPOLLIN; goto finish; @@ -383,7 +405,7 @@ static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdat log_debug("Server fd=%d (conn %p) successfully initialized.", c_server_to_client->fd, c_server_to_client); /* Initialize watcher for send to server; this shows connectivity. */ - r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w_send); + r = sd_event_add_io(sd_event_get(s), c_server_to_client->fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &c_server_to_client->w); if (r < 0) { log_error("Error %d creating connectivity watcher for fd=%d: %s", r, c_server_to_client->fd, strerror(-r)); goto fail; |