diff options
Diffstat (limited to 'src/poll.c')
-rw-r--r-- | src/poll.c | 302 |
1 files changed, 179 insertions, 123 deletions
@@ -44,14 +44,14 @@ #endif /** - * @defgroup libssh_poll The SSH poll functions. + * @defgroup libssh_poll The SSH poll functions * @ingroup libssh * * Add a generic way to handle sockets asynchronously. * * It's based on poll objects, each of which store a socket, its events and a * callback, which gets called whenever an event is set. The poll objects are - * attached to a poll context, which should be allocated on per thread basis. + * attached to a poll context, which should be allocated on a per thread basis. * * Polling the poll context will poll all the attached poll objects and call * their callbacks (handlers) if any of the socket events are set. This should @@ -68,7 +68,7 @@ struct ssh_poll_handle_struct { size_t idx; } x; short events; - int lock; + uint32_t lock_cnt; ssh_poll_callback cb; void *cb_data; }; @@ -84,15 +84,18 @@ struct ssh_poll_ctx_struct { #ifdef HAVE_POLL #include <poll.h> -void ssh_poll_init(void) { +void ssh_poll_init(void) +{ return; } -void ssh_poll_cleanup(void) { +void ssh_poll_cleanup(void) +{ return; } -int ssh_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) { +int ssh_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) +{ return poll((struct pollfd *) fds, nfds, timeout); } @@ -210,8 +213,8 @@ static short bsd_socket_compute_revents(int fd, short events) * poll implementation. * * Keep in mind that select is terribly inefficient. The interface is simply not - * meant to be used with maximum descriptor value greater, say, 32 or so. With - * a value as high as 1024 on Linux you'll pay dearly in every single call. + * meant to be used with maximum descriptor value greater than, say, 32 or so. + * With a value as high as 1024 on Linux you'll pay dearly in every single call. * poll() will be orders of magnitude faster. */ static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) @@ -246,19 +249,17 @@ static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) } #endif - if (fds[i].events & (POLLIN | POLLRDNORM)) { - FD_SET (fds[i].fd, &readfds); - } + // we use the readfds to get POLLHUP and POLLERR, which are provided even when not requested + FD_SET (fds[i].fd, &readfds); + if (fds[i].events & (POLLOUT | POLLWRNORM | POLLWRBAND)) { FD_SET (fds[i].fd, &writefds); } if (fds[i].events & (POLLPRI | POLLRDBAND)) { FD_SET (fds[i].fd, &exceptfds); } - if (fds[i].fd > max_fd && - (fds[i].events & (POLLIN | POLLOUT | POLLPRI | - POLLRDNORM | POLLRDBAND | - POLLWRNORM | POLLWRBAND))) { + + if (fds[i].fd > max_fd) { max_fd = fds[i].fd; rc = 0; } @@ -286,7 +287,7 @@ static int bsd_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) if (rc < 0) { return -1; } - /* A timeout occured */ + /* A timeout occurred */ if (rc == 0) { return 0; } @@ -335,21 +336,24 @@ int ssh_poll(ssh_pollfd_t *fds, nfds_t nfds, int timeout) { /** * @brief Allocate a new poll object, which could be used within a poll context. * - * @param fd Socket that will be polled. - * @param events Poll events that will be monitored for the socket. i.e. - * POLLIN, POLLPRI, POLLOUT - * @param cb Function to be called if any of the events are set. - * The prototype of cb is: - * int (*ssh_poll_callback)(ssh_poll_handle p, socket_t fd, - * int revents, void *userdata); - * @param userdata Userdata to be passed to the callback function. NULL if - * not needed. + * @param[in] fd Socket that will be polled. + * @param[in] events Poll events that will be monitored for the socket. + * i.e. POLLIN, POLLPRI, POLLOUT + * @param[in] cb Function to be called if any of the events are set. + * The prototype of cb is: + * int (*ssh_poll_callback)(ssh_poll_handle p, + * socket_t fd, + * int revents, + * void *userdata); + * @param[in] userdata Userdata to be passed to the callback function. + * NULL if not needed. * - * @return A new poll object, NULL on error + * @return A new poll object, NULL on error */ -ssh_poll_handle ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, - void *userdata) { +ssh_poll_handle +ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, void *userdata) +{ ssh_poll_handle p; p = malloc(sizeof(struct ssh_poll_handle_struct)); @@ -373,12 +377,13 @@ ssh_poll_handle ssh_poll_new(socket_t fd, short events, ssh_poll_callback cb, * @param p Pointer to an already allocated poll object. */ -void ssh_poll_free(ssh_poll_handle p) { - if(p->ctx != NULL){ - ssh_poll_ctx_remove(p->ctx,p); - p->ctx=NULL; - } - SAFE_FREE(p); +void ssh_poll_free(ssh_poll_handle p) +{ + if (p->ctx != NULL) { + ssh_poll_ctx_remove(p->ctx, p); + p->ctx = NULL; + } + SAFE_FREE(p); } /** @@ -388,8 +393,9 @@ void ssh_poll_free(ssh_poll_handle p) { * * @return Poll context or NULL if the poll object isn't attached. */ -ssh_poll_ctx ssh_poll_get_ctx(ssh_poll_handle p) { - return p->ctx; +ssh_poll_ctx ssh_poll_get_ctx(ssh_poll_handle p) +{ + return p->ctx; } /** @@ -399,22 +405,31 @@ ssh_poll_ctx ssh_poll_get_ctx(ssh_poll_handle p) { * * @return Poll events. */ -short ssh_poll_get_events(ssh_poll_handle p) { - return p->events; +short ssh_poll_get_events(ssh_poll_handle p) +{ + return p->events; } /** * @brief Set the events of a poll object. The events will also be propagated - * to an associated poll context. + * to an associated poll context unless the fd is locked. In that case, + * only the POLLOUT can be set. * * @param p Pointer to an already allocated poll object. * @param events Poll events. */ -void ssh_poll_set_events(ssh_poll_handle p, short events) { - p->events = events; - if (p->ctx != NULL && !p->lock) { - p->ctx->pollfds[p->x.idx].events = events; - } +void ssh_poll_set_events(ssh_poll_handle p, short events) +{ + p->events = events; + if (p->ctx != NULL) { + if (p->lock_cnt == 0) { + p->ctx->pollfds[p->x.idx].events = events; + } else if (!(p->ctx->pollfds[p->x.idx].events & POLLOUT)) { + /* if locked, allow only setting POLLOUT to prevent recursive + * callbacks */ + p->ctx->pollfds[p->x.idx].events = events & POLLOUT; + } + } } /** @@ -424,12 +439,13 @@ void ssh_poll_set_events(ssh_poll_handle p, short events) { * @param p Pointer to an already allocated poll object. * @param fd New file descriptor. */ -void ssh_poll_set_fd(ssh_poll_handle p, socket_t fd) { - if (p->ctx != NULL) { - p->ctx->pollfds[p->x.idx].fd = fd; - } else { - p->x.fd = fd; - } +void ssh_poll_set_fd(ssh_poll_handle p, socket_t fd) +{ + if (p->ctx != NULL) { + p->ctx->pollfds[p->x.idx].fd = fd; + } else { + p->x.fd = fd; + } } /** @@ -439,8 +455,9 @@ void ssh_poll_set_fd(ssh_poll_handle p, socket_t fd) { * @param p Pointer to an already allocated poll object. * @param events Poll events. */ -void ssh_poll_add_events(ssh_poll_handle p, short events) { - ssh_poll_set_events(p, ssh_poll_get_events(p) | events); +void ssh_poll_add_events(ssh_poll_handle p, short events) +{ + ssh_poll_set_events(p, ssh_poll_get_events(p) | events); } /** @@ -450,8 +467,9 @@ void ssh_poll_add_events(ssh_poll_handle p, short events) { * @param p Pointer to an already allocated poll object. * @param events Poll events. */ -void ssh_poll_remove_events(ssh_poll_handle p, short events) { - ssh_poll_set_events(p, ssh_poll_get_events(p) & ~events); +void ssh_poll_remove_events(ssh_poll_handle p, short events) +{ + ssh_poll_set_events(p, ssh_poll_get_events(p) & ~events); } /** @@ -462,12 +480,13 @@ void ssh_poll_remove_events(ssh_poll_handle p, short events) { * @return Raw socket. */ -socket_t ssh_poll_get_fd(ssh_poll_handle p) { - if (p->ctx != NULL) { - return p->ctx->pollfds[p->x.idx].fd; - } +socket_t ssh_poll_get_fd(ssh_poll_handle p) +{ + if (p->ctx != NULL) { + return p->ctx->pollfds[p->x.idx].fd; + } - return p->x.fd; + return p->x.fd; } /** * @brief Set the callback of a poll object. @@ -477,11 +496,12 @@ socket_t ssh_poll_get_fd(ssh_poll_handle p) { * @param userdata Userdata to be passed to the callback function. NULL if * not needed. */ -void ssh_poll_set_callback(ssh_poll_handle p, ssh_poll_callback cb, void *userdata) { - if (cb != NULL) { - p->cb = cb; - p->cb_data = userdata; - } +void ssh_poll_set_callback(ssh_poll_handle p, ssh_poll_callback cb, void *userdata) +{ + if (cb != NULL) { + p->cb = cb; + p->cb_data = userdata; + } } /** @@ -495,7 +515,8 @@ void ssh_poll_set_callback(ssh_poll_handle p, ssh_poll_callback cb, void *userda * for the next 5. Set it to 0 if you want to use the * library's default value. */ -ssh_poll_ctx ssh_poll_ctx_new(size_t chunk_size) { +ssh_poll_ctx ssh_poll_ctx_new(size_t chunk_size) +{ ssh_poll_ctx ctx; ctx = malloc(sizeof(struct ssh_poll_ctx_struct)); @@ -518,25 +539,27 @@ ssh_poll_ctx ssh_poll_ctx_new(size_t chunk_size) { * * @param ctx Pointer to an already allocated poll context. */ -void ssh_poll_ctx_free(ssh_poll_ctx ctx) { - if (ctx->polls_allocated > 0) { - while (ctx->polls_used > 0){ - ssh_poll_handle p = ctx->pollptrs[0]; - /* - * The free function calls ssh_poll_ctx_remove() and decrements - * ctx->polls_used - */ - ssh_poll_free(p); - } +void ssh_poll_ctx_free(ssh_poll_ctx ctx) +{ + if (ctx->polls_allocated > 0) { + while (ctx->polls_used > 0){ + ssh_poll_handle p = ctx->pollptrs[0]; + /* + * The free function calls ssh_poll_ctx_remove() and decrements + * ctx->polls_used + */ + ssh_poll_free(p); + } - SAFE_FREE(ctx->pollptrs); - SAFE_FREE(ctx->pollfds); - } + SAFE_FREE(ctx->pollptrs); + SAFE_FREE(ctx->pollfds); + } - SAFE_FREE(ctx); + SAFE_FREE(ctx); } -static int ssh_poll_ctx_resize(ssh_poll_ctx ctx, size_t new_size) { +static int ssh_poll_ctx_resize(ssh_poll_ctx ctx, size_t new_size) +{ ssh_poll_handle *pollptrs; ssh_pollfd_t *pollfds; @@ -570,7 +593,8 @@ static int ssh_poll_ctx_resize(ssh_poll_ctx ctx, size_t new_size) { * * @return 0 on success, < 0 on error */ -int ssh_poll_ctx_add(ssh_poll_ctx ctx, ssh_poll_handle p) { +int ssh_poll_ctx_add(ssh_poll_ctx ctx, ssh_poll_handle p) +{ socket_t fd; if (p->ctx != NULL) { @@ -604,7 +628,7 @@ int ssh_poll_ctx_add(ssh_poll_ctx ctx, ssh_poll_handle p) { */ int ssh_poll_ctx_add_socket (ssh_poll_ctx ctx, ssh_socket s) { - ssh_poll_handle p; + ssh_poll_handle p = NULL; int ret; p = ssh_socket_get_poll_handle(s); @@ -622,7 +646,8 @@ int ssh_poll_ctx_add_socket (ssh_poll_ctx ctx, ssh_socket s) * @param ctx Pointer to an already allocated poll context. * @param p Pointer to an already allocated poll object. */ -void ssh_poll_ctx_remove(ssh_poll_ctx ctx, ssh_poll_handle p) { +void ssh_poll_ctx_remove(ssh_poll_ctx ctx, ssh_poll_handle p) +{ size_t i; i = p->x.idx; @@ -648,7 +673,7 @@ void ssh_poll_ctx_remove(ssh_poll_ctx ctx, ssh_poll_handle p) { * @brief Poll all the sockets associated through a poll object with a * poll context. If any of the events are set after the poll, the * call back function of the socket will be called. - * This function should be called once within the programs main loop. + * This function should be called once within the program's main loop. * * @param ctx Pointer to an already allocated poll context. * @param timeout An upper limit on the time for which ssh_poll_ctx() will @@ -657,7 +682,7 @@ void ssh_poll_ctx_remove(ssh_poll_ctx ctx, ssh_poll_handle p) { * the poll() function. * @returns SSH_OK No error. * SSH_ERROR Error happened during the poll. - * SSH_AGAIN Timeout occured + * SSH_AGAIN Timeout occurred */ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) @@ -673,6 +698,15 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) return SSH_ERROR; } + /* Allow only POLLOUT events on locked sockets as that means we are called + * recursively and we only want process the POLLOUT events here to flush + * output buffer */ + for (i = 0; i < ctx->polls_used; i++) { + /* The lock allows only POLLOUT events: drop the rest */ + if (ctx->pollptrs[i]->lock_cnt > 0) { + ctx->pollfds[i].events &= POLLOUT; + } + } ssh_timestamp_init(&ts); do { int tm = ssh_timeout_update(&ts, timeout); @@ -688,17 +722,24 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) used = ctx->polls_used; for (i = 0; i < used && rc > 0; ) { - if (!ctx->pollfds[i].revents || ctx->pollptrs[i]->lock) { + revents = ctx->pollfds[i].revents; + /* Do not pass any other events except for POLLOUT to callback when + * called recursively more than 2 times. On s390x the poll will be + * spammed with POLLHUP events causing infinite recursion when the user + * callback issues some write/flush/poll calls. */ + if (ctx->pollptrs[i]->lock_cnt > 2) { + revents &= POLLOUT; + } + if (revents == 0) { i++; } else { int ret; p = ctx->pollptrs[i]; fd = ctx->pollfds[i].fd; - revents = ctx->pollfds[i].revents; /* avoid having any event caught during callback */ ctx->pollfds[i].events = 0; - p->lock = 1; + p->lock_cnt++; if (p->cb && (ret = p->cb(p, fd, revents, p->cb_data)) < 0) { if (ret == -2) { return -1; @@ -709,7 +750,7 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) } else { ctx->pollfds[i].revents = 0; ctx->pollfds[i].events = p->events; - p->lock = 0; + p->lock_cnt--; i++; } @@ -727,12 +768,13 @@ int ssh_poll_ctx_dopoll(ssh_poll_ctx ctx, int timeout) * @param session SSH session * @returns the default ssh_poll_ctx */ -ssh_poll_ctx ssh_poll_get_default_ctx(ssh_session session){ - if(session->default_poll_ctx != NULL) - return session->default_poll_ctx; - /* 2 is enough for the default one */ - session->default_poll_ctx = ssh_poll_ctx_new(2); - return session->default_poll_ctx; +ssh_poll_ctx ssh_poll_get_default_ctx(ssh_session session) +{ + if(session->default_poll_ctx != NULL) + return session->default_poll_ctx; + /* 2 is enough for the default one */ + session->default_poll_ctx = ssh_poll_ctx_new(2); + return session->default_poll_ctx; } /* public event API */ @@ -754,10 +796,11 @@ struct ssh_event_struct { * ssh_session objects and socket fd which are going to be polled at the * same time as the event context. You would need a single event context * per thread. - * + * * @return The ssh_event object on success, NULL on failure. */ -ssh_event ssh_event_new(void) { +ssh_event ssh_event_new(void) +{ ssh_event event; event = malloc(sizeof(struct ssh_event_struct)); @@ -784,12 +827,14 @@ ssh_event ssh_event_new(void) { return event; } -static int ssh_event_fd_wrapper_callback(ssh_poll_handle p, socket_t fd, int revents, - void *userdata) { +static int +ssh_event_fd_wrapper_callback(ssh_poll_handle p, socket_t fd, int revents, + void *userdata) +{ struct ssh_event_fd_wrapper *pw = (struct ssh_event_fd_wrapper *)userdata; (void)p; - if(pw->cb != NULL) { + if (pw->cb != NULL) { return pw->cb(fd, revents, pw->userdata); } return 0; @@ -812,11 +857,13 @@ static int ssh_event_fd_wrapper_callback(ssh_poll_handle p, socket_t fd, int rev * @returns SSH_OK on success * SSH_ERROR on failure */ -int ssh_event_add_fd(ssh_event event, socket_t fd, short events, - ssh_event_callback cb, void *userdata) { +int +ssh_event_add_fd(ssh_event event, socket_t fd, short events, + ssh_event_callback cb, void *userdata) +{ ssh_poll_handle p; struct ssh_event_fd_wrapper *pw; - + if(event == NULL || event->ctx == NULL || cb == NULL || fd == SSH_INVALID_SOCKET) { return SSH_ERROR; @@ -872,7 +919,7 @@ void ssh_event_remove_poll(ssh_event event, ssh_poll_handle p) } /** - * @brief remove the poll handle from session and assign them to a event, + * @brief remove the poll handle from session and assign them to an event, * when used in blocking mode. * * @param event The ssh_event object @@ -881,7 +928,8 @@ void ssh_event_remove_poll(ssh_event event, ssh_poll_handle p) * @returns SSH_OK on success * SSH_ERROR on failure */ -int ssh_event_add_session(ssh_event event, ssh_session session) { +int ssh_event_add_session(ssh_event event, ssh_session session) +{ ssh_poll_handle p; #ifdef WITH_SERVER struct ssh_iterator *iterator; @@ -933,16 +981,19 @@ int ssh_event_add_session(ssh_event event, ssh_session session) { * * @return SSH_ERROR in case of error */ -int ssh_event_add_connector(ssh_event event, ssh_connector connector){ +int ssh_event_add_connector(ssh_event event, ssh_connector connector) +{ return ssh_connector_set_event(connector, event); } /** - * @brief Poll all the sockets and sessions associated through an event object.i + * @brief Poll all the sockets and sessions associated through an event object. * * If any of the events are set after the poll, the call back functions of the * sessions or sockets will be called. * This function should be called once within the programs main loop. + * In case of failure, the errno should be consulted to find more information + * about the failure set by underlying poll imlpementation. * * @param event The ssh_event object to poll. * @@ -951,13 +1002,15 @@ int ssh_event_add_connector(ssh_event event, ssh_connector connector){ * means an infinite timeout. This parameter is passed to * the poll() function. * @returns SSH_OK on success. - * SSH_ERROR Error happened during the poll. - * SSH_AGAIN Timeout occured + * SSH_ERROR Error happened during the poll. Check errno to get more + * details about why it failed. + * SSH_AGAIN Timeout occurred */ -int ssh_event_dopoll(ssh_event event, int timeout) { +int ssh_event_dopoll(ssh_event event, int timeout) +{ int rc; - if(event == NULL || event->ctx == NULL) { + if (event == NULL || event->ctx == NULL) { return SSH_ERROR; } rc = ssh_poll_ctx_dopoll(event->ctx, timeout); @@ -973,7 +1026,8 @@ int ssh_event_dopoll(ssh_event event, int timeout) { * @returns SSH_OK on success * SSH_ERROR on failure */ -int ssh_event_remove_fd(ssh_event event, socket_t fd) { +int ssh_event_remove_fd(ssh_event event, socket_t fd) +{ register size_t i, used; int rc = SSH_ERROR; @@ -1019,7 +1073,8 @@ int ssh_event_remove_fd(ssh_event event, socket_t fd) { * @returns SSH_OK on success * SSH_ERROR on failure */ -int ssh_event_remove_session(ssh_event event, ssh_session session) { +int ssh_event_remove_session(ssh_event event, ssh_session session) +{ ssh_poll_handle p; register size_t i, used; int rc = SSH_ERROR; @@ -1027,14 +1082,14 @@ int ssh_event_remove_session(ssh_event event, ssh_session session) { struct ssh_iterator *iterator; #endif - if(event == NULL || event->ctx == NULL || session == NULL) { + if (event == NULL || event->ctx == NULL || session == NULL) { return SSH_ERROR; } used = event->ctx->polls_used; - for(i = 0; i < used; i++) { - p = event->ctx->pollptrs[i]; - if(p->session == session){ + for (i = 0; i < used; i++) { + p = event->ctx->pollptrs[i]; + if (p->session == session) { /* * ssh_poll_ctx_remove() decrements * event->ctx->polls_used @@ -1054,8 +1109,8 @@ int ssh_event_remove_session(ssh_event event, ssh_session session) { } #ifdef WITH_SERVER iterator = ssh_list_get_iterator(event->sessions); - while(iterator != NULL) { - if((ssh_session)iterator->data == session) { + while (iterator != NULL) { + if ((ssh_session)iterator->data == session) { ssh_list_remove(event->sessions, iterator); /* there should be only one instance of this session */ break; @@ -1073,7 +1128,8 @@ int ssh_event_remove_session(ssh_event event, ssh_session session) { * @return SSH_OK on success * @return SSH_ERROR on failure */ -int ssh_event_remove_connector(ssh_event event, ssh_connector connector){ +int ssh_event_remove_connector(ssh_event event, ssh_connector connector) +{ (void)event; return ssh_connector_remove_event(connector); } @@ -1091,13 +1147,13 @@ void ssh_event_free(ssh_event event) size_t used, i; ssh_poll_handle p; - if(event == NULL) { + if (event == NULL) { return; } if (event->ctx != NULL) { used = event->ctx->polls_used; - for(i = 0; i < used; i++) { + for (i = 0; i < used; i++) { p = event->ctx->pollptrs[i]; if (p->session != NULL) { ssh_poll_ctx_remove(event->ctx, p); @@ -1110,7 +1166,7 @@ void ssh_event_free(ssh_event event) ssh_poll_ctx_free(event->ctx); } #ifdef WITH_SERVER - if(event->sessions != NULL) { + if (event->sessions != NULL) { ssh_list_free(event->sessions); } #endif |