aboutsummaryrefslogtreecommitdiff
path: root/src/connector.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/connector.c')
-rw-r--r--src/connector.c127
1 files changed, 120 insertions, 7 deletions
diff --git a/src/connector.c b/src/connector.c
index 407aa522..c4f6af54 100644
--- a/src/connector.c
+++ b/src/connector.c
@@ -26,6 +26,10 @@
#include "libssh/callbacks.h"
#include "libssh/session.h"
#include <stdlib.h>
+#include <errno.h>
+#include <stdbool.h>
+#include <sys/stat.h>
+
#define CHUNKSIZE 4096
#ifdef _WIN32
@@ -40,6 +44,9 @@
# undef unlink
# define unlink _unlink
# endif /* HAVE_IO_H */
+#else
+# include <sys/types.h>
+# include <sys/socket.h>
#endif
struct ssh_connector_struct {
@@ -51,6 +58,8 @@ struct ssh_connector_struct {
socket_t in_fd;
socket_t out_fd;
+ bool fd_is_socket;
+
ssh_poll_handle in_poll;
ssh_poll_handle out_poll;
@@ -76,6 +85,13 @@ static int ssh_connector_channel_write_wontblock_cb(ssh_session session,
ssh_channel channel,
size_t bytes,
void *userdata);
+static ssize_t ssh_connector_fd_read(ssh_connector connector,
+ void *buffer,
+ uint32_t len);
+static ssize_t ssh_connector_fd_write(ssh_connector connector,
+ const void *buffer,
+ uint32_t len);
+static bool ssh_connector_fd_is_socket(socket_t socket);
ssh_connector ssh_connector_new(ssh_session session)
{
@@ -91,6 +107,8 @@ ssh_connector ssh_connector_new(ssh_session session)
connector->in_fd = SSH_INVALID_SOCKET;
connector->out_fd = SSH_INVALID_SOCKET;
+ connector->fd_is_socket = false;
+
ssh_callbacks_init(&connector->in_channel_cb);
ssh_callbacks_init(&connector->out_channel_cb);
@@ -167,12 +185,14 @@ int ssh_connector_set_out_channel(ssh_connector connector,
void ssh_connector_set_in_fd(ssh_connector connector, socket_t fd)
{
connector->in_fd = fd;
+ connector->fd_is_socket = ssh_connector_fd_is_socket(fd);
connector->in_channel = NULL;
}
void ssh_connector_set_out_fd(ssh_connector connector, socket_t fd)
{
connector->out_fd = fd;
+ connector->fd_is_socket = ssh_connector_fd_is_socket(fd);
connector->out_channel = NULL;
}
@@ -223,9 +243,9 @@ static void ssh_connector_reset_pollevents(ssh_connector connector)
static void ssh_connector_fd_in_cb(ssh_connector connector)
{
unsigned char buffer[CHUNKSIZE];
- int r;
- int toread = CHUNKSIZE;
- int w;
+ uint32_t toread = CHUNKSIZE;
+ ssize_t r;
+ ssize_t w;
int total = 0;
int rc;
@@ -239,7 +259,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector)
toread = MIN(size, CHUNKSIZE);
}
- r = read(connector->in_fd, buffer, toread);
+ r = ssh_connector_fd_read(connector, buffer, toread);
if (r < 0) {
ssh_connector_except(connector, connector->in_fd);
return;
@@ -277,7 +297,7 @@ static void ssh_connector_fd_in_cb(ssh_connector connector)
* bytes
*/
while (total != r) {
- w = write(connector->out_fd, buffer + total, r - total);
+ w = ssh_connector_fd_write(connector, buffer + total, r - total);
if (w < 0){
ssh_connector_except(connector, connector->out_fd);
return;
@@ -319,7 +339,7 @@ static void ssh_connector_fd_out_cb(ssh_connector connector){
} else if(r>0) {
/* loop around write in case the write blocks even for CHUNKSIZE bytes */
while (total != r){
- w = write(connector->out_fd, buffer + total, r - total);
+ w = ssh_connector_fd_write(connector, buffer + total, r - total);
if (w < 0){
ssh_connector_except(connector, connector->out_fd);
return;
@@ -451,7 +471,7 @@ static int ssh_connector_channel_data_cb(ssh_session session,
ssh_connector_except_channel(connector, connector->out_channel);
}
} else if (connector->out_fd != SSH_INVALID_SOCKET) {
- w = write(connector->out_fd, data, len);
+ w = ssh_connector_fd_write(connector, data, len);
if (w < 0)
ssh_connector_except(connector, connector->out_fd);
} else {
@@ -634,3 +654,96 @@ int ssh_connector_remove_event(ssh_connector connector) {
return SSH_OK;
}
+
+/**
+ * @internal
+ *
+ * @brief Check the file descriptor to check if it is a Windows socket handle.
+ *
+ */
+static bool ssh_connector_fd_is_socket(socket_t s)
+{
+#ifdef _WIN32
+ struct sockaddr_storage ss;
+ int len = sizeof(struct sockaddr_storage);
+ int rc;
+
+ rc = getsockname(s, (struct sockaddr *)&ss, &len);
+ if (rc == 0) {
+ return true;
+ }
+
+ SSH_LOG(SSH_LOG_TRACE,
+ "Error %i in getsockname() for fd %d",
+ WSAGetLastError(),
+ s);
+
+ return false;
+#else
+ struct stat sb;
+ int rc;
+
+ rc = fstat(s, &sb);
+ if (rc != 0) {
+ SSH_LOG(SSH_LOG_TRACE,
+ "error %i in fstat() for fd %d",
+ errno,
+ s);
+ return false;
+ }
+
+ /* The descriptor is a socket */
+ if (S_ISSOCK(sb.st_mode)) {
+ return true;
+ }
+
+ return false;
+#endif /* _WIN32 */
+}
+
+/**
+ * @internal
+ *
+ * @brief read len bytes from socket into buffer
+ *
+ */
+static ssize_t ssh_connector_fd_read(ssh_connector connector,
+ void *buffer,
+ uint32_t len)
+{
+ ssize_t nread = -1;
+
+ if (connector->fd_is_socket) {
+ nread = recv(connector->in_fd,buffer, len, 0);
+ } else {
+ nread = read(connector->in_fd,buffer, len);
+ }
+
+ return nread;
+}
+
+/**
+ * @internal
+ *
+ * @brief brief writes len bytes from buffer to socket
+ *
+ */
+static ssize_t ssh_connector_fd_write(ssh_connector connector,
+ const void *buffer,
+ uint32_t len)
+{
+ ssize_t bwritten = -1;
+ int flags = 0;
+
+#ifdef MSG_NOSIGNAL
+ flags |= MSG_NOSIGNAL;
+#endif
+
+ if (connector->fd_is_socket) {
+ bwritten = send(connector->out_fd,buffer, len, flags);
+ } else {
+ bwritten = write(connector->out_fd, buffer, len);
+ }
+
+ return bwritten;
+}