From 1961b709d8d4cbf16bb6da632c78acb49289f446 Mon Sep 17 00:00:00 2001 From: Paul Bakker Date: Fri, 25 Jan 2013 14:49:24 +0100 Subject: [PATCH] Added ssl_handshake_step() to allow single stepping the handshake process Single stepping the handshake process allows for better support of non-blocking network stacks and for getting information from specific handshake messages if wanted. --- ChangeLog | 3 + include/polarssl/ssl.h | 18 +++- library/ssl_cli.c | 174 +++++++++++++++++------------------ library/ssl_srv.c | 200 ++++++++++++++++++++--------------------- library/ssl_tls.c | 30 +++++-- 5 files changed, 222 insertions(+), 203 deletions(-) diff --git a/ChangeLog b/ChangeLog index 701f86b32..f59006966 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,6 +1,9 @@ PolarSSL ChangeLog = Version Master +Changes + * Added ssl_handshake_step() to allow single stepping the handshake process + Bugfix * Memory leak when using RSA_PKCS_V21 operations fixed * Handle future version properly in ssl_write_certificate_request() diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index e5d9eb73c..9746e276b 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -970,6 +970,20 @@ const x509_cert *ssl_get_peer_cert( const ssl_context *ssl ); */ int ssl_handshake( ssl_context *ssl ); +/** + * \brief Perform a single step of the SSL handshake + * + * Note: the state of the context (ssl->state) will be at + * the following state after execution of this function. + * Do not call this function if state is SSL_HANDSHAKE_OVER. + * + * \param ssl SSL context + * + * \return 0 if successful, POLARSSL_ERR_NET_WANT_READ, + * POLARSSL_ERR_NET_WANT_WRITE, or a specific SSL error code. + */ +int ssl_handshake_step( ssl_context *ssl ); + /** * \brief Perform an SSL renegotiation on the running connection * @@ -1061,8 +1075,8 @@ void ssl_handshake_free( ssl_handshake_params *handshake ); /* * Internal functions (do not call directly) */ -int ssl_handshake_client( ssl_context *ssl ); -int ssl_handshake_server( ssl_context *ssl ); +int ssl_handshake_client_step( ssl_context *ssl ); +int ssl_handshake_server_step( ssl_context *ssl ); void ssl_handshake_wrapup( ssl_context *ssl ); int ssl_send_fatal_handshake_failure( ssl_context *ssl ); diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 42ddf4131..545906a2a 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1274,121 +1274,113 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) } /* - * SSL handshake -- client side + * SSL handshake -- client side -- single step */ -int ssl_handshake_client( ssl_context *ssl ) +int ssl_handshake_client_step( ssl_context *ssl ) { int ret = 0; - SSL_DEBUG_MSG( 2, ( "=> handshake client" ) ); + if( ssl->state == SSL_HANDSHAKE_OVER ) + return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - while( ssl->state != SSL_HANDSHAKE_OVER ) + SSL_DEBUG_MSG( 2, ( "client state: %d", ssl->state ) ); + + if( ( ret = ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + switch( ssl->state ) { - SSL_DEBUG_MSG( 2, ( "client state: %d", ssl->state ) ); - - if( ( ret = ssl_flush_output( ssl ) ) != 0 ) + case SSL_HELLO_REQUEST: + ssl->state = SSL_CLIENT_HELLO; break; - switch( ssl->state ) - { - case SSL_HELLO_REQUEST: - ssl->state = SSL_CLIENT_HELLO; - break; + /* + * ==> ClientHello + */ + case SSL_CLIENT_HELLO: + ret = ssl_write_client_hello( ssl ); + break; - /* - * ==> ClientHello - */ - case SSL_CLIENT_HELLO: - ret = ssl_write_client_hello( ssl ); - break; + /* + * <== ServerHello + * Certificate + * ( ServerKeyExchange ) + * ( CertificateRequest ) + * ServerHelloDone + */ + case SSL_SERVER_HELLO: + ret = ssl_parse_server_hello( ssl ); + break; - /* - * <== ServerHello - * Certificate - * ( ServerKeyExchange ) - * ( CertificateRequest ) - * ServerHelloDone - */ - case SSL_SERVER_HELLO: - ret = ssl_parse_server_hello( ssl ); - break; + case SSL_SERVER_CERTIFICATE: + ret = ssl_parse_certificate( ssl ); + break; - case SSL_SERVER_CERTIFICATE: - ret = ssl_parse_certificate( ssl ); - break; + case SSL_SERVER_KEY_EXCHANGE: + ret = ssl_parse_server_key_exchange( ssl ); + break; - case SSL_SERVER_KEY_EXCHANGE: - ret = ssl_parse_server_key_exchange( ssl ); - break; + case SSL_CERTIFICATE_REQUEST: + ret = ssl_parse_certificate_request( ssl ); + break; - case SSL_CERTIFICATE_REQUEST: - ret = ssl_parse_certificate_request( ssl ); - break; + case SSL_SERVER_HELLO_DONE: + ret = ssl_parse_server_hello_done( ssl ); + break; - case SSL_SERVER_HELLO_DONE: - ret = ssl_parse_server_hello_done( ssl ); - break; + /* + * ==> ( Certificate/Alert ) + * ClientKeyExchange + * ( CertificateVerify ) + * ChangeCipherSpec + * Finished + */ + case SSL_CLIENT_CERTIFICATE: + ret = ssl_write_certificate( ssl ); + break; - /* - * ==> ( Certificate/Alert ) - * ClientKeyExchange - * ( CertificateVerify ) - * ChangeCipherSpec - * Finished - */ - case SSL_CLIENT_CERTIFICATE: - ret = ssl_write_certificate( ssl ); - break; + case SSL_CLIENT_KEY_EXCHANGE: + ret = ssl_write_client_key_exchange( ssl ); + break; - case SSL_CLIENT_KEY_EXCHANGE: - ret = ssl_write_client_key_exchange( ssl ); - break; + case SSL_CERTIFICATE_VERIFY: + ret = ssl_write_certificate_verify( ssl ); + break; - case SSL_CERTIFICATE_VERIFY: - ret = ssl_write_certificate_verify( ssl ); - break; + case SSL_CLIENT_CHANGE_CIPHER_SPEC: + ret = ssl_write_change_cipher_spec( ssl ); + break; - case SSL_CLIENT_CHANGE_CIPHER_SPEC: - ret = ssl_write_change_cipher_spec( ssl ); - break; + case SSL_CLIENT_FINISHED: + ret = ssl_write_finished( ssl ); + break; - case SSL_CLIENT_FINISHED: - ret = ssl_write_finished( ssl ); - break; + /* + * <== ChangeCipherSpec + * Finished + */ + case SSL_SERVER_CHANGE_CIPHER_SPEC: + ret = ssl_parse_change_cipher_spec( ssl ); + break; - /* - * <== ChangeCipherSpec - * Finished - */ - case SSL_SERVER_CHANGE_CIPHER_SPEC: - ret = ssl_parse_change_cipher_spec( ssl ); - break; + case SSL_SERVER_FINISHED: + ret = ssl_parse_finished( ssl ); + break; - case SSL_SERVER_FINISHED: - ret = ssl_parse_finished( ssl ); - break; + case SSL_FLUSH_BUFFERS: + SSL_DEBUG_MSG( 2, ( "handshake: done" ) ); + ssl->state = SSL_HANDSHAKE_WRAPUP; + break; - case SSL_FLUSH_BUFFERS: - SSL_DEBUG_MSG( 2, ( "handshake: done" ) ); - ssl->state = SSL_HANDSHAKE_WRAPUP; - break; + case SSL_HANDSHAKE_WRAPUP: + ssl_handshake_wrapup( ssl ); + break; - case SSL_HANDSHAKE_WRAPUP: - ssl_handshake_wrapup( ssl ); - break; - - default: - SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) ); - return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - } - - if( ret != 0 ) - break; - } - - SSL_DEBUG_MSG( 2, ( "<= handshake client" ) ); + default: + SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) ); + return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); + } return( ret ); } - #endif diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 38253933a..df57cb31f 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -1293,121 +1293,113 @@ static int ssl_parse_certificate_verify( ssl_context *ssl ) } /* - * SSL handshake -- server side + * SSL handshake -- server side -- single step */ -int ssl_handshake_server( ssl_context *ssl ) +int ssl_handshake_server_step( ssl_context *ssl ) { int ret = 0; - SSL_DEBUG_MSG( 2, ( "=> handshake server" ) ); + if( ssl->state == SSL_HANDSHAKE_OVER ) + return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - while( ssl->state != SSL_HANDSHAKE_OVER ) + SSL_DEBUG_MSG( 2, ( "server state: %d", ssl->state ) ); + + if( ( ret = ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + switch( ssl->state ) { - SSL_DEBUG_MSG( 2, ( "server state: %d", ssl->state ) ); - - if( ( ret = ssl_flush_output( ssl ) ) != 0 ) + case SSL_HELLO_REQUEST: + ssl->state = SSL_CLIENT_HELLO; break; - switch( ssl->state ) - { - case SSL_HELLO_REQUEST: - ssl->state = SSL_CLIENT_HELLO; - break; - - /* - * <== ClientHello - */ - case SSL_CLIENT_HELLO: - ret = ssl_parse_client_hello( ssl ); - break; - - /* - * ==> ServerHello - * Certificate - * ( ServerKeyExchange ) - * ( CertificateRequest ) - * ServerHelloDone - */ - case SSL_SERVER_HELLO: - ret = ssl_write_server_hello( ssl ); - break; - - case SSL_SERVER_CERTIFICATE: - ret = ssl_write_certificate( ssl ); - break; - - case SSL_SERVER_KEY_EXCHANGE: - ret = ssl_write_server_key_exchange( ssl ); - break; - - case SSL_CERTIFICATE_REQUEST: - ret = ssl_write_certificate_request( ssl ); - break; - - case SSL_SERVER_HELLO_DONE: - ret = ssl_write_server_hello_done( ssl ); - break; - - /* - * <== ( Certificate/Alert ) - * ClientKeyExchange - * ( CertificateVerify ) - * ChangeCipherSpec - * Finished - */ - case SSL_CLIENT_CERTIFICATE: - ret = ssl_parse_certificate( ssl ); - break; - - case SSL_CLIENT_KEY_EXCHANGE: - ret = ssl_parse_client_key_exchange( ssl ); - break; - - case SSL_CERTIFICATE_VERIFY: - ret = ssl_parse_certificate_verify( ssl ); - break; - - case SSL_CLIENT_CHANGE_CIPHER_SPEC: - ret = ssl_parse_change_cipher_spec( ssl ); - break; - - case SSL_CLIENT_FINISHED: - ret = ssl_parse_finished( ssl ); - break; - - /* - * ==> ChangeCipherSpec - * Finished - */ - case SSL_SERVER_CHANGE_CIPHER_SPEC: - ret = ssl_write_change_cipher_spec( ssl ); - break; - - case SSL_SERVER_FINISHED: - ret = ssl_write_finished( ssl ); - break; - - case SSL_FLUSH_BUFFERS: - SSL_DEBUG_MSG( 2, ( "handshake: done" ) ); - ssl->state = SSL_HANDSHAKE_WRAPUP; - break; - - case SSL_HANDSHAKE_WRAPUP: - ssl_handshake_wrapup( ssl ); - break; - - default: - SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) ); - return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - } - - if( ret != 0 ) + /* + * <== ClientHello + */ + case SSL_CLIENT_HELLO: + ret = ssl_parse_client_hello( ssl ); break; + + /* + * ==> ServerHello + * Certificate + * ( ServerKeyExchange ) + * ( CertificateRequest ) + * ServerHelloDone + */ + case SSL_SERVER_HELLO: + ret = ssl_write_server_hello( ssl ); + break; + + case SSL_SERVER_CERTIFICATE: + ret = ssl_write_certificate( ssl ); + break; + + case SSL_SERVER_KEY_EXCHANGE: + ret = ssl_write_server_key_exchange( ssl ); + break; + + case SSL_CERTIFICATE_REQUEST: + ret = ssl_write_certificate_request( ssl ); + break; + + case SSL_SERVER_HELLO_DONE: + ret = ssl_write_server_hello_done( ssl ); + break; + + /* + * <== ( Certificate/Alert ) + * ClientKeyExchange + * ( CertificateVerify ) + * ChangeCipherSpec + * Finished + */ + case SSL_CLIENT_CERTIFICATE: + ret = ssl_parse_certificate( ssl ); + break; + + case SSL_CLIENT_KEY_EXCHANGE: + ret = ssl_parse_client_key_exchange( ssl ); + break; + + case SSL_CERTIFICATE_VERIFY: + ret = ssl_parse_certificate_verify( ssl ); + break; + + case SSL_CLIENT_CHANGE_CIPHER_SPEC: + ret = ssl_parse_change_cipher_spec( ssl ); + break; + + case SSL_CLIENT_FINISHED: + ret = ssl_parse_finished( ssl ); + break; + + /* + * ==> ChangeCipherSpec + * Finished + */ + case SSL_SERVER_CHANGE_CIPHER_SPEC: + ret = ssl_write_change_cipher_spec( ssl ); + break; + + case SSL_SERVER_FINISHED: + ret = ssl_write_finished( ssl ); + break; + + case SSL_FLUSH_BUFFERS: + SSL_DEBUG_MSG( 2, ( "handshake: done" ) ); + ssl->state = SSL_HANDSHAKE_WRAPUP; + break; + + case SSL_HANDSHAKE_WRAPUP: + ssl_handshake_wrapup( ssl ); + break; + + default: + SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) ); + return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); } - SSL_DEBUG_MSG( 2, ( "<= handshake server" ) ); - return( ret ); } - #endif diff --git a/library/ssl_tls.c b/library/ssl_tls.c index e0a64ab06..94113928f 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3513,24 +3513,42 @@ const int ssl_default_ciphersuites[] = }; /* - * Perform the SSL handshake + * Perform a single step of the SSL handshake */ -int ssl_handshake( ssl_context *ssl ) +int ssl_handshake_step( ssl_context *ssl ) { int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE; - SSL_DEBUG_MSG( 2, ( "=> handshake" ) ); - #if defined(POLARSSL_SSL_CLI_C) if( ssl->endpoint == SSL_IS_CLIENT ) - ret = ssl_handshake_client( ssl ); + ret = ssl_handshake_client_step( ssl ); #endif #if defined(POLARSSL_SSL_SRV_C) if( ssl->endpoint == SSL_IS_SERVER ) - ret = ssl_handshake_server( ssl ); + ret = ssl_handshake_server_step( ssl ); #endif + return( ret ); +} + +/* + * Perform the SSL handshake + */ +int ssl_handshake( ssl_context *ssl ) +{ + int ret = 0; + + SSL_DEBUG_MSG( 2, ( "=> handshake" ) ); + + while( ssl->state != SSL_HANDSHAKE_OVER ) + { + ret = ssl_handshake_step( ssl ); + + if( ret != 0 ) + break; + } + SSL_DEBUG_MSG( 2, ( "<= handshake" ) ); return( ret );