diff options
Diffstat (limited to 'src/connector.c')
-rw-r--r-- | src/connector.c | 127 |
1 files changed, 120 insertions, 7 deletions
diff --git a/src/connector.c b/src/connector.c index 407aa522..c4f6af54 100644 --- a/src/connector.c +++ b/src/connector.c @@ -26,6 +26,10 @@ #include "libssh/callbacks.h" #include "libssh/session.h" #include <stdlib.h> +#include <errno.h> +#include <stdbool.h> +#include <sys/stat.h> + #define CHUNKSIZE 4096 #ifdef _WIN32 @@ -40,6 +44,9 @@ # undef unlink # define unlink _unlink # endif /* HAVE_IO_H */ +#else +# include <sys/types.h> +# include <sys/socket.h> #endif struct ssh_connector_struct { @@ -51,6 +58,8 @@ struct ssh_connector_struct { socket_t in_fd; socket_t out_fd; + bool fd_is_socket; + ssh_poll_handle in_poll; ssh_poll_handle out_poll; @@ -76,6 +85,13 @@ static int ssh_connector_channel_write_wontblock_cb(ssh_session session, ssh_channel channel, size_t bytes, void *userdata); +static ssize_t ssh_connector_fd_read(ssh_connector connector, + void *buffer, + uint32_t len); +static ssize_t ssh_connector_fd_write(ssh_connector connector, + const void *buffer, + uint32_t len); +static bool ssh_connector_fd_is_socket(socket_t socket); ssh_connector ssh_connector_new(ssh_session session) { @@ -91,6 +107,8 @@ ssh_connector ssh_connector_new(ssh_session session) connector->in_fd = SSH_INVALID_SOCKET; connector->out_fd = SSH_INVALID_SOCKET; + connector->fd_is_socket = false; + ssh_callbacks_init(&connector->in_channel_cb); ssh_callbacks_init(&connector->out_channel_cb); @@ -167,12 +185,14 @@ int ssh_connector_set_out_channel(ssh_connector connector, void ssh_connector_set_in_fd(ssh_connector connector, socket_t fd) { connector->in_fd = fd; + connector->fd_is_socket = ssh_connector_fd_is_socket(fd); connector->in_channel = NULL; } void ssh_connector_set_out_fd(ssh_connector connector, socket_t fd) { connector->out_fd = fd; + connector->fd_is_socket = ssh_connector_fd_is_socket(fd); connector->out_channel = NULL; } @@ -223,9 +243,9 @@ static void ssh_connector_reset_pollevents(ssh_connector connector) static void ssh_connector_fd_in_cb(ssh_connector connector) { unsigned char buffer[CHUNKSIZE]; - int r; - int toread = CHUNKSIZE; - int w; + uint32_t toread = CHUNKSIZE; + ssize_t r; + ssize_t w; int total = 0; int rc; @@ -239,7 +259,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector) toread = MIN(size, CHUNKSIZE); } - r = read(connector->in_fd, buffer, toread); + r = ssh_connector_fd_read(connector, buffer, toread); if (r < 0) { ssh_connector_except(connector, connector->in_fd); return; @@ -277,7 +297,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector) * bytes */ while (total != r) { - w = write(connector->out_fd, buffer + total, r - total); + w = ssh_connector_fd_write(connector, buffer + total, r - total); if (w < 0){ ssh_connector_except(connector, connector->out_fd); return; @@ -319,7 +339,7 @@ static void ssh_connector_fd_out_cb(ssh_connector connector){ } else if(r>0) { /* loop around write in case the write blocks even for CHUNKSIZE bytes */ while (total != r){ - w = write(connector->out_fd, buffer + total, r - total); + w = ssh_connector_fd_write(connector, buffer + total, r - total); if (w < 0){ ssh_connector_except(connector, connector->out_fd); return; @@ -451,7 +471,7 @@ static int ssh_connector_channel_data_cb(ssh_session session, ssh_connector_except_channel(connector, connector->out_channel); } } else if (connector->out_fd != SSH_INVALID_SOCKET) { - w = write(connector->out_fd, data, len); + w = ssh_connector_fd_write(connector, data, len); if (w < 0) ssh_connector_except(connector, connector->out_fd); } else { @@ -634,3 +654,96 @@ int ssh_connector_remove_event(ssh_connector connector) { return SSH_OK; } + +/** + * @internal + * + * @brief Check the file descriptor to check if it is a Windows socket handle. + * + */ +static bool ssh_connector_fd_is_socket(socket_t s) +{ +#ifdef _WIN32 + struct sockaddr_storage ss; + int len = sizeof(struct sockaddr_storage); + int rc; + + rc = getsockname(s, (struct sockaddr *)&ss, &len); + if (rc == 0) { + return true; + } + + SSH_LOG(SSH_LOG_TRACE, + "Error %i in getsockname() for fd %d", + WSAGetLastError(), + s); + + return false; +#else + struct stat sb; + int rc; + + rc = fstat(s, &sb); + if (rc != 0) { + SSH_LOG(SSH_LOG_TRACE, + "error %i in fstat() for fd %d", + errno, + s); + return false; + } + + /* The descriptor is a socket */ + if (S_ISSOCK(sb.st_mode)) { + return true; + } + + return false; +#endif /* _WIN32 */ +} + +/** + * @internal + * + * @brief read len bytes from socket into buffer + * + */ +static ssize_t ssh_connector_fd_read(ssh_connector connector, + void *buffer, + uint32_t len) +{ + ssize_t nread = -1; + + if (connector->fd_is_socket) { + nread = recv(connector->in_fd,buffer, len, 0); + } else { + nread = read(connector->in_fd,buffer, len); + } + + return nread; +} + +/** + * @internal + * + * @brief brief writes len bytes from buffer to socket + * + */ +static ssize_t ssh_connector_fd_write(ssh_connector connector, + const void *buffer, + uint32_t len) +{ + ssize_t bwritten = -1; + int flags = 0; + +#ifdef MSG_NOSIGNAL + flags |= MSG_NOSIGNAL; +#endif + + if (connector->fd_is_socket) { + bwritten = send(connector->out_fd,buffer, len, flags); + } else { + bwritten = write(connector->out_fd, buffer, len); + } + + return bwritten; +} |