aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/libssh/socket.h5
-rw-r--r--src/agent.c2
-rw-r--r--src/bind.c2
-rw-r--r--src/pcap.c2
-rw-r--r--src/poll.c22
-rw-r--r--src/session.c16
-rw-r--r--src/socket.c189
-rw-r--r--tests/unittests/torture_packet.c6
8 files changed, 85 insertions, 159 deletions
diff --git a/include/libssh/socket.h b/include/libssh/socket.h
index 8e1eac21..5c296e0d 100644
--- a/include/libssh/socket.h
+++ b/include/libssh/socket.h
@@ -34,7 +34,7 @@ ssh_socket ssh_socket_new(ssh_session session);
void ssh_socket_reset(ssh_socket s);
void ssh_socket_free(ssh_socket s);
void ssh_socket_set_fd(ssh_socket s, socket_t fd);
-socket_t ssh_socket_get_fd_in(ssh_socket s);
+socket_t ssh_socket_get_fd(ssh_socket s);
#ifndef _WIN32
int ssh_socket_unix(ssh_socket s, const char *path);
void ssh_execute_command(const char *command, socket_t in, socket_t out);
@@ -61,8 +61,7 @@ int ssh_socket_set_blocking(socket_t fd);
void ssh_socket_set_callbacks(ssh_socket s, ssh_socket_callbacks callbacks);
int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p, socket_t fd, int revents, void *v_s);
-struct ssh_poll_handle_struct * ssh_socket_get_poll_handle_in(ssh_socket s);
-struct ssh_poll_handle_struct * ssh_socket_get_poll_handle_out(ssh_socket s);
+struct ssh_poll_handle_struct * ssh_socket_get_poll_handle(ssh_socket s);
int ssh_socket_connect(ssh_socket s, const char *host, int port, const char *bind_addr);
diff --git a/src/agent.c b/src/agent.c
index d1ea5c08..78be33e6 100644
--- a/src/agent.c
+++ b/src/agent.c
@@ -73,7 +73,7 @@ static size_t atomicio(struct ssh_agent_struct *agent, void *buf, size_t n, int
/* Using a socket ? */
if (channel == NULL) {
- fd = ssh_socket_get_fd_in(agent->sock);
+ fd = ssh_socket_get_fd(agent->sock);
pfd.fd = fd;
pfd.events = do_read ? POLLIN : POLLOUT;
diff --git a/src/bind.c b/src/bind.c
index b07dd574..63ef3a94 100644
--- a/src/bind.c
+++ b/src/bind.c
@@ -447,7 +447,7 @@ int ssh_bind_accept_fd(ssh_bind sshbind, ssh_session session, socket_t fd){
return SSH_ERROR;
}
ssh_socket_set_fd(session->socket, fd);
- ssh_socket_get_poll_handle_out(session->socket);
+ ssh_socket_get_poll_handle(session->socket);
/* We must try to import any keys that could be imported in case
* we are not using ssh_bind_listen (which is the other place
diff --git a/src/pcap.c b/src/pcap.c
index ffb074cf..333e1c4e 100644
--- a/src/pcap.c
+++ b/src/pcap.c
@@ -312,7 +312,7 @@ static int ssh_pcap_context_connect(ssh_pcap_context ctx){
return SSH_ERROR;
if(session->socket==NULL)
return SSH_ERROR;
- fd=ssh_socket_get_fd_in(session->socket);
+ fd = ssh_socket_get_fd(session->socket);
/* TODO: adapt for windows */
if(fd<0)
return SSH_ERROR;
diff --git a/src/poll.c b/src/poll.c
index 8f4a0764..0ee8db20 100644
--- a/src/poll.c
+++ b/src/poll.c
@@ -533,19 +533,17 @@ int ssh_poll_ctx_add(ssh_poll_ctx ctx, ssh_poll_handle p) {
*
* @return 0 on success, < 0 on error
*/
-int ssh_poll_ctx_add_socket (ssh_poll_ctx ctx, ssh_socket s) {
- ssh_poll_handle p_in, p_out;
- int ret;
- p_in=ssh_socket_get_poll_handle_in(s);
- if(p_in==NULL)
- return -1;
- ret = ssh_poll_ctx_add(ctx,p_in);
- if(ret != 0)
+int ssh_poll_ctx_add_socket (ssh_poll_ctx ctx, ssh_socket s)
+{
+ ssh_poll_handle p;
+ int ret;
+
+ p = ssh_socket_get_poll_handle(s);
+ if (p == NULL) {
+ return -1;
+ }
+ ret = ssh_poll_ctx_add(ctx,p);
return ret;
- p_out=ssh_socket_get_poll_handle_out(s);
- if(p_in != p_out)
- ret = ssh_poll_ctx_add(ctx,p_out);
- return ret;
}
diff --git a/src/session.c b/src/session.c
index f9c45a06..3953fe76 100644
--- a/src/session.c
+++ b/src/session.c
@@ -536,7 +536,7 @@ socket_t ssh_get_fd(ssh_session session) {
return -1;
}
- return ssh_socket_get_fd_in(session->socket);
+ return ssh_socket_get_fd(session->socket);
}
/**
@@ -599,7 +599,7 @@ void ssh_set_fd_except(ssh_session session) {
* @return SSH_OK on success, SSH_ERROR otherwise.
*/
int ssh_handle_packets(ssh_session session, int timeout) {
- ssh_poll_handle spoll_in,spoll_out;
+ ssh_poll_handle spoll;
ssh_poll_ctx ctx;
int tm = timeout;
int rc;
@@ -608,17 +608,13 @@ int ssh_handle_packets(ssh_session session, int timeout) {
return SSH_ERROR;
}
- spoll_in = ssh_socket_get_poll_handle_in(session->socket);
- spoll_out = ssh_socket_get_poll_handle_out(session->socket);
- ssh_poll_add_events(spoll_in, POLLIN);
- ctx = ssh_poll_get_ctx(spoll_in);
+ spoll = ssh_socket_get_poll_handle(session->socket);
+ ssh_poll_add_events(spoll, POLLIN);
+ ctx = ssh_poll_get_ctx(spoll);
if (!ctx) {
ctx = ssh_poll_get_default_ctx(session);
- ssh_poll_ctx_add(ctx, spoll_in);
- if (spoll_in != spoll_out) {
- ssh_poll_ctx_add(ctx, spoll_out);
- }
+ ssh_poll_ctx_add(ctx, spoll);
}
if (timeout == SSH_TIMEOUT_USER) {
diff --git a/src/socket.c b/src/socket.c
index 2c72566d..8c3e68ec 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -74,8 +74,7 @@ enum ssh_socket_states_e {
};
struct ssh_socket_struct {
- socket_t fd_in;
- socket_t fd_out;
+ socket_t fd;
int fd_is_socket;
int last_errno;
int read_wontblock; /* reading now on socket will
@@ -87,8 +86,7 @@ struct ssh_socket_struct {
ssh_buffer in_buffer;
ssh_session session;
ssh_socket_callbacks callbacks;
- ssh_poll_handle poll_in;
- ssh_poll_handle poll_out;
+ ssh_poll_handle poll_handle;
};
static int sockets_initialized = 0;
@@ -149,8 +147,7 @@ ssh_socket ssh_socket_new(ssh_session session) {
ssh_set_error_oom(session);
return NULL;
}
- s->fd_in = SSH_INVALID_SOCKET;
- s->fd_out= SSH_INVALID_SOCKET;
+ s->fd = SSH_INVALID_SOCKET;
s->last_errno = -1;
s->fd_is_socket = 1;
s->session = session;
@@ -170,7 +167,7 @@ ssh_socket ssh_socket_new(ssh_session session) {
s->read_wontblock = 0;
s->write_wontblock = 0;
s->data_except = 0;
- s->poll_in=s->poll_out=NULL;
+ s->poll_handle = NULL;
s->state=SSH_SOCKET_NONE;
return s;
}
@@ -181,8 +178,7 @@ ssh_socket ssh_socket_new(ssh_session session) {
* @param[in] s socket to rest
*/
void ssh_socket_reset(ssh_socket s){
- s->fd_in = SSH_INVALID_SOCKET;
- s->fd_out= SSH_INVALID_SOCKET;
+ s->fd = SSH_INVALID_SOCKET;
s->last_errno = -1;
s->fd_is_socket = 1;
ssh_buffer_reinit(s->in_buffer);
@@ -190,7 +186,7 @@ void ssh_socket_reset(ssh_socket s){
s->read_wontblock = 0;
s->write_wontblock = 0;
s->data_except = 0;
- s->poll_in=s->poll_out=NULL;
+ s->poll_handle = NULL;
s->state=SSH_SOCKET_NONE;
}
@@ -337,7 +333,7 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p,
ssh_poll_set_events(p, POLLOUT | POLLIN);
}
- rc = ssh_socket_set_blocking(ssh_socket_get_fd_in(s));
+ rc = ssh_socket_set_blocking(ssh_socket_get_fd(s));
if (rc < 0) {
return -1;
}
@@ -370,8 +366,8 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p,
/* TODO: Find a way to put back POLLOUT when buffering occurs */
}
- /* Return -1 if one of the poll handlers disappeared */
- if (s->poll_in == NULL || s->poll_out == NULL) {
+ /* Return -1 if the poll handler disappeared */
+ if (s->poll_handle == NULL) {
return -1;
}
@@ -379,31 +375,17 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p,
}
/** @internal
- * @brief returns the input poll handle corresponding to the socket,
+ * @brief returns the poll handle corresponding to the socket,
* creates it if it does not exist.
* @returns allocated and initialized ssh_poll_handle object
*/
-ssh_poll_handle ssh_socket_get_poll_handle_in(ssh_socket s){
- if(s->poll_in)
- return s->poll_in;
- s->poll_in=ssh_poll_new(s->fd_in,0,ssh_socket_pollcallback,s);
- if(s->fd_in == s->fd_out && s->poll_out == NULL)
- s->poll_out=s->poll_in;
- return s->poll_in;
-}
-
-/** @internal
- * @brief returns the output poll handle corresponding to the socket,
- * creates it if it does not exist.
- * @returns allocated and initialized ssh_poll_handle object
- */
-ssh_poll_handle ssh_socket_get_poll_handle_out(ssh_socket s){
- if(s->poll_out)
- return s->poll_out;
- s->poll_out=ssh_poll_new(s->fd_out,0,ssh_socket_pollcallback,s);
- if(s->fd_in == s->fd_out && s->poll_in == NULL)
- s->poll_in=s->poll_out;
- return s->poll_out;
+ssh_poll_handle ssh_socket_get_poll_handle(ssh_socket s)
+{
+ if (s->poll_handle) {
+ return s->poll_handle;
+ }
+ s->poll_handle = ssh_poll_new(s->fd,0,ssh_socket_pollcallback,s);
+ return s->poll_handle;
}
/** \internal
@@ -460,27 +442,17 @@ int ssh_socket_unix(ssh_socket s, const char *path) {
void ssh_socket_close(ssh_socket s){
if (ssh_socket_is_open(s)) {
#ifdef _WIN32
- CLOSE_SOCKET(s->fd_in);
- /* fd_in = fd_out under win32 */
+ CLOSE_SOCKET(s->fd);
s->last_errno = WSAGetLastError();
#else
- if (s->fd_out != s->fd_in && s->fd_out != -1) {
- CLOSE_SOCKET(s->fd_out);
- }
- CLOSE_SOCKET(s->fd_in);
+ CLOSE_SOCKET(s->fd);
s->last_errno = errno;
#endif
}
- if(s->poll_in != NULL){
- if(s->poll_out == s->poll_in)
- s->poll_out = NULL;
- ssh_poll_free(s->poll_in);
- s->poll_in=NULL;
- }
- if(s->poll_out != NULL){
- ssh_poll_free(s->poll_out);
- s->poll_out=NULL;
+ if(s->poll_handle != NULL){
+ ssh_poll_free(s->poll_handle);
+ s->poll_handle=NULL;
}
s->state = SSH_SOCKET_CLOSED;
@@ -495,59 +467,34 @@ void ssh_socket_close(ssh_socket s){
* file descriptors
*/
void ssh_socket_set_fd(ssh_socket s, socket_t fd) {
- s->fd_in = s->fd_out = fd;
+ s->fd = fd;
- if (s->poll_in) {
- ssh_poll_set_fd(s->poll_in,fd);
+ if (s->poll_handle) {
+ ssh_poll_set_fd(s->poll_handle,fd);
} else {
s->state = SSH_SOCKET_CONNECTING;
/* POLLOUT is the event to wait for in a nonblocking connect */
- ssh_poll_set_events(ssh_socket_get_poll_handle_in(s), POLLOUT);
+ ssh_poll_set_events(ssh_socket_get_poll_handle(s), POLLOUT);
#ifdef _WIN32
- ssh_poll_add_events(ssh_socket_get_poll_handle_in(s), POLLWRNORM);
+ ssh_poll_add_events(ssh_socket_get_poll_handle(s), POLLWRNORM);
#endif
}
}
-/**
- * @internal
- * @brief sets the input file descriptor of the socket.
- * @param[out] s ssh_socket to update
- * @param[in] fd file descriptor to set
- */
-void ssh_socket_set_fd_in(ssh_socket s, socket_t fd) {
- s->fd_in = fd;
- if(s->poll_in)
- ssh_poll_set_fd(s->poll_in,fd);
-}
-
-/**
- * @internal
- * @brief sets the output file descriptor of the socket.
- * @param[out] s ssh_socket to update
- * @param[in] fd file descriptor to set
- */
-void ssh_socket_set_fd_out(ssh_socket s, socket_t fd) {
- s->fd_out = fd;
- if(s->poll_out)
- ssh_poll_set_fd(s->poll_out,fd);
-}
-
-
-
/** \internal
* \brief returns the input file descriptor of the socket
*/
-socket_t ssh_socket_get_fd_in(ssh_socket s) {
- return s->fd_in;
+socket_t ssh_socket_get_fd(ssh_socket s)
+{
+ return s->fd;
}
/** \internal
* \brief returns nonzero if the socket is open
*/
int ssh_socket_is_open(ssh_socket s) {
- return s->fd_in != SSH_INVALID_SOCKET;
+ return s->fd != SSH_INVALID_SOCKET;
}
/** \internal
@@ -563,9 +510,9 @@ static ssize_t ssh_socket_unbuffered_read(ssh_socket s,
return -1;
}
if (s->fd_is_socket) {
- rc = recv(s->fd_in,buffer, len, 0);
+ rc = recv(s->fd,buffer, len, 0);
} else {
- rc = read(s->fd_in,buffer, len);
+ rc = read(s->fd,buffer, len);
}
#ifdef _WIN32
s->last_errno = WSAGetLastError();
@@ -600,9 +547,9 @@ static ssize_t ssh_socket_unbuffered_write(ssh_socket s,
}
if (s->fd_is_socket) {
- w = send(s->fd_out, buffer, len, flags);
+ w = send(s->fd, buffer, len, flags);
} else {
- w = write(s->fd_out, buffer, len);
+ w = write(s->fd, buffer, len);
}
#ifdef _WIN32
s->last_errno = WSAGetLastError();
@@ -611,9 +558,9 @@ static ssize_t ssh_socket_unbuffered_write(ssh_socket s,
#endif
s->write_wontblock = 0;
/* Reactive the POLLOUT detector in the poll multiplexer system */
- if (s->poll_out) {
+ if (s->poll_handle) {
SSH_LOG(SSH_LOG_PACKET, "Enabling POLLOUT for socket");
- ssh_poll_set_events(s->poll_out,ssh_poll_get_events(s->poll_out) | POLLOUT);
+ ssh_poll_set_events(s->poll_handle,ssh_poll_get_events(s->poll_handle) | POLLOUT);
}
if (w < 0) {
s->data_except = 1;
@@ -626,10 +573,10 @@ static ssize_t ssh_socket_unbuffered_write(ssh_socket s,
* \brief returns nonzero if the current socket is in the fd_set
*/
int ssh_socket_fd_isset(ssh_socket s, fd_set *set) {
- if(s->fd_in == SSH_INVALID_SOCKET) {
+ if(s->fd == SSH_INVALID_SOCKET) {
return 0;
}
- return FD_ISSET(s->fd_in,set) || FD_ISSET(s->fd_out,set);
+ return FD_ISSET(s->fd,set);
}
/** \internal
@@ -637,22 +584,16 @@ int ssh_socket_fd_isset(ssh_socket s, fd_set *set) {
*/
void ssh_socket_fd_set(ssh_socket s, fd_set *set, socket_t *max_fd) {
- if (s->fd_in == SSH_INVALID_SOCKET) {
+ if (s->fd == SSH_INVALID_SOCKET) {
return;
}
- FD_SET(s->fd_in,set);
- FD_SET(s->fd_out,set);
+ FD_SET(s->fd,set);
- if (s->fd_in >= 0 &&
- s->fd_in >= *max_fd &&
- s->fd_in != SSH_INVALID_SOCKET) {
- *max_fd = s->fd_in + 1;
- }
- if (s->fd_out >= 0 &&
- s->fd_out >= *max_fd &&
- s->fd_out != SSH_INVALID_SOCKET) {
- *max_fd = s->fd_out + 1;
+ if (s->fd >= 0 &&
+ s->fd >= *max_fd &&
+ s->fd != SSH_INVALID_SOCKET) {
+ *max_fd = s->fd + 1;
}
}
@@ -701,9 +642,9 @@ int ssh_socket_nonblocking_flush(ssh_socket s)
}
len = ssh_buffer_get_len(s->out_buffer);
- if (!s->write_wontblock && s->poll_out && len > 0) {
+ if (!s->write_wontblock && s->poll_handle && len > 0) {
/* force the poll system to catch pollout events */
- ssh_poll_add_events(s->poll_out, POLLOUT);
+ ssh_poll_add_events(s->poll_handle, POLLOUT);
return SSH_AGAIN;
}
@@ -741,9 +682,9 @@ int ssh_socket_nonblocking_flush(ssh_socket s)
/* Is there some data pending? */
len = ssh_buffer_get_len(s->out_buffer);
- if (s->poll_out && len > 0) {
+ if (s->poll_handle && len > 0) {
/* force the poll system to catch pollout events */
- ssh_poll_add_events(s->poll_out, POLLOUT);
+ ssh_poll_add_events(s->poll_handle, POLLOUT);
return SSH_AGAIN;
}
@@ -804,10 +745,10 @@ int ssh_socket_get_status(ssh_socket s) {
int ssh_socket_get_poll_flags(ssh_socket s) {
int r = 0;
- if (s->poll_in != NULL && (ssh_poll_get_events (s->poll_in) & POLLIN) > 0) {
+ if (s->poll_handle != NULL && (ssh_poll_get_events (s->poll_handle) & POLLIN) > 0) {
r |= SSH_READ_PENDING;
}
- if (s->poll_out != NULL && (ssh_poll_get_events (s->poll_out) & POLLOUT) > 0) {
+ if (s->poll_handle != NULL && (ssh_poll_get_events (s->poll_handle) & POLLOUT) > 0) {
r |= SSH_WRITE_PENDING;
}
return r;
@@ -897,19 +838,15 @@ void ssh_execute_command(const char *command, socket_t in, socket_t out){
*/
int ssh_socket_connect_proxycommand(ssh_socket s, const char *command){
- socket_t in_pipe[2];
- socket_t out_pipe[2];
+ socket_t pair[2];
int pid;
int rc;
- if(s->state != SSH_SOCKET_NONE)
+ if (s->state != SSH_SOCKET_NONE) {
return SSH_ERROR;
-
- rc = pipe(in_pipe);
- if (rc < 0) {
- return SSH_ERROR;
}
- rc = pipe(out_pipe);
+
+ rc = socketpair(PF_LOCAL, SOCK_STREAM, 0, pair);
if (rc < 0) {
return SSH_ERROR;
}
@@ -917,20 +854,18 @@ int ssh_socket_connect_proxycommand(ssh_socket s, const char *command){
SSH_LOG(SSH_LOG_PROTOCOL,"Executing proxycommand '%s'",command);
pid = fork();
if(pid == 0){
- ssh_execute_command(command,out_pipe[0],in_pipe[1]);
+ ssh_execute_command(command,pair[0],pair[0]);
}
- close(in_pipe[1]);
- close(out_pipe[0]);
- SSH_LOG(SSH_LOG_PROTOCOL,"ProxyCommand connection pipe: [%d,%d]",in_pipe[0],out_pipe[1]);
- ssh_socket_set_fd_in(s,in_pipe[0]);
- ssh_socket_set_fd_out(s,out_pipe[1]);
+ close(pair[0]);
+ SSH_LOG(SSH_LOG_PROTOCOL,"ProxyCommand connection pipe: [%d,%d]",pair[0],pair[1]);
+ ssh_socket_set_fd(s, pair[1]);
s->state=SSH_SOCKET_CONNECTED;
s->fd_is_socket=0;
/* POLLOUT is the event to wait for in a nonblocking connect */
- ssh_poll_set_events(ssh_socket_get_poll_handle_in(s),POLLIN);
- ssh_poll_set_events(ssh_socket_get_poll_handle_out(s),POLLOUT);
- if(s->callbacks && s->callbacks->connected)
+ ssh_poll_set_events(ssh_socket_get_poll_handle(s), POLLIN | POLLOUT);
+ if(s->callbacks && s->callbacks->connected) {
s->callbacks->connected(SSH_SOCKET_CONNECTED_OK,0,s->callbacks->userdata);
+ }
return SSH_OK;
}
diff --git a/tests/unittests/torture_packet.c b/tests/unittests/torture_packet.c
index 61f4b049..f94e4a10 100644
--- a/tests/unittests/torture_packet.c
+++ b/tests/unittests/torture_packet.c
@@ -104,8 +104,7 @@ static void torture_packet(const char *cipher,
assert_non_null(session->out_buffer);
ssh_buffer_add_data(session->out_buffer, test_data, payload_len);
- session->socket->fd_out = sockets[0];
- session->socket->fd_in = -2;
+ session->socket->fd = sockets[0];
session->socket->write_wontblock = 1;
rc = ssh_packet_send(session);
assert_int_equal(rc, SSH_OK);
@@ -126,8 +125,7 @@ static void torture_packet(const char *cipher,
}
close(sockets[0]);
close(sockets[1]);
- session->socket->fd_in = SSH_INVALID_SOCKET;
- session->socket->fd_out = SSH_INVALID_SOCKET;
+ session->socket->fd = SSH_INVALID_SOCKET;
ssh_free(session);
}