diff options
-rw-r--r-- | examples/libssh_scp.c | 93 | ||||
-rw-r--r-- | include/libssh/priv.h | 9 | ||||
-rw-r--r-- | libssh/scp.c | 75 |
3 files changed, 96 insertions, 81 deletions
diff --git a/examples/libssh_scp.c b/examples/libssh_scp.c index 0cb8efe2..1a3f08ac 100644 --- a/examples/libssh_scp.c +++ b/examples/libssh_scp.c @@ -33,7 +33,7 @@ struct location { char *host; char *path; ssh_session session; - ssh_channel channel; + ssh_scp scp; FILE *file; }; @@ -150,21 +150,20 @@ static struct location *parse_location(char *loc){ } static int open_location(struct location *loc, int flag){ - char buffer[1024]; - if(loc->is_ssh && flag==WRITE){ loc->session=connect_ssh(loc->host,loc->user); if(!loc->session){ fprintf(stderr,"Couldn't connect to %s\n",loc->host); return -1; } - loc->channel=channel_new(loc->session); - channel_open_session(loc->channel); - //channel_request_pty(loc->channel); - snprintf(buffer,sizeof(buffer),"scp -vt %s",loc->path); - fprintf(stderr,"Execution of \"%s\"\n",buffer); - if(channel_request_exec(loc->channel,buffer) < 0){ - printf("error executing scp: %s\n",ssh_get_error(loc->session)); + loc->scp=ssh_scp_new(loc->session,SSH_SCP_WRITE,loc->path); + if(!loc->scp){ + fprintf(stderr,"error : %s\n",ssh_get_error(loc->session)); + return -1; + } + if(ssh_scp_init(loc->scp)==SSH_ERROR){ + fprintf(stderr,"error : %s\n",ssh_get_error(loc->session)); + ssh_scp_free(loc->scp); return -1; } return 0; @@ -193,17 +192,13 @@ static int do_copy(struct location *src, struct location *dest){ size=s.st_size; } else size=0; - - r=channel_read(dest->channel,buffer,1,0); - printf("Received %d\n", buffer[0]); - snprintf(buffer,sizeof(buffer),"C0644 %d %s\n",size,src->path); - printf("writing \"%s\"",buffer); - if(channel_write(dest->channel,buffer,strlen(buffer))<0){ - fprintf(stderr,"channel_write : %s\n",ssh_get_error(dest->session)); + r=ssh_scp_push_file(dest->scp,src->path,size,"0644"); +// snprintf(buffer,sizeof(buffer),"C0644 %d %s\n",size,src->path); + if(r==SSH_ERROR){ + fprintf(stderr,"error: %s\n",ssh_get_error(dest->session)); + ssh_scp_free(dest->scp); return -1; } - r=channel_read(dest->channel,buffer,1,0); - printf("Received %d\n", buffer[0]); do { r=fread(buffer,1,sizeof(buffer),src->file); if(r==0) @@ -212,60 +207,22 @@ static int do_copy(struct location *src, struct location *dest){ fprintf(stderr,"Error reading file: %s\n",strerror(errno)); return -1; } - w=channel_write(dest->channel,buffer,r); - //printf("."); - fflush(stdout); - //usleep(500); - if(w<0){ - fprintf(stderr,"error writing in channel: %s\n",ssh_get_error(dest->session)); - r=channel_get_exit_status(dest->channel); - if(r!=-1) - printf("Exit status : %d\n",r); + w=ssh_scp_write(dest->scp,buffer,r); + if(w == SSH_ERROR){ + fprintf(stderr,"error writing in scp: %s\n",ssh_get_error(dest->session)); + ssh_scp_free(dest->scp); return -1; } - total+=w; - if(w!=r){ - fprintf(stderr,"coulnd write %d bytes : %d\n",r,w); - } -/* if((r=channel_poll(dest->channel,0))>0){ - if((size_t)r>sizeof(buffer)) - r=sizeof(buffer); - r=channel_read(dest->channel,buffer,r,0); - buffer[r]=0; - printf("received : \"%s\"\n",buffer); - } - if((r=channel_poll(dest->channel,1))>0){ - if((size_t)r>sizeof(buffer)) - r=sizeof(buffer); - r=channel_read(dest->channel,buffer,r,1); - buffer[r]=0; - printf("received ext : \"%s\"\n",buffer); - }*/ + total+=r; } while(1); printf("wrote %d bytes\n",total); - channel_write(dest->channel,"",1); - channel_send_eof(dest->channel); - r=channel_read(dest->channel, buffer, 1, 0); - if(r>0) - printf("Received %d\n", buffer[0]); - //channel_close(dest->channel); - do{ - if((r=channel_poll(dest->channel,0))>0){ - r=channel_read(dest->channel,buffer,r,0); - if(r>0) - write(1,buffer,r); - } - if((r=channel_poll(dest->channel,1)) > 0){ - r=channel_read(dest->channel,buffer,r,1); - if(r<=0) - break; - write(1,buffer,r); - } - } while(!channel_is_eof(dest->channel) && r != SSH_ERROR); - r=channel_get_exit_status(dest->channel); - if(r!=-1) - printf("Exit status : %d\n",r); + r=ssh_scp_close(dest->scp); + if(r == SSH_ERROR){ + fprintf(stderr,"Error closing scp: %s\n",ssh_get_error(dest->session)); + ssh_scp_free(dest->scp); + return -1; + } return 0; } diff --git a/include/libssh/priv.h b/include/libssh/priv.h index 4d65652d..1c8628e2 100644 --- a/include/libssh/priv.h +++ b/include/libssh/priv.h @@ -342,11 +342,20 @@ struct ssh_keys_struct { const char *publickey; }; +enum ssh_scp_states { + SSH_SCP_NEW, //Data structure just created + SSH_SCP_WRITE_INITED, //Gave our intention to write + SSH_SCP_WRITE_WRITING,//File was opened and currently writing + SSH_SCP_READ_INITED, //Gave our intention to read + SSH_SCP_READ_READING, //File is opened and reading + SSH_SCP_ERROR //Something bad happened +}; struct ssh_scp_struct { ssh_session session; int mode; ssh_channel channel; char *location; + enum ssh_scp_states state; size_t filelen; size_t processed; }; diff --git a/libssh/scp.c b/libssh/scp.c index b2fc2c91..878b6cb6 100644 --- a/libssh/scp.c +++ b/libssh/scp.c @@ -47,6 +47,7 @@ ssh_scp ssh_scp_new(ssh_session session, int mode, const char *location){ scp->mode=mode; scp->location=strdup(location); scp->channel=NULL; + scp->state=SSH_SCP_NEW; return scp; } @@ -54,11 +55,18 @@ int ssh_scp_init(ssh_scp scp){ int r; char execbuffer[1024]; u_int8_t code; + if(scp->state != SSH_SCP_NEW){ + ssh_set_error(scp->session,SSH_FATAL,"ssh_scp_init called under invalid state"); + return SSH_ERROR; + } scp->channel=channel_new(scp->session); - if(scp->channel == NULL) + if(scp->channel == NULL){ + scp->state=SSH_SCP_ERROR; return SSH_ERROR; + } r= channel_open_session(scp->channel); if(r==SSH_ERROR){ + scp->state=SSH_SCP_ERROR; return SSH_ERROR; } if(scp->mode == SSH_SCP_WRITE) @@ -66,27 +74,42 @@ int ssh_scp_init(ssh_scp scp){ else snprintf(execbuffer,sizeof(execbuffer),"scp -f %s",scp->location); if(channel_request_exec(scp->channel,execbuffer) == SSH_ERROR){ + scp->state=SSH_SCP_ERROR; return SSH_ERROR; } r=channel_read(scp->channel,&code,1,0); if(code != 0){ ssh_set_error(scp->session,SSH_FATAL, "scp status code %ud not valid", code); + scp->state=SSH_SCP_ERROR; return SSH_ERROR; } + if(scp->mode == SSH_SCP_WRITE) + scp->state=SSH_SCP_WRITE_INITED; + else + scp->state=SSH_SCP_READ_INITED; return SSH_OK; } int ssh_scp_close(ssh_scp scp){ - if(channel_send_eof(scp->channel) == SSH_ERROR) - return SSH_ERROR; - if(channel_close(scp->channel) == SSH_ERROR) - return SSH_ERROR; - channel_free(scp->channel); - scp->channel=NULL; + if(scp->channel != NULL){ + if(channel_send_eof(scp->channel) == SSH_ERROR){ + scp->state=SSH_SCP_ERROR; + return SSH_ERROR; + } + if(channel_close(scp->channel) == SSH_ERROR){ + scp->state=SSH_SCP_ERROR; + return SSH_ERROR; + } + channel_free(scp->channel); + scp->channel=NULL; + } + scp->state=SSH_SCP_NEW; return SSH_OK; } void ssh_scp_free(ssh_scp scp){ + if(scp->state != SSH_SCP_NEW) + ssh_scp_close(scp); if(scp->channel) channel_free(scp->channel); SAFE_FREE(scp->location); @@ -104,17 +127,25 @@ int ssh_scp_push_file(ssh_scp scp, const char *filename, size_t size, const char char buffer[1024]; int r; u_int8_t code; + if(scp->state != SSH_SCP_WRITE_INITED){ + ssh_set_error(scp->session,SSH_FATAL,"ssh_scp_push_file called under invalid state"); + return SSH_ERROR; + } snprintf(buffer,sizeof(buffer),"C%s %ld %s\n",perms, size, filename); r=channel_write(scp->channel,buffer,strlen(buffer)); - if(r==SSH_ERROR) + if(r==SSH_ERROR){ + scp->state=SSH_SCP_ERROR; return SSH_ERROR; + } r=channel_read(scp->channel,&code,1,0); if(code != 0){ ssh_set_error(scp->session,SSH_FATAL, "scp status code %ud not valid", code); + scp->state=SSH_SCP_ERROR; return SSH_ERROR; } scp->filelen = size; scp->processed = 0; + scp->state=SSH_SCP_WRITE_WRITING; return SSH_OK; } @@ -125,22 +156,40 @@ int ssh_scp_push_file(ssh_scp scp, const char *filename, size_t size, const char * @returns SSH_ERROR an error happened while writing */ int ssh_scp_write(ssh_scp scp, const void *buffer, size_t len){ - int w,r; - u_int8_t code; + int w; + //int r; + //u_int8_t code; + if(scp->state != SSH_SCP_WRITE_WRITING){ + ssh_set_error(scp->session,SSH_FATAL,"ssh_scp_write called under invalid state"); + return SSH_ERROR; + } if(scp->processed + len > scp->filelen) len = scp->filelen - scp->processed; + /* hack to avoid waiting for window change */ + channel_poll(scp->channel,0); w=channel_write(scp->channel,buffer,len); if(w != SSH_ERROR) scp->processed += w; - else - return w; + else { + scp->state=SSH_SCP_ERROR; + //return=channel_get_exit_status(scp->channel); + return SSH_ERROR; + } + /* Check if we arrived at end of file */ if(scp->processed == scp->filelen) { - r=channel_read(scp->channel,&code,1,0); +/* r=channel_read(scp->channel,&code,1,0); + if(r==SSH_ERROR){ + scp->state=SSH_SCP_ERROR; + return SSH_ERROR; + } if(code != 0){ ssh_set_error(scp->session,SSH_FATAL, "scp status code %ud not valid", code); + scp->state=SSH_SCP_ERROR; return SSH_ERROR; } +*/ scp->processed=scp->filelen=0; + scp->state=SSH_SCP_WRITE_INITED; } return SSH_OK; } |