diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function index 2b0988d10..6529bbe9e 100644 --- a/tests/suites/test_suite_ssl.function +++ b/tests/suites/test_suite_ssl.function @@ -756,17 +756,26 @@ exit: * * \p endpoint_type must be set as MBEDTLS_SSL_IS_SERVER or * MBEDTLS_SSL_IS_CLIENT. + * \p pk_alg the algorithm to use, currently only MBEDTLS_PK_RSA and + * MBEDTLS_PK_ECDSA are supported. + * \p dtls_context - in case of DTLS - this is the context handling metadata. + * \p input_queue - used only in case of DTLS. + * \p output_queue - used only in case of DTLS. * * \retval 0 on success, otherwise error code. */ -int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type, int pk_alg ) +int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type, int pk_alg, + mbedtls_test_message_socket_context *dtls_context, + mbedtls_test_message_queue *input_queue, + mbedtls_test_message_queue *output_queue ) { int ret = -1; - if( ep == NULL ) - { + if( dtls_context != NULL && ( input_queue == NULL || output_queue == NULL ) ) + return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; + + if( ep == NULL ) return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; - } memset( ep, 0, sizeof( *ep ) ); @@ -779,7 +788,16 @@ int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type, int pk_alg ) mbedtls_ctr_drbg_random, &( ep->ctr_drbg ) ); mbedtls_entropy_init( &( ep->entropy ) ); - mbedtls_mock_socket_init( &( ep->socket ) ); + if( dtls_context != NULL ) + { + TEST_ASSERT( mbedtls_message_socket_setup( input_queue, output_queue, + 100, &( ep->socket ), + dtls_context ) == 0 ); + } + else + { + mbedtls_mock_socket_init( &( ep->socket ) ); + } ret = mbedtls_ctr_drbg_seed( &( ep->ctr_drbg ), mbedtls_entropy_func, &( ep->entropy ), (const unsigned char *) ( ep->name ), @@ -787,18 +805,36 @@ int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type, int pk_alg ) TEST_ASSERT( ret == 0 ); /* Non-blocking callbacks without timeout */ - mbedtls_ssl_set_bio( &( ep->ssl ), &( ep->socket ), - mbedtls_mock_tcp_send_nb, - mbedtls_mock_tcp_recv_nb, - NULL ); + if( dtls_context != NULL ) + { + mbedtls_ssl_set_bio( &( ep->ssl ), dtls_context, + mbedtls_mock_tcp_send_msg, + mbedtls_mock_tcp_recv_msg, + NULL ); + } + else + { + mbedtls_ssl_set_bio( &( ep->ssl ), &( ep->socket ), + mbedtls_mock_tcp_send_nb, + mbedtls_mock_tcp_recv_nb, + NULL ); + } ret = mbedtls_ssl_config_defaults( &( ep->conf ), endpoint_type, - MBEDTLS_SSL_TRANSPORT_STREAM, - MBEDTLS_SSL_PRESET_DEFAULT ); + ( dtls_context != NULL ) ? + MBEDTLS_SSL_TRANSPORT_DATAGRAM : + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT ); TEST_ASSERT( ret == 0 ); ret = mbedtls_ssl_setup( &( ep->ssl ), &( ep->conf ) ); TEST_ASSERT( ret == 0 ); + +#if defined(MBEDTLS_SSL_PROTO_DTLS) && defined(MBEDTLS_SSL_SRV_C) + if( endpoint_type == MBEDTLS_SSL_IS_SERVER && dtls_context != NULL ) + mbedtls_ssl_conf_dtls_cookies( &( ep->conf ), NULL, NULL, NULL ); +#endif + ret = mbedtls_endpoint_certificate_init( ep, pk_alg ); TEST_ASSERT( ret == 0 ); @@ -820,7 +856,8 @@ void mbedtls_endpoint_certificate_free( mbedtls_endpoint *ep ) /* * Deinitializes endpoint represented by \p ep. */ -void mbedtls_endpoint_free( mbedtls_endpoint *ep ) +void mbedtls_endpoint_free( mbedtls_endpoint *ep, + mbedtls_test_message_socket_context *context ) { mbedtls_endpoint_certificate_free( ep ); @@ -828,7 +865,15 @@ void mbedtls_endpoint_free( mbedtls_endpoint *ep ) mbedtls_ssl_config_free( &( ep->conf ) ); mbedtls_ctr_drbg_free( &( ep->ctr_drbg ) ); mbedtls_entropy_free( &( ep->entropy ) ); - mbedtls_mock_socket_close( &( ep->socket ) ); + + if( context != NULL ) + { + mbedtls_message_socket_close( context ); + } + else + { + mbedtls_mock_socket_close( &( ep->socket ) ); + } } /* @@ -2987,17 +3032,19 @@ void mbedtls_endpoint_sanity( int endpoint_type ) mbedtls_endpoint ep; int ret = -1; - ret = mbedtls_endpoint_init( NULL, endpoint_type, MBEDTLS_PK_RSA ); + ret = mbedtls_endpoint_init( NULL, endpoint_type, MBEDTLS_PK_RSA, + NULL, NULL, NULL ); TEST_ASSERT( MBEDTLS_ERR_SSL_BAD_INPUT_DATA == ret ); ret = mbedtls_endpoint_certificate_init( NULL, MBEDTLS_PK_RSA ); TEST_ASSERT( MBEDTLS_ERR_SSL_BAD_INPUT_DATA == ret ); - ret = mbedtls_endpoint_init( &ep, endpoint_type, MBEDTLS_PK_RSA ); + ret = mbedtls_endpoint_init( &ep, endpoint_type, MBEDTLS_PK_RSA, + NULL, NULL, NULL ); TEST_ASSERT( ret == 0 ); exit: - mbedtls_endpoint_free( &ep ); + mbedtls_endpoint_free( &ep, NULL ); } /* END_CASE */ @@ -3008,13 +3055,14 @@ void move_handshake_to_state(int endpoint_type, int state, int need_pass) mbedtls_endpoint base_ep, second_ep; int ret = -1; - ret = mbedtls_endpoint_init( &base_ep, endpoint_type, MBEDTLS_PK_RSA ); + ret = mbedtls_endpoint_init( &base_ep, endpoint_type, MBEDTLS_PK_RSA, + NULL, NULL, NULL ); TEST_ASSERT( ret == 0 ); ret = mbedtls_endpoint_init( &second_ep, ( endpoint_type == MBEDTLS_SSL_IS_SERVER ) ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER, - MBEDTLS_PK_RSA ); + MBEDTLS_PK_RSA, NULL, NULL, NULL ); TEST_ASSERT( ret == 0 ); ret = mbedtls_mock_socket_connect( &(base_ep.socket), @@ -3037,8 +3085,8 @@ void move_handshake_to_state(int endpoint_type, int state, int need_pass) } exit: - mbedtls_endpoint_free( &base_ep ); - mbedtls_endpoint_free( &second_ep ); + mbedtls_endpoint_free( &base_ep, NULL ); + mbedtls_endpoint_free( &second_ep, NULL ); } /* END_CASE */ @@ -3055,9 +3103,10 @@ void handshake( const char *cipher, int version, int pk_alg, #else (void) psk_str; #endif + /* Client side */ TEST_ASSERT( mbedtls_endpoint_init( &client, MBEDTLS_SSL_IS_CLIENT, - pk_alg ) == 0 ); + pk_alg, NULL, NULL, NULL ) == 0 ); mbedtls_ssl_conf_min_version( &client.conf, MBEDTLS_SSL_MAJOR_VERSION_3, version ); @@ -3070,7 +3119,7 @@ void handshake( const char *cipher, int version, int pk_alg, } /* Server side */ TEST_ASSERT( mbedtls_endpoint_init( &server, MBEDTLS_SSL_IS_SERVER, - pk_alg ) == 0 ); + pk_alg, NULL, NULL, NULL ) == 0 ); mbedtls_ssl_conf_min_version( &server.conf, MBEDTLS_SSL_MAJOR_VERSION_3, version ); @@ -3102,8 +3151,8 @@ void handshake( const char *cipher, int version, int pk_alg, TEST_ASSERT( server.ssl.state == MBEDTLS_SSL_HANDSHAKE_OVER ); exit: - mbedtls_endpoint_free( &client ); - mbedtls_endpoint_free( &server ); + mbedtls_endpoint_free( &client, NULL ); + mbedtls_endpoint_free( &server, NULL ); } /* END_CASE */ @@ -3120,10 +3169,12 @@ void send_application_data( int mfl, int cli_msg_len, int srv_msg_len, unsigned char *srv_in_buf = malloc( cli_msg_len ); int ret = -1; - ret = mbedtls_endpoint_init( &server, MBEDTLS_SSL_IS_SERVER, MBEDTLS_PK_RSA ); + ret = mbedtls_endpoint_init( &server, MBEDTLS_SSL_IS_SERVER, MBEDTLS_PK_RSA, + NULL, NULL, NULL ); TEST_ASSERT( ret == 0 ); - ret = mbedtls_endpoint_init( &client, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_PK_RSA ); + ret = mbedtls_endpoint_init( &client, MBEDTLS_SSL_IS_CLIENT, MBEDTLS_PK_RSA, + NULL, NULL, NULL ); TEST_ASSERT( ret == 0 ); #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) @@ -3219,8 +3270,8 @@ void send_application_data( int mfl, int cli_msg_len, int srv_msg_len, } exit: - mbedtls_endpoint_free( &client ); - mbedtls_endpoint_free( &server ); + mbedtls_endpoint_free( &client, NULL ); + mbedtls_endpoint_free( &server, NULL ); free( cli_msg_buf ); free( cli_in_buf ); free( srv_msg_buf );