aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/curve25519.c49
-rw-r--r--src/dh.c149
-rw-r--r--src/ecdh.c21
-rw-r--r--src/ecdh_crypto.c37
-rw-r--r--src/ecdh_gcrypt.c14
-rw-r--r--src/ecdh_mbedcrypto.c13
-rw-r--r--src/server.c151
-rw-r--r--src/wrapper.c32
8 files changed, 291 insertions, 175 deletions
diff --git a/src/curve25519.c b/src/curve25519.c
index 6738da61..73551542 100644
--- a/src/curve25519.c
+++ b/src/curve25519.c
@@ -205,10 +205,31 @@ error:
#ifdef WITH_SERVER
+static SSH_PACKET_CALLBACK(ssh_packet_server_curve25519_init);
+
+static ssh_packet_callback dh_server_callbacks[]= {
+ ssh_packet_server_curve25519_init
+};
+
+static struct ssh_packet_callbacks_struct ssh_curve25519_server_callbacks = {
+ .start = SSH2_MSG_KEX_ECDH_INIT,
+ .n_callbacks = 1,
+ .callbacks = dh_server_callbacks,
+ .user = NULL
+};
+
+/** @internal
+ * @brief sets up the curve25519-sha256@libssh.org kex callbacks
+ */
+void ssh_server_curve25519_init(ssh_session session){
+ /* register the packet callbacks */
+ ssh_packet_set_callbacks(session, &ssh_curve25519_server_callbacks);
+}
+
/** @brief Parse a SSH_MSG_KEXDH_INIT packet (server) and send a
* SSH_MSG_KEXDH_REPLY
*/
-int ssh_server_curve25519_init(ssh_session session, ssh_buffer packet){
+static SSH_PACKET_CALLBACK(ssh_packet_server_curve25519_init){
/* ECDH keys */
ssh_string q_c_string;
ssh_string q_s_string;
@@ -219,18 +240,24 @@ int ssh_server_curve25519_init(ssh_session session, ssh_buffer packet){
ssh_string sig_blob = NULL;
int ok;
int rc;
+ (void)type;
+ (void)user;
+
+ ssh_packet_remove_callbacks(session, &ssh_curve25519_server_callbacks);
/* Extract the client pubkey from the init packet */
q_c_string = ssh_buffer_get_ssh_string(packet);
if (q_c_string == NULL) {
ssh_set_error(session,SSH_FATAL, "No Q_C ECC point in packet");
- return SSH_ERROR;
+ goto error;
}
if (ssh_string_len(q_c_string) != CURVE25519_PUBKEY_SIZE){
- ssh_set_error(session, SSH_FATAL, "Incorrect size for server Curve25519 public key: %d",
- (int)ssh_string_len(q_c_string));
- ssh_string_free(q_c_string);
- return SSH_ERROR;
+ ssh_set_error(session,
+ SSH_FATAL,
+ "Incorrect size for server Curve25519 public key: %zu",
+ ssh_string_len(q_c_string));
+ ssh_string_free(q_c_string);
+ goto error;
}
memcpy(session->next_crypto->curve25519_client_pubkey,
@@ -241,7 +268,7 @@ int ssh_server_curve25519_init(ssh_session session, ssh_buffer packet){
ok = ssh_get_random(session->next_crypto->curve25519_privkey, CURVE25519_PRIVKEY_SIZE, 1);
if (!ok) {
ssh_set_error(session, SSH_FATAL, "PRNG error");
- return SSH_ERROR;
+ goto error;
}
crypto_scalarmult_base(session->next_crypto->curve25519_server_pubkey,
@@ -331,12 +358,16 @@ int ssh_server_curve25519_init(ssh_session session, ssh_buffer packet){
session->dh_handshake_state = DH_STATE_NEWKEYS_SENT;
rc = ssh_packet_send(session);
+ if (rc == SSH_ERROR) {
+ goto error;
+ }
SSH_LOG(SSH_LOG_PROTOCOL, "SSH_MSG_NEWKEYS sent");
- return rc;
+ return SSH_PACKET_USED;
error:
ssh_buffer_reinit(session->out_buffer);
- return SSH_ERROR;
+ session->session_state=SSH_SESSION_STATE_ERROR;
+ return SSH_PACKET_USED;
}
#endif /* WITH_SERVER */
diff --git a/src/dh.c b/src/dh.c
index 474f52f3..6158b465 100644
--- a/src/dh.c
+++ b/src/dh.c
@@ -772,6 +772,155 @@ error:
return SSH_PACKET_USED;
}
+#ifdef WITH_SERVER
+
+static SSH_PACKET_CALLBACK(ssh_packet_server_dh_init);
+
+static ssh_packet_callback dh_server_callbacks[] = {
+ ssh_packet_server_dh_init,
+};
+
+static struct ssh_packet_callbacks_struct ssh_dh_server_callbacks = {
+ .start = SSH2_MSG_KEXDH_INIT,
+ .n_callbacks = 1,
+ .callbacks = dh_server_callbacks,
+ .user = NULL
+};
+
+/** @internal
+ * @brief sets up the diffie-hellman-groupx kex callbacks
+ */
+void ssh_server_dh_init(ssh_session session){
+ /* register the packet callbacks */
+ ssh_packet_set_callbacks(session, &ssh_dh_server_callbacks);
+}
+
+static int dh_handshake_server(ssh_session session)
+{
+ ssh_key privkey = NULL;
+ ssh_string sig_blob = NULL;
+ ssh_string f = NULL;
+ ssh_string pubkey_blob = NULL;
+ int rc;
+
+ rc = ssh_dh_generate_y(session);
+ if (rc < 0) {
+ ssh_set_error(session, SSH_FATAL, "Could not create y number");
+ return -1;
+ }
+ rc = ssh_dh_generate_f(session);
+ if (rc < 0) {
+ ssh_set_error(session, SSH_FATAL, "Could not create f number");
+ return -1;
+ }
+
+ f = ssh_dh_get_f(session);
+ if (f == NULL) {
+ ssh_set_error(session, SSH_FATAL, "Could not get the f number");
+ return -1;
+ }
+
+ if (ssh_get_key_params(session,&privkey) != SSH_OK){
+ ssh_string_free(f);
+ return -1;
+ }
+
+ rc = ssh_dh_build_k(session);
+ if (rc < 0) {
+ ssh_set_error(session, SSH_FATAL, "Could not import the public key");
+ ssh_string_free(f);
+ return -1;
+ }
+
+ rc = ssh_make_sessionid(session);
+ if (rc != SSH_OK) {
+ ssh_set_error(session, SSH_FATAL, "Could not create a session id");
+ ssh_string_free(f);
+ return -1;
+ }
+
+ sig_blob = ssh_srv_pki_do_sign_sessionid(session, privkey);
+ if (sig_blob == NULL) {
+ ssh_set_error(session, SSH_FATAL, "Could not sign the session id");
+ ssh_string_free(f);
+ return -1;
+ }
+ rc = ssh_dh_get_next_server_publickey_blob(session, &pubkey_blob);
+ if (rc != SSH_OK){
+ ssh_set_error_oom(session);
+ ssh_string_free(f);
+ ssh_string_free(sig_blob);
+ return -1;
+ }
+ rc = ssh_buffer_pack(session->out_buffer,
+ "bSSS",
+ SSH2_MSG_KEXDH_REPLY,
+ pubkey_blob,
+ f,
+ sig_blob);
+ ssh_string_free(f);
+ ssh_string_free(sig_blob);
+ if (rc != SSH_OK) {
+ ssh_set_error_oom(session);
+ ssh_buffer_reinit(session->out_buffer);
+ return -1;
+ }
+
+ rc = ssh_packet_send(session);
+ if (rc == SSH_ERROR) {
+ return -1;
+ }
+
+ rc = ssh_buffer_add_u8(session->out_buffer, SSH2_MSG_NEWKEYS);
+ if (rc < 0) {
+ ssh_buffer_reinit(session->out_buffer);
+ return -1;
+ }
+
+ rc = ssh_packet_send(session);
+ if (rc == SSH_ERROR) {
+ return -1;
+ }
+ SSH_LOG(SSH_LOG_PACKET, "SSH_MSG_NEWKEYS sent");
+ session->dh_handshake_state=DH_STATE_NEWKEYS_SENT;
+
+ return 0;
+}
+
+/** @internal
+ * @brief parse an incoming SSH_MSG_KEXDH_INIT packet and complete
+ * Diffie-Hellman key exchange
+ **/
+static SSH_PACKET_CALLBACK(ssh_packet_server_dh_init)
+{
+ ssh_string e = NULL;
+ int rc;
+
+ (void)type;
+ (void)user;
+
+ ssh_packet_remove_callbacks(session, &ssh_dh_server_callbacks);
+ e = ssh_buffer_get_ssh_string(packet);
+ if (e == NULL) {
+ ssh_set_error(session, SSH_FATAL, "No e number in client request");
+ return -1;
+ }
+ rc = ssh_dh_import_e(session, e);
+ if (rc < 0) {
+ ssh_set_error(session, SSH_FATAL, "Cannot import e number");
+ goto error;
+ }
+ session->dh_handshake_state = DH_STATE_INIT_SENT;
+ dh_handshake_server(session);
+ ssh_string_free(e);
+ return SSH_PACKET_USED;
+error:
+ session->session_state = SSH_SESSION_STATE_ERROR;
+ return SSH_PACKET_USED;
+}
+
+#endif /* WITH_SERVER */
+
int ssh_make_sessionid(ssh_session session) {
ssh_string num = NULL;
ssh_buffer server_hash = NULL;
diff --git a/src/ecdh.c b/src/ecdh.c
index 71779da9..1be1d927 100644
--- a/src/ecdh.c
+++ b/src/ecdh.c
@@ -107,4 +107,25 @@ error:
return SSH_PACKET_USED;
}
+#ifdef WITH_SERVER
+
+static ssh_packet_callback ecdh_server_callbacks[] = {
+ ssh_packet_server_ecdh_init
+};
+
+struct ssh_packet_callbacks_struct ssh_ecdh_server_callbacks = {
+ .start = SSH2_MSG_KEX_ECDH_INIT,
+ .n_callbacks = 1,
+ .callbacks = ecdh_server_callbacks,
+ .user = NULL
+};
+
+/** @internal
+ * @brief sets up the ecdh kex callbacks
+ */
+void ssh_server_ecdh_init(ssh_session session){
+ ssh_packet_set_callbacks(session, &ssh_ecdh_server_callbacks);
+}
+
+#endif /* WITH_SERVER */
#endif /* HAVE_ECDH */
diff --git a/src/ecdh_crypto.c b/src/ecdh_crypto.c
index 10cc6a5f..950578a1 100644
--- a/src/ecdh_crypto.c
+++ b/src/ecdh_crypto.c
@@ -195,11 +195,10 @@ int ecdh_build_k(ssh_session session) {
#ifdef WITH_SERVER
-/** @brief Parse a SSH_MSG_KEXDH_INIT packet (server) and send a
+/** @brief Handle a SSH_MSG_KEXDH_INIT packet (server) and send a
* SSH_MSG_KEXDH_REPLY
*/
-
-int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
+SSH_PACKET_CALLBACK(ssh_packet_server_ecdh_init){
/* ECDH keys */
ssh_string q_c_string;
ssh_string q_s_string;
@@ -214,12 +213,15 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
int curve;
int len;
int rc;
+ (void)type;
+ (void)user;
+ ssh_packet_remove_callbacks(session, &ssh_ecdh_server_callbacks);
/* Extract the client pubkey from the init packet */
q_c_string = ssh_buffer_get_ssh_string(packet);
if (q_c_string == NULL) {
ssh_set_error(session,SSH_FATAL, "No Q_C ECC point in packet");
- return SSH_ERROR;
+ goto error;
}
session->next_crypto->ecdh_client_pubkey = q_c_string;
@@ -237,7 +239,7 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
if (ecdh_key == NULL) {
ssh_set_error_oom(session);
BN_CTX_free(ctx);
- return SSH_ERROR;
+ goto error;
}
group = EC_KEY_get0_group(ecdh_key);
@@ -255,7 +257,7 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
if (q_s_string == NULL) {
EC_KEY_free(ecdh_key);
BN_CTX_free(ctx);
- return SSH_ERROR;
+ goto error;
}
EC_POINT_point2oct(group,
@@ -273,25 +275,25 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
rc = ecdh_build_k(session);
if (rc < 0) {
ssh_set_error(session, SSH_FATAL, "Cannot build k number");
- return SSH_ERROR;
+ goto error;
}
/* privkey is not allocated */
rc = ssh_get_key_params(session, &privkey);
if (rc == SSH_ERROR) {
- return SSH_ERROR;
+ goto error;
}
rc = ssh_make_sessionid(session);
if (rc != SSH_OK) {
ssh_set_error(session, SSH_FATAL, "Could not create a session id");
- return SSH_ERROR;
+ goto error;
}
sig_blob = ssh_srv_pki_do_sign_sessionid(session, privkey);
if (sig_blob == NULL) {
ssh_set_error(session, SSH_FATAL, "Could not sign the session id");
- return SSH_ERROR;
+ goto error;
}
rc = ssh_dh_get_next_server_publickey_blob(session, &pubkey_blob);
@@ -313,26 +315,33 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet){
if (rc != SSH_OK) {
ssh_set_error_oom(session);
- return SSH_ERROR;
+ goto error;
}
SSH_LOG(SSH_LOG_PROTOCOL, "SSH_MSG_KEXDH_REPLY sent");
rc = ssh_packet_send(session);
if (rc == SSH_ERROR) {
- return SSH_ERROR;
+ goto error;
}
/* Send the MSG_NEWKEYS */
rc = ssh_buffer_add_u8(session->out_buffer, SSH2_MSG_NEWKEYS);
if (rc < 0) {
- return SSH_ERROR;;
+ goto error;
}
session->dh_handshake_state = DH_STATE_NEWKEYS_SENT;
rc = ssh_packet_send(session);
+ if (rc == SSH_ERROR){
+ goto error;
+ }
SSH_LOG(SSH_LOG_PROTOCOL, "SSH_MSG_NEWKEYS sent");
- return rc;
+ return SSH_PACKET_USED;
+error:
+ ssh_buffer_reinit(session->out_buffer);
+ session->session_state = SSH_SESSION_STATE_ERROR;
+ return SSH_PACKET_USED;
}
#endif /* WITH_SERVER */
diff --git a/src/ecdh_gcrypt.c b/src/ecdh_gcrypt.c
index 96dbd1a0..913855c0 100644
--- a/src/ecdh_gcrypt.c
+++ b/src/ecdh_gcrypt.c
@@ -259,10 +259,11 @@ int ecdh_build_k(ssh_session session)
#ifdef WITH_SERVER
-/** @brief Parse a SSH_MSG_KEXDH_INIT packet (server) and send a
+
+/** @brief Handle a SSH_MSG_KEXDH_INIT packet (server) and send a
* SSH_MSG_KEXDH_REPLY
*/
-int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet) {
+SSH_PACKET_CALLBACK(ssh_packet_server_ecdh_init){
gpg_error_t err;
/* ECDH keys */
ssh_string q_c_string;
@@ -275,7 +276,10 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet) {
ssh_string pubkey_blob = NULL;
int rc = SSH_ERROR;
const char *curve = NULL;
+ (void)type;
+ (void)user;
+ ssh_packet_remove_callbacks(session, &ssh_ecdh_server_callbacks);
curve = ecdh_kex_type_to_curve(session->next_crypto->kex_type);
if (curve == NULL) {
goto out;
@@ -380,7 +384,11 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet) {
out:
gcry_sexp_release(param);
gcry_sexp_release(key);
- return rc;
+ if (rc == SSH_ERROR) {
+ ssh_buffer_reinit(session->out_buffer);
+ session->session_state = SSH_SESSION_STATE_ERROR;
+ }
+ return SSH_PACKET_USED;
}
#endif /* WITH_SERVER */
diff --git a/src/ecdh_mbedcrypto.c b/src/ecdh_mbedcrypto.c
index 3ff93ee8..68033f7d 100644
--- a/src/ecdh_mbedcrypto.c
+++ b/src/ecdh_mbedcrypto.c
@@ -182,8 +182,8 @@ out:
}
#ifdef WITH_SERVER
-int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet)
-{
+
+SSH_PACKET_CALLBACK(ssh_packet_server_ecdh_init){
ssh_string q_c_string = NULL;
ssh_string q_s_string = NULL;
mbedtls_ecp_group grp;
@@ -192,7 +192,10 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet)
ssh_string pubkey_blob = NULL;
int rc;
mbedtls_ecp_group_id curve;
+ (void)type;
+ (void)user;
+ ssh_packet_remove_callbacks(session, &ssh_ecdh_server_callbacks);
curve = ecdh_kex_type_to_curve(session->next_crypto->kex_type);
if (curve == MBEDTLS_ECP_DP_NONE) {
return SSH_ERROR;
@@ -308,7 +311,11 @@ int ssh_server_ecdh_init(ssh_session session, ssh_buffer packet)
out:
mbedtls_ecp_group_free(&grp);
- return rc;
+ if (rc == SSH_ERROR) {
+ ssh_buffer_reinit(session->out_buffer);
+ session->session_state = SSH_SESSION_STATE_ERROR;
+ }
+ return SSH_PACKET_USED;
}
#endif /* WITH_SERVER */
diff --git a/src/server.c b/src/server.c
index a586964f..472e2c01 100644
--- a/src/server.c
+++ b/src/server.c
@@ -65,8 +65,6 @@
session->common.callbacks->connect_status_function(session->common.callbacks->userdata, status); \
} while (0)
-static int dh_handshake_server(ssh_session session);
-
/**
* @addtogroup libssh_server
*
@@ -177,28 +175,6 @@ int ssh_server_init_kex(ssh_session session) {
return server_set_kex(session);
}
-/** @internal
- * @brief parse an incoming SSH_MSG_KEXDH_INIT packet and complete
- * key exchange
- **/
-static int ssh_server_kexdh_init(ssh_session session, ssh_buffer packet){
- ssh_string e;
- e = ssh_buffer_get_ssh_string(packet);
- if (e == NULL) {
- ssh_set_error(session, SSH_FATAL, "No e number in client request");
- return -1;
- }
- if (ssh_dh_import_e(session, e) < 0) {
- ssh_set_error(session, SSH_FATAL, "Cannot import e number");
- session->session_state=SSH_SESSION_STATE_ERROR;
- } else {
- session->dh_handshake_state=DH_STATE_INIT_SENT;
- dh_handshake_server(session);
- }
- ssh_string_free(e);
- return SSH_OK;
-}
-
static int ssh_server_send_extensions(ssh_session session) {
int rc;
const char *hostkey_algorithms;
@@ -231,14 +207,15 @@ error:
}
SSH_PACKET_CALLBACK(ssh_packet_kexdh_init){
- int rc = SSH_ERROR;
+ (void)packet;
(void)type;
(void)user;
SSH_LOG(SSH_LOG_PACKET,"Received SSH_MSG_KEXDH_INIT");
if(session->dh_handshake_state != DH_STATE_INIT){
SSH_LOG(SSH_LOG_RARE,"Invalid state for SSH_MSG_KEXDH_INIT");
- goto error;
+ session->session_state = SSH_SESSION_STATE_ERROR;
+ return SSH_PACKET_USED;
}
/* If first_kex_packet_follows guess was wrong, ignore this message. */
@@ -246,41 +223,10 @@ SSH_PACKET_CALLBACK(ssh_packet_kexdh_init){
SSH_LOG(SSH_LOG_RARE, "first_kex_packet_follows guess was wrong, "
"ignoring first SSH_MSG_KEXDH_INIT message");
session->first_kex_follows_guess_wrong = 0;
- rc = SSH_OK;
- goto error;
- }
-
- switch(session->next_crypto->kex_type){
- case SSH_KEX_DH_GROUP1_SHA1:
- case SSH_KEX_DH_GROUP14_SHA1:
- case SSH_KEX_DH_GROUP16_SHA512:
- case SSH_KEX_DH_GROUP18_SHA512:
- rc=ssh_server_kexdh_init(session, packet);
- break;
- #ifdef HAVE_ECDH
- case SSH_KEX_ECDH_SHA2_NISTP256:
- case SSH_KEX_ECDH_SHA2_NISTP384:
- case SSH_KEX_ECDH_SHA2_NISTP521:
- rc = ssh_server_ecdh_init(session, packet);
- break;
- #endif
- #ifdef HAVE_CURVE25519
- case SSH_KEX_CURVE25519_SHA256:
- case SSH_KEX_CURVE25519_SHA256_LIBSSH_ORG:
- rc = ssh_server_curve25519_init(session, packet);
- break;
- #endif
- default:
- ssh_set_error(session,SSH_FATAL,"Wrong kex type in ssh_packet_kexdh_init");
- rc = SSH_ERROR;
- }
-error:
- if (rc == SSH_ERROR) {
- session->session_state = SSH_SESSION_STATE_ERROR;
+ return SSH_PACKET_USED;
}
-
- return SSH_PACKET_USED;
+ return SSH_PACKET_NOT_USED;
}
int ssh_get_key_params(ssh_session session, ssh_key *privkey){
@@ -334,93 +280,6 @@ int ssh_get_key_params(ssh_session session, ssh_key *privkey){
return SSH_OK;
}
-static int dh_handshake_server(ssh_session session) {
- ssh_key privkey;
- ssh_string sig_blob;
- ssh_string f;
- ssh_string pubkey_blob = NULL;
- int rc;
-
- if (ssh_dh_generate_y(session) < 0) {
- ssh_set_error(session, SSH_FATAL, "Could not create y number");
- return -1;
- }
- if (ssh_dh_generate_f(session) < 0) {
- ssh_set_error(session, SSH_FATAL, "Could not create f number");
- return -1;
- }
-
- f = ssh_dh_get_f(session);
- if (f == NULL) {
- ssh_set_error(session, SSH_FATAL, "Could not get the f number");
- return -1;
- }
-
- if (ssh_get_key_params(session,&privkey) != SSH_OK){
- ssh_string_free(f);
- return -1;
- }
-
- if (ssh_dh_build_k(session) < 0) {
- ssh_set_error(session, SSH_FATAL, "Could not import the public key");
- ssh_string_free(f);
- return -1;
- }
-
- if (ssh_make_sessionid(session) != SSH_OK) {
- ssh_set_error(session, SSH_FATAL, "Could not create a session id");
- ssh_string_free(f);
- return -1;
- }
-
- sig_blob = ssh_srv_pki_do_sign_sessionid(session, privkey);
- if (sig_blob == NULL) {
- ssh_set_error(session, SSH_FATAL, "Could not sign the session id");
- ssh_string_free(f);
- return -1;
- }
-
- rc = ssh_dh_get_next_server_publickey_blob(session, &pubkey_blob);
- if (rc != SSH_OK) {
- ssh_set_error_oom(session);
- ssh_string_free(f);
- ssh_string_free(sig_blob);
- return -1;
- }
-
- rc = ssh_buffer_pack(session->out_buffer,
- "bSSS",
- SSH2_MSG_KEXDH_REPLY,
- pubkey_blob,
- f,
- sig_blob);
- ssh_string_free(f);
- ssh_string_free(sig_blob);
- ssh_string_free(pubkey_blob);
- if(rc != SSH_OK){
- ssh_set_error_oom(session);
- ssh_buffer_reinit(session->out_buffer);
- return -1;
- }
-
- if (ssh_packet_send(session) == SSH_ERROR) {
- return -1;
- }
-
- if (ssh_buffer_add_u8(session->out_buffer, SSH2_MSG_NEWKEYS) < 0) {
- ssh_buffer_reinit(session->out_buffer);
- return -1;
- }
-
- if (ssh_packet_send(session) == SSH_ERROR) {
- return -1;
- }
- SSH_LOG(SSH_LOG_PACKET, "SSH_MSG_NEWKEYS sent");
- session->dh_handshake_state=DH_STATE_NEWKEYS_SENT;
-
- return 0;
-}
-
/**
* @internal
*
diff --git a/src/wrapper.c b/src/wrapper.c
index 7724dcf6..b4429e47 100644
--- a/src/wrapper.c
+++ b/src/wrapper.c
@@ -48,6 +48,9 @@
#include "libssh/wrapper.h"
#include "libssh/pki.h"
#include "libssh/poly1305.h"
+#include "libssh/dh.h"
+#include "libssh/ecdh.h"
+#include "libssh/curve25519.h"
static struct ssh_hmac_struct ssh_hmac_tab[] = {
{ "hmac-sha1", SSH_HMAC_SHA1 },
@@ -530,6 +533,35 @@ int crypt_set_algorithms_server(ssh_session session){
method = session->next_crypto->kex_methods[SSH_HOSTKEYS];
session->srv.hostkey = ssh_key_type_from_signature_name(method);
+ /* setup DH key exchange type */
+ switch (session->next_crypto->kex_type) {
+ case SSH_KEX_DH_GROUP1_SHA1:
+ case SSH_KEX_DH_GROUP14_SHA1:
+ case SSH_KEX_DH_GROUP16_SHA512:
+ case SSH_KEX_DH_GROUP18_SHA512:
+ ssh_server_dh_init(session);
+ break;
+#ifdef HAVE_ECDH
+ case SSH_KEX_ECDH_SHA2_NISTP256:
+ case SSH_KEX_ECDH_SHA2_NISTP384:
+ case SSH_KEX_ECDH_SHA2_NISTP521:
+ ssh_server_ecdh_init(session);
+ break;
+#endif
+#ifdef HAVE_CURVE25519
+ case SSH_KEX_CURVE25519_SHA256:
+ case SSH_KEX_CURVE25519_SHA256_LIBSSH_ORG:
+ ssh_server_curve25519_init(session);
+ break;
+#endif
+ default:
+ ssh_set_error(session,
+ SSH_FATAL,
+ "crypt_set_algorithms_server: could not find init "
+ "handler for kex type %d",
+ session->next_crypto->kex_type);
+ return SSH_ERROR;
+ }
return SSH_OK;
}