/* * This file is part of the SSH Library * * Copyright (c) 2018 by Anderson Toshiyuki Sasaki * * The SSH Library is free software; you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation; either version 2.1 of the License, or (at your * option) any later version. * * The SSH Library is distributed in the hope that it will be useful, but * WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY * or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public * License for more details. * * You should have received a copy of the GNU Lesser General Public License * along with the SSH Library; see the file COPYING. If not, write to * the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston, * MA 02111-1307, USA. */ /* * This test checks if the messages accepted by the packet filter were intented * to be accepted. * * The process consists in 2 steps: * - Try the filter with a message type in an arbitrary state * - If the message is accepted by the filter, check if the message is in the * set of accepted states. * * Only the values selected by the flag (COMPARE_*) are considered. * */ #include "config.h" #define LIBSSH_STATIC #include "torture.h" #include "libssh/priv.h" #include "libssh/libssh.h" #include "libssh/session.h" #include "libssh/auth.h" #include "libssh/ssh2.h" #include "libssh/packet.h" #include "packet.c" #define COMPARE_SESSION_STATE 1 #define COMPARE_ROLE (1 << 1) #define COMPARE_DH_STATE (1 << 2) #define COMPARE_AUTH_STATE (1 << 3) #define COMPARE_GLOBAL_REQ_STATE (1 << 4) #define COMPARE_CURRENT_METHOD (1 << 5) #define SESSION_STATE_COUNT 11 #define DH_STATE_COUNT 4 #define AUTH_STATE_COUNT 15 #define GLOBAL_REQ_STATE_COUNT 5 #define MESSAGE_COUNT 100 // from 1 to 100 #define ROLE_CLIENT 0 #define ROLE_SERVER 1 /* * This is the list of currently unfiltered message types. * Only unrecognized types should be in this list. * */ static uint8_t unfiltered[] = { 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 22, 23, 24, 25, 26, 27, 28, 29, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 54, 55, 56, 57, 58, 59, 62, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 83, 84, 85, 86, 87, 88, 89, }; typedef struct global_state_st { /* If the bit in this flag is zero, the corresponding state is not * considered, working as a wildcard (meaning any value is accepted) */ uint32_t flags; uint8_t role; enum ssh_session_state_e session; enum ssh_dh_state_e dh; enum ssh_auth_state_e auth; enum ssh_channel_request_state_e global_req; } global_state; static int cmp_state(const void *e1, const void *e2) { global_state *s1 = (global_state *) e1; global_state *s2 = (global_state *) e2; /* Compare role (client == 0 or server == 1)*/ if (s1->role < s2->role) { return -1; } else if (s1->role > s2->role) { return 1; } /* Compare session state */ if (s1->session < s2->session) { return -1; } else if (s1->session > s2->session) { return 1; } /* Compare DH state */ if (s1->dh < s2->dh) { return -1; } else if (s1->dh > s2->dh) { return 1; } /* Compare auth */ if (s1->auth < s2->auth) { return -1; } else if (s1->auth > s2->auth) { return 1; } /* Compare global_req */ if (s1->global_req < s2->global_req) { return -1; } else if (s1->global_req > s2->global_req) { return 1; } /* If all equal, they are equal */ return 0; } static int cmp_state_search(const void *key, const void *array_element) { global_state *s1 = (global_state *) key; global_state *s2 = (global_state *) array_element; int result = 0; if (s2->flags & COMPARE_ROLE) { /* Compare role (client == 0 or server == 1)*/ if (s1->role < s2->role) { return -1; } else if (s1->role > s2->role) { return 1; } } if (s2->flags & COMPARE_SESSION_STATE) { /* Compare session state */ if (s1->session < s2->session) { result = -1; goto end; } else if (s1->session > s2->session) { result = 1; goto end; } } if (s2->flags & COMPARE_DH_STATE) { /* Compare DH state */ if (s1->dh < s2->dh) { result = -1; goto end; } else if (s1->dh > s2->dh) { result = 1; goto end; } } if (s2->flags & COMPARE_AUTH_STATE) { /* Compare auth */ if (s1->auth < s2->auth) { result = -1; goto end; } else if (s1->auth > s2->auth) { result = 1; goto end; } } if (s2->flags & COMPARE_GLOBAL_REQ_STATE) { /* Compare global_req */ if (s1->global_req < s2->global_req) { result = -1; goto end; } else if (s1->global_req > s2->global_req) { result = 1; goto end; } } end: return result; } static int is_state_accepted(global_state *tested, global_state *accepted, int accepted_len) { global_state *found = NULL; found = bsearch(tested, accepted, accepted_len, sizeof(global_state), cmp_state_search); if (found != NULL) { return 1; } return 0; } static int cmp_uint8(const void *i, const void *j) { uint8_t e1 = *((uint8_t *)i); uint8_t e2 = *((uint8_t *)j); if (e1 < e2) { return -1; } else if (e1 > e2) { return 1; } return 0; } static int check_unfiltered(uint8_t msg_type) { uint8_t *found; found = bsearch(&msg_type, unfiltered, sizeof(unfiltered)/sizeof(uint8_t), sizeof(uint8_t), cmp_uint8); if (found != NULL) { return 1; } return 0; } static void torture_packet_filter_check_unfiltered(void **state) { ssh_session session; int role_c; int auth_c; int session_c; int dh_c; int global_req_c; uint8_t msg_type; enum ssh_packet_filter_result_e rc; int in_unfiltered; session = ssh_new(); for (msg_type = 1; msg_type <= MESSAGE_COUNT; msg_type++) { session->in_packet.type = msg_type; for (role_c = 0; role_c < 2; role_c++) { session->server = role_c; for (session_c = 0; session_c < SESSION_STATE_COUNT; session_c++) { session->session_state = session_c; for (dh_c = 0; dh_c < DH_STATE_COUNT; dh_c++) { session->dh_handshake_state = dh_c; for (auth_c = 0; auth_c < AUTH_STATE_COUNT; auth_c++) { session->auth.state = auth_c; for (global_req_c = 0; global_req_c < GLOBAL_REQ_STATE_COUNT; global_req_c++) { session->global_req_state = global_req_c; rc = ssh_packet_incoming_filter(session); if (rc == SSH_PACKET_UNKNOWN) { in_unfiltered = check_unfiltered(msg_type); if (!in_unfiltered) { fprintf(stderr, "Message type %d UNFILTERED " "in state: role %d, session %d, dh %d, auth %d\n", msg_type, role_c, session_c, dh_c, auth_c); } assert_int_equal(in_unfiltered, 1); } else { in_unfiltered = check_unfiltered(msg_type); if (in_unfiltered) { fprintf(stderr, "Message type %d NOT UNFILTERED " "in state: role %d, session %d, dh %d, auth %d\n", msg_type, role_c, session_c, dh_c, auth_c); } assert_int_equal(in_unfiltered, 0); } } } } } } } ssh_free(session); } static int check_message_in_all_states(global_state accepted[], int accepted_count, uint8_t msg_type) { ssh_session session; int role_c; int auth_c; int session_c; int dh_c; int global_req_c; enum ssh_packet_filter_result_e rc; int in_accepted; global_state key; session = ssh_new(); /* Sort the accepted array so that the elements can be searched using * bsearch */ qsort(accepted, accepted_count, sizeof(global_state), cmp_state); session->in_packet.type = msg_type; for (role_c = 0; role_c < 2; role_c++) { session->server = role_c; key.role = role_c; for (session_c = 0; session_c < SESSION_STATE_COUNT; session_c++) { session->session_state = session_c; key.session = session_c; for (dh_c = 0; dh_c < DH_STATE_COUNT; dh_c++) { session->dh_handshake_state = dh_c; key.dh = dh_c; for (auth_c = 0; auth_c < AUTH_STATE_COUNT; auth_c++) { session->auth.state = auth_c; key.auth = auth_c; for (global_req_c = 0; global_req_c < GLOBAL_REQ_STATE_COUNT; global_req_c++) { session->global_req_state = global_req_c; key.global_req = global_req_c; rc = ssh_packet_incoming_filter(session); if (rc == SSH_PACKET_ALLOWED) { in_accepted = is_state_accepted(&key, accepted, accepted_count); if (!in_accepted) { fprintf(stderr, "Message type %d ALLOWED " "in state: role %d, session %d, dh %d, auth %d\n", msg_type, role_c, session_c, dh_c, auth_c); } assert_int_equal(in_accepted, 1); } else if (rc == SSH_PACKET_DENIED) { in_accepted = is_state_accepted(&key, accepted, accepted_count); if (in_accepted) { fprintf(stderr, "Message type %d DENIED " "in state: role %d, session %d, dh %d, auth %d\n", msg_type, role_c, session_c, dh_c, auth_c); } assert_int_equal(in_accepted, 0); } else { fprintf(stderr, "Message type %d UNFILTERED " "in state: role %d, session %d, dh %d, auth %d\n", msg_type, role_c, session_c, dh_c, auth_c); } } } } } } ssh_free(session); return 0; } static void torture_packet_filter_check_auth_success(void **state) { int rc; global_state accepted[] = { { .flags = (COMPARE_SESSION_STATE | COMPARE_ROLE | COMPARE_AUTH_STATE | COMPARE_DH_STATE), .role = ROLE_CLIENT, .session = SSH_SESSION_STATE_AUTHENTICATING, .dh = DH_STATE_FINISHED, .auth = SSH_AUTH_STATE_PUBKEY_AUTH_SENT, }, { .flags = (COMPARE_SESSION_STATE | COMPARE_ROLE | COMPARE_AUTH_STATE | COMPARE_DH_STATE), .role = ROLE_CLIENT, .session = SSH_SESSION_STATE_AUTHENTICATING, .dh = DH_STATE_FINISHED, .auth = SSH_AUTH_STATE_PASSWORD_AUTH_SENT, }, { .flags = (COMPARE_SESSION_STATE | COMPARE_ROLE | COMPARE_AUTH_STATE | COMPARE_DH_STATE), .role = ROLE_CLIENT, .session = SSH_SESSION_STATE_AUTHENTICATING, .dh = DH_STATE_FINISHED, .auth = SSH_AUTH_STATE_GSSAPI_MIC_SENT, }, { .flags = (COMPARE_SESSION_STATE | COMPARE_ROLE | COMPARE_AUTH_STATE | COMPARE_DH_STATE), .role = ROLE_CLIENT, .session = SSH_SESSION_STATE_AUTHENTICATING, .dh = DH_STATE_FINISHED, .auth = SSH_AUTH_STATE_KBDINT_SENT, }, { .flags = (COMPARE_SESSION_STATE | COMPARE_ROLE | COMPARE_AUTH_STATE | COMPARE_DH_STATE | COMPARE_CURRENT_METHOD), .role = ROLE_CLIENT, .session = SSH_SESSION_STATE_AUTHENTICATING, .dh = DH_STATE_FINISHED, .auth = SSH_AUTH_STATE_AUTH_NONE_SENT, } }; int accepted_count = 5; /* Unused */ (void) state; rc = check_message_in_all_states(accepted, accepted_count, SSH2_MSG_USERAUTH_SUCCESS); assert_int_equal(rc, 0); } static void torture_packet_filter_check_channel_open(void **state) { int rc; /* The only condition to accept a CHANNEL_OPEN is to be authenticated */ global_state accepted[] = { { .flags = COMPARE_SESSION_STATE, .session = SSH_SESSION_STATE_AUTHENTICATED, } }; int accepted_count = 1; /* Unused */ (void) state; rc = check_message_in_all_states(accepted, accepted_count, SSH2_MSG_CHANNEL_OPEN); assert_int_equal(rc, 0); } int torture_run_tests(void) { int rc; struct CMUnitTest tests[] = { cmocka_unit_test(torture_packet_filter_check_auth_success), cmocka_unit_test(torture_packet_filter_check_channel_open), cmocka_unit_test(torture_packet_filter_check_unfiltered), }; ssh_init(); torture_filter_tests(tests); rc = cmocka_run_group_tests(tests, NULL, NULL); ssh_finalize(); return rc; }