aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authormilo <milo@r0ot.me>2012-03-09 10:55:08 +0100
committermilo <milo@r0ot.me>2012-05-05 17:48:11 +0200
commit8460f714f5e9586eca18086438190d30b9c6147b (patch)
tree7c95692ae50879b9711dca72e218661c669f3f86
parent88d6e6253c6db815d5f58888aaa4d680029c09b2 (diff)
downloadlibssh-8460f714f5e9586eca18086438190d30b9c6147b.tar.gz
libssh-8460f714f5e9586eca18086438190d30b9c6147b.tar.xz
libssh-8460f714f5e9586eca18086438190d30b9c6147b.zip
Split ssh_packet_socket_callback() in two parts to make it reentrant
-rw-r--r--include/libssh/callbacks.h6
-rw-r--r--include/libssh/packet.h3
-rw-r--r--include/libssh/priv.h1
-rw-r--r--src/client.c4
-rw-r--r--src/packet.c90
-rw-r--r--src/server.c4
-rw-r--r--src/socket.c65
7 files changed, 123 insertions, 50 deletions
diff --git a/include/libssh/callbacks.h b/include/libssh/callbacks.h
index 980480d..1bc0f5e 100644
--- a/include/libssh/callbacks.h
+++ b/include/libssh/callbacks.h
@@ -57,6 +57,7 @@ typedef void (*ssh_callback_int) (int code, void *user);
* @returns number of bytes processed by the callee. The remaining bytes will
* be sent in the next callback message, when more data is available.
*/
+typedef size_t (*ssh_callback_data_header) (const void *data, size_t len, void *user);
typedef int (*ssh_callback_data) (const void *data, size_t len, void *user);
typedef void (*ssh_callback_int_int) (int code, int errno_code, void *user);
@@ -138,6 +139,11 @@ struct ssh_socket_callbacks_struct {
* This function will be called each time data appears on socket. The data
* not consumed will appear on the next data event.
*/
+ ssh_callback_data_header data_header;
+ /**
+ * This function will be called each time data appears on socket. The data
+ * not consumed will appear on the next data event.
+ */
ssh_callback_data data;
/** This function will be called each time a controlflow state changes, i.e.
* the socket is available for reading or writing.
diff --git a/include/libssh/packet.h b/include/libssh/packet.h
index 9d934c6..a313ade 100644
--- a/include/libssh/packet.h
+++ b/include/libssh/packet.h
@@ -70,11 +70,12 @@ int ssh_packet_send_unimplemented(ssh_session session, uint32_t seqnum);
int ssh_packet_parse_type(ssh_session session);
//int packet_flush(ssh_session session, int enforce_blocking);
+size_t ssh_packet_header_socket_callback(const void *data, size_t len, void *user);
int ssh_packet_socket_callback(const void *data, size_t len, void *user);
void ssh_packet_register_socket_callback(ssh_session session, struct ssh_socket_struct *s);
void ssh_packet_set_callbacks(ssh_session session, ssh_packet_callbacks callbacks);
void ssh_packet_set_default_callbacks(ssh_session session);
-void ssh_packet_process(ssh_session session, uint8_t type);
+void ssh_packet_process(ssh_session session, void *in_buffer, uint8_t type);
/* PACKET CRYPT */
uint32_t packet_decrypt_len(ssh_session session, char *crypted);
diff --git a/include/libssh/priv.h b/include/libssh/priv.h
index 2e323b9..1ca59f1 100644
--- a/include/libssh/priv.h
+++ b/include/libssh/priv.h
@@ -175,6 +175,7 @@ void _ssh_set_error_invalid(void *error, const char *function);
int ssh_send_banner(ssh_session session, int is_server);
+void ssh_packet_process(ssh_session session, void *in_buffer, uint8_t type);
/* connect.c */
socket_t ssh_connect_host(ssh_session session, const char *host,const char
*bind_addr, int port, long timeout, long usec);
diff --git a/src/client.c b/src/client.c
index ca827eb..f270941 100644
--- a/src/client.c
+++ b/src/client.c
@@ -364,8 +364,10 @@ static void ssh_client_connection_callback(ssh_session session){
goto error;
}
/* from now, the packet layer is handling incoming packets */
- if(session->version==2)
+ if(session->version==2) {
+ session->socket_callbacks.data_header=ssh_packet_header_socket_callback;
session->socket_callbacks.data=ssh_packet_socket_callback;
+ }
#ifdef WITH_SSH1
else
session->socket_callbacks.data=ssh_packet_socket_callback1;
diff --git a/src/packet.c b/src/packet.c
index f06967a..6d8c3c9 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -120,30 +120,27 @@ static ssh_packet_callback default_packet_handlers[]= {
* @len length of data received. It might not be enough for a complete packet
* @returns number of bytes read and processed.
*/
-int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user){
+size_t ssh_packet_header_socket_callback(const void *data, size_t receivedlen, void *user){
ssh_session session=(ssh_session) user;
unsigned int blocksize = (session->current_crypto ?
session->current_crypto->in_cipher->blocksize : 8);
- int current_macsize = session->current_crypto ? MACSIZE : 0;
- unsigned char mac[30] = {0};
char buffer[16] = {0};
- void *packet=NULL;
int to_be_read;
- int rc;
- uint32_t len, compsize, payloadsize;
- uint8_t padding;
+ uint32_t len;
+ int current_macsize = session->current_crypto ? MACSIZE : 0;
size_t processed=0; /* number of byte processed from the callback */
+
enter_function();
if (session->session_state == SSH_SESSION_STATE_ERROR)
- goto error;
+ goto error;
switch(session->packet_state) {
case PACKET_STATE_INIT:
- if(receivedlen < blocksize){
- /* We didn't receive enough data to read at least one block size, give up */
- leave_function();
- return 0;
- }
+ if(receivedlen < blocksize){
+ /* We didn't receive enough data to read at least one block size, give up */
+ leave_function();
+ return 0;
+ }
memset(&session->in_packet, 0, sizeof(PACKET));
if (session->in_buffer) {
@@ -168,6 +165,9 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
if(len > MAX_PACKET_LEN) {
ssh_set_error(session, SSH_FATAL,
"read_packet(): Packet len too high(%u %.4x)", len, len);
+ ssh_set_error(session, SSH_FATAL,
+ "read_packet(): Packet len too high(%ld %.8x)", receivedlen, (unsigned)receivedlen);
+ *((char*)NULL) = 0;
goto error;
}
@@ -182,6 +182,43 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
/* saves the status of the current operations */
session->in_packet.len = len;
session->packet_state = PACKET_STATE_SIZEREAD;
+ return len + sizeof(uint32_t) + current_macsize;
+ }
+
+ ssh_set_error(session, SSH_FATAL,
+ "Invalid state into packet_read2(): %d",
+ session->packet_state);
+
+error:
+ session->session_state= SSH_SESSION_STATE_ERROR;
+ leave_function();
+ return 0;
+}
+
+/** @internal
+ * @handles a data received event. It then calls the handlers for the different packet types
+ * or and exception handler callback.
+ * @param user pointer to current ssh_session
+ * @param data pointer to the data received
+ * @len length of data received. It might not be enough for a complete packet
+ * @returns number of bytes read and processed.
+ */
+int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user){
+ ssh_session session=(ssh_session) user;
+ unsigned int blocksize = (session->current_crypto ?
+ session->current_crypto->in_cipher->blocksize : 8);
+ int current_macsize = session->current_crypto ? MACSIZE : 0;
+ unsigned char mac[30] = {0};
+ void *packet=NULL;
+ int to_be_read;
+ uint32_t len, compsize, payloadsize;
+ uint8_t padding;
+ size_t processed=blocksize; /* number of byte processed from the callback */
+
+ enter_function();
+ if (session->session_state == SSH_SESSION_STATE_ERROR)
+ goto error;
+ switch(session->packet_state) {
case PACKET_STATE_SIZEREAD:
len = session->in_packet.len;
to_be_read = len - blocksize + sizeof(uint32_t) + current_macsize;
@@ -194,7 +231,6 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
}
packet = (unsigned char *)data + processed;
-// ssh_socket_read(session->socket,packet,to_be_read-current_macsize);
if (buffer_add_data(session->in_buffer, packet,
to_be_read - current_macsize) < 0) {
@@ -254,27 +290,22 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
payloadsize=buffer_get_rest_len(session->in_buffer);
session->recv_seq++;
/* We don't want to rewrite a new packet while still executing the packet callbacks */
- session->packet_state = PACKET_STATE_PROCESSING;
ssh_packet_parse_type(session);
ssh_log(session,SSH_LOG_PACKET,
"packet: read type %hhd [len=%d,padding=%hhd,comp=%d,payload=%d]",
session->in_packet.type, len, padding, compsize, payloadsize);
/* execute callbacks */
- ssh_packet_process(session, session->in_packet.type);
- session->packet_state = PACKET_STATE_INIT;
- if(processed < receivedlen){
- /* Handle a potential packet left in socket buffer */
- ssh_log(session,SSH_LOG_PACKET,"Processing %" PRIdS " bytes left in socket buffer",
- receivedlen-processed);
- rc = ssh_packet_socket_callback((char *)data + processed,
- receivedlen - processed,user);
- processed += rc;
+ {
+ ssh_buffer in_buffer = ssh_buffer_new();
+ size_t bufsize = buffer_get_rest_len(session->in_buffer);
+ buffer_add_data(in_buffer, buffer_get_rest(session->in_buffer), bufsize);
+ session->packet_state = PACKET_STATE_INIT;
+ ssh_packet_process(session, in_buffer, session->in_packet.type);
+ buffer_pass_bytes(session->in_buffer, bufsize - buffer_get_rest_len(in_buffer));
+ ssh_buffer_free(in_buffer);
}
leave_function();
return processed;
- case PACKET_STATE_PROCESSING:
- ssh_log(session, SSH_LOG_RARE, "Nested packet processing. Delaying.");
- return 0;
}
ssh_set_error(session, SSH_FATAL,
@@ -288,6 +319,7 @@ error:
}
void ssh_packet_register_socket_callback(ssh_session session, ssh_socket s){
+ session->socket_callbacks.data_header=ssh_packet_header_socket_callback;
session->socket_callbacks.data=ssh_packet_socket_callback;
session->socket_callbacks.connected=NULL;
session->socket_callbacks.controlflow=NULL;
@@ -327,7 +359,7 @@ void ssh_packet_set_default_callbacks(ssh_session session){
* @brief dispatch the call of packet handlers callbacks for a received packet
* @param type type of packet
*/
-void ssh_packet_process(ssh_session session, uint8_t type){
+void ssh_packet_process(ssh_session session, void *in_buffer, uint8_t type){
struct ssh_iterator *i;
int r=SSH_PACKET_NOT_USED;
ssh_packet_callbacks cb;
@@ -349,7 +381,7 @@ void ssh_packet_process(ssh_session session, uint8_t type){
continue;
if(cb->callbacks[type - cb->start]==NULL)
continue;
- r=cb->callbacks[type - cb->start](session,type,session->in_buffer,cb->user);
+ r=cb->callbacks[type - cb->start](session,type,in_buffer,cb->user);
if(r==SSH_PACKET_USED)
break;
}
diff --git a/src/server.c b/src/server.c
index ac3fec3..2e5c0ce 100644
--- a/src/server.c
+++ b/src/server.c
@@ -374,8 +374,10 @@ static void ssh_server_connection_callback(ssh_session session){
goto error;
}
/* from now, the packet layer is handling incoming packets */
- if(session->version==2)
+ if(session->version==2) {
+ session->socket_callbacks.data_header=ssh_packet_header_socket_callback;
session->socket_callbacks.data=ssh_packet_socket_callback;
+ }
#ifdef WITH_SSH1
else
session->socket_callbacks.data=ssh_packet_socket_callback1;
diff --git a/src/socket.c b/src/socket.c
index 5bde836..0a799f8 100644
--- a/src/socket.c
+++ b/src/socket.c
@@ -82,6 +82,7 @@ struct ssh_socket_struct {
not block */
int write_wontblock;
int data_except;
+ size_t packet_length;
enum ssh_socket_states_e state;
ssh_buffer out_buffer;
ssh_buffer in_buffer;
@@ -259,10 +260,7 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p, socket_t fd, int r
return -2;
}
}
- if(r==0){
- if(p != NULL) {
- ssh_poll_remove_events(p, POLLIN);
- }
+ if(r==0 && buffer_get_rest_len(s->in_buffer) == 0) {
if(p != NULL) {
ssh_poll_remove_events(p, POLLIN);
}
@@ -276,20 +274,51 @@ int ssh_socket_pollcallback(struct ssh_poll_handle_struct *p, socket_t fd, int r
return -2;
}
}
- if(r>0){
- /* Bufferize the data and then call the callback */
- buffer_add_data(s->in_buffer,buffer,r);
- if(s->callbacks && s->callbacks->data){
- r= s->callbacks->data(buffer_get_rest(s->in_buffer),
- buffer_get_rest_len(s->in_buffer),
- s->callbacks->userdata);
- buffer_pass_bytes(s->in_buffer,r);
- /* p may have been freed, so don't use it
- * anymore in this function */
- p = NULL;
- }
- }
- }
+ if(r>0){
+ /* Bufferize the data and then call the callback */
+ buffer_add_data(s->in_buffer,buffer,r);
+ do {
+ if(s->callbacks && s->callbacks->data_header){
+ if(s->packet_length <= 0) {
+ r= s->callbacks->data_header(buffer_get_rest(s->in_buffer),
+ buffer_get_rest_len(s->in_buffer),
+ s->callbacks->userdata);
+ /* p may have been freed, so don't use it
+ * anymore in this function */
+ p = NULL;
+ if(r > 0) {
+ s->packet_length = r;
+ } else {
+ break;
+ }
+ }
+ } else {
+ s->packet_length = buffer_get_rest_len(s->in_buffer);
+ }
+ if(s->packet_length > 0) {
+ if(s->callbacks && s->callbacks->data){
+ if(!s->callbacks->data_header || (s->packet_length > 0
+ && buffer_get_rest_len(s->in_buffer) >= s->packet_length)) {
+ ssh_buffer in_buffer = ssh_buffer_new();
+ buffer_add_data(in_buffer, buffer_get_rest(s->in_buffer),
+ s->packet_length);
+ buffer_pass_bytes(s->in_buffer,s->packet_length);
+ s->packet_length = 0;
+ r= s->callbacks->data(buffer_get_rest(in_buffer),
+ buffer_get_rest_len(in_buffer),
+ s->callbacks->userdata);
+ ssh_buffer_free(in_buffer);
+ /* p may have been freed, so don't use it
+ * anymore in this function */
+ p = NULL;
+ } else {
+ break;
+ }
+ }
+ }
+ } while(buffer_get_rest_len(s->in_buffer) > 0);
+ }
+ }
#ifdef _WIN32
if(revents & POLLOUT || revents & POLLWRNORM){
#else