From 1961b709d8d4cbf16bb6da632c78acb49289f446 Mon Sep 17 00:00:00 2001
From: Paul Bakker
Date: Fri, 25 Jan 2013 14:49:24 +0100
Subject: [PATCH] Added ssl_handshake_step() to allow single stepping the
handshake process
Single stepping the handshake process allows for better support of
non-blocking network stacks and for getting information from specific
handshake messages if wanted.
---
ChangeLog | 3 +
include/polarssl/ssl.h | 18 +++-
library/ssl_cli.c | 174 +++++++++++++++++------------------
library/ssl_srv.c | 200 ++++++++++++++++++++---------------------
library/ssl_tls.c | 30 +++++--
5 files changed, 222 insertions(+), 203 deletions(-)
diff --git a/ChangeLog b/ChangeLog
index 701f86b32..f59006966 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,6 +1,9 @@
PolarSSL ChangeLog
= Version Master
+Changes
+ * Added ssl_handshake_step() to allow single stepping the handshake process
+
Bugfix
* Memory leak when using RSA_PKCS_V21 operations fixed
* Handle future version properly in ssl_write_certificate_request()
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index e5d9eb73c..9746e276b 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -970,6 +970,20 @@ const x509_cert *ssl_get_peer_cert( const ssl_context *ssl );
*/
int ssl_handshake( ssl_context *ssl );
+/**
+ * \brief Perform a single step of the SSL handshake
+ *
+ * Note: the state of the context (ssl->state) will be at
+ * the following state after execution of this function.
+ * Do not call this function if state is SSL_HANDSHAKE_OVER.
+ *
+ * \param ssl SSL context
+ *
+ * \return 0 if successful, POLARSSL_ERR_NET_WANT_READ,
+ * POLARSSL_ERR_NET_WANT_WRITE, or a specific SSL error code.
+ */
+int ssl_handshake_step( ssl_context *ssl );
+
/**
* \brief Perform an SSL renegotiation on the running connection
*
@@ -1061,8 +1075,8 @@ void ssl_handshake_free( ssl_handshake_params *handshake );
/*
* Internal functions (do not call directly)
*/
-int ssl_handshake_client( ssl_context *ssl );
-int ssl_handshake_server( ssl_context *ssl );
+int ssl_handshake_client_step( ssl_context *ssl );
+int ssl_handshake_server_step( ssl_context *ssl );
void ssl_handshake_wrapup( ssl_context *ssl );
int ssl_send_fatal_handshake_failure( ssl_context *ssl );
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 42ddf4131..545906a2a 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1274,121 +1274,113 @@ static int ssl_write_certificate_verify( ssl_context *ssl )
}
/*
- * SSL handshake -- client side
+ * SSL handshake -- client side -- single step
*/
-int ssl_handshake_client( ssl_context *ssl )
+int ssl_handshake_client_step( ssl_context *ssl )
{
int ret = 0;
- SSL_DEBUG_MSG( 2, ( "=> handshake client" ) );
+ if( ssl->state == SSL_HANDSHAKE_OVER )
+ return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
- while( ssl->state != SSL_HANDSHAKE_OVER )
+ SSL_DEBUG_MSG( 2, ( "client state: %d", ssl->state ) );
+
+ if( ( ret = ssl_flush_output( ssl ) ) != 0 )
+ return( ret );
+
+ switch( ssl->state )
{
- SSL_DEBUG_MSG( 2, ( "client state: %d", ssl->state ) );
-
- if( ( ret = ssl_flush_output( ssl ) ) != 0 )
+ case SSL_HELLO_REQUEST:
+ ssl->state = SSL_CLIENT_HELLO;
break;
- switch( ssl->state )
- {
- case SSL_HELLO_REQUEST:
- ssl->state = SSL_CLIENT_HELLO;
- break;
+ /*
+ * ==> ClientHello
+ */
+ case SSL_CLIENT_HELLO:
+ ret = ssl_write_client_hello( ssl );
+ break;
- /*
- * ==> ClientHello
- */
- case SSL_CLIENT_HELLO:
- ret = ssl_write_client_hello( ssl );
- break;
+ /*
+ * <== ServerHello
+ * Certificate
+ * ( ServerKeyExchange )
+ * ( CertificateRequest )
+ * ServerHelloDone
+ */
+ case SSL_SERVER_HELLO:
+ ret = ssl_parse_server_hello( ssl );
+ break;
- /*
- * <== ServerHello
- * Certificate
- * ( ServerKeyExchange )
- * ( CertificateRequest )
- * ServerHelloDone
- */
- case SSL_SERVER_HELLO:
- ret = ssl_parse_server_hello( ssl );
- break;
+ case SSL_SERVER_CERTIFICATE:
+ ret = ssl_parse_certificate( ssl );
+ break;
- case SSL_SERVER_CERTIFICATE:
- ret = ssl_parse_certificate( ssl );
- break;
+ case SSL_SERVER_KEY_EXCHANGE:
+ ret = ssl_parse_server_key_exchange( ssl );
+ break;
- case SSL_SERVER_KEY_EXCHANGE:
- ret = ssl_parse_server_key_exchange( ssl );
- break;
+ case SSL_CERTIFICATE_REQUEST:
+ ret = ssl_parse_certificate_request( ssl );
+ break;
- case SSL_CERTIFICATE_REQUEST:
- ret = ssl_parse_certificate_request( ssl );
- break;
+ case SSL_SERVER_HELLO_DONE:
+ ret = ssl_parse_server_hello_done( ssl );
+ break;
- case SSL_SERVER_HELLO_DONE:
- ret = ssl_parse_server_hello_done( ssl );
- break;
+ /*
+ * ==> ( Certificate/Alert )
+ * ClientKeyExchange
+ * ( CertificateVerify )
+ * ChangeCipherSpec
+ * Finished
+ */
+ case SSL_CLIENT_CERTIFICATE:
+ ret = ssl_write_certificate( ssl );
+ break;
- /*
- * ==> ( Certificate/Alert )
- * ClientKeyExchange
- * ( CertificateVerify )
- * ChangeCipherSpec
- * Finished
- */
- case SSL_CLIENT_CERTIFICATE:
- ret = ssl_write_certificate( ssl );
- break;
+ case SSL_CLIENT_KEY_EXCHANGE:
+ ret = ssl_write_client_key_exchange( ssl );
+ break;
- case SSL_CLIENT_KEY_EXCHANGE:
- ret = ssl_write_client_key_exchange( ssl );
- break;
+ case SSL_CERTIFICATE_VERIFY:
+ ret = ssl_write_certificate_verify( ssl );
+ break;
- case SSL_CERTIFICATE_VERIFY:
- ret = ssl_write_certificate_verify( ssl );
- break;
+ case SSL_CLIENT_CHANGE_CIPHER_SPEC:
+ ret = ssl_write_change_cipher_spec( ssl );
+ break;
- case SSL_CLIENT_CHANGE_CIPHER_SPEC:
- ret = ssl_write_change_cipher_spec( ssl );
- break;
+ case SSL_CLIENT_FINISHED:
+ ret = ssl_write_finished( ssl );
+ break;
- case SSL_CLIENT_FINISHED:
- ret = ssl_write_finished( ssl );
- break;
+ /*
+ * <== ChangeCipherSpec
+ * Finished
+ */
+ case SSL_SERVER_CHANGE_CIPHER_SPEC:
+ ret = ssl_parse_change_cipher_spec( ssl );
+ break;
- /*
- * <== ChangeCipherSpec
- * Finished
- */
- case SSL_SERVER_CHANGE_CIPHER_SPEC:
- ret = ssl_parse_change_cipher_spec( ssl );
- break;
+ case SSL_SERVER_FINISHED:
+ ret = ssl_parse_finished( ssl );
+ break;
- case SSL_SERVER_FINISHED:
- ret = ssl_parse_finished( ssl );
- break;
+ case SSL_FLUSH_BUFFERS:
+ SSL_DEBUG_MSG( 2, ( "handshake: done" ) );
+ ssl->state = SSL_HANDSHAKE_WRAPUP;
+ break;
- case SSL_FLUSH_BUFFERS:
- SSL_DEBUG_MSG( 2, ( "handshake: done" ) );
- ssl->state = SSL_HANDSHAKE_WRAPUP;
- break;
+ case SSL_HANDSHAKE_WRAPUP:
+ ssl_handshake_wrapup( ssl );
+ break;
- case SSL_HANDSHAKE_WRAPUP:
- ssl_handshake_wrapup( ssl );
- break;
-
- default:
- SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) );
- return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
- }
-
- if( ret != 0 )
- break;
- }
-
- SSL_DEBUG_MSG( 2, ( "<= handshake client" ) );
+ default:
+ SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) );
+ return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
+ }
return( ret );
}
-
#endif
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 38253933a..df57cb31f 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -1293,121 +1293,113 @@ static int ssl_parse_certificate_verify( ssl_context *ssl )
}
/*
- * SSL handshake -- server side
+ * SSL handshake -- server side -- single step
*/
-int ssl_handshake_server( ssl_context *ssl )
+int ssl_handshake_server_step( ssl_context *ssl )
{
int ret = 0;
- SSL_DEBUG_MSG( 2, ( "=> handshake server" ) );
+ if( ssl->state == SSL_HANDSHAKE_OVER )
+ return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
- while( ssl->state != SSL_HANDSHAKE_OVER )
+ SSL_DEBUG_MSG( 2, ( "server state: %d", ssl->state ) );
+
+ if( ( ret = ssl_flush_output( ssl ) ) != 0 )
+ return( ret );
+
+ switch( ssl->state )
{
- SSL_DEBUG_MSG( 2, ( "server state: %d", ssl->state ) );
-
- if( ( ret = ssl_flush_output( ssl ) ) != 0 )
+ case SSL_HELLO_REQUEST:
+ ssl->state = SSL_CLIENT_HELLO;
break;
- switch( ssl->state )
- {
- case SSL_HELLO_REQUEST:
- ssl->state = SSL_CLIENT_HELLO;
- break;
-
- /*
- * <== ClientHello
- */
- case SSL_CLIENT_HELLO:
- ret = ssl_parse_client_hello( ssl );
- break;
-
- /*
- * ==> ServerHello
- * Certificate
- * ( ServerKeyExchange )
- * ( CertificateRequest )
- * ServerHelloDone
- */
- case SSL_SERVER_HELLO:
- ret = ssl_write_server_hello( ssl );
- break;
-
- case SSL_SERVER_CERTIFICATE:
- ret = ssl_write_certificate( ssl );
- break;
-
- case SSL_SERVER_KEY_EXCHANGE:
- ret = ssl_write_server_key_exchange( ssl );
- break;
-
- case SSL_CERTIFICATE_REQUEST:
- ret = ssl_write_certificate_request( ssl );
- break;
-
- case SSL_SERVER_HELLO_DONE:
- ret = ssl_write_server_hello_done( ssl );
- break;
-
- /*
- * <== ( Certificate/Alert )
- * ClientKeyExchange
- * ( CertificateVerify )
- * ChangeCipherSpec
- * Finished
- */
- case SSL_CLIENT_CERTIFICATE:
- ret = ssl_parse_certificate( ssl );
- break;
-
- case SSL_CLIENT_KEY_EXCHANGE:
- ret = ssl_parse_client_key_exchange( ssl );
- break;
-
- case SSL_CERTIFICATE_VERIFY:
- ret = ssl_parse_certificate_verify( ssl );
- break;
-
- case SSL_CLIENT_CHANGE_CIPHER_SPEC:
- ret = ssl_parse_change_cipher_spec( ssl );
- break;
-
- case SSL_CLIENT_FINISHED:
- ret = ssl_parse_finished( ssl );
- break;
-
- /*
- * ==> ChangeCipherSpec
- * Finished
- */
- case SSL_SERVER_CHANGE_CIPHER_SPEC:
- ret = ssl_write_change_cipher_spec( ssl );
- break;
-
- case SSL_SERVER_FINISHED:
- ret = ssl_write_finished( ssl );
- break;
-
- case SSL_FLUSH_BUFFERS:
- SSL_DEBUG_MSG( 2, ( "handshake: done" ) );
- ssl->state = SSL_HANDSHAKE_WRAPUP;
- break;
-
- case SSL_HANDSHAKE_WRAPUP:
- ssl_handshake_wrapup( ssl );
- break;
-
- default:
- SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) );
- return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
- }
-
- if( ret != 0 )
+ /*
+ * <== ClientHello
+ */
+ case SSL_CLIENT_HELLO:
+ ret = ssl_parse_client_hello( ssl );
break;
+
+ /*
+ * ==> ServerHello
+ * Certificate
+ * ( ServerKeyExchange )
+ * ( CertificateRequest )
+ * ServerHelloDone
+ */
+ case SSL_SERVER_HELLO:
+ ret = ssl_write_server_hello( ssl );
+ break;
+
+ case SSL_SERVER_CERTIFICATE:
+ ret = ssl_write_certificate( ssl );
+ break;
+
+ case SSL_SERVER_KEY_EXCHANGE:
+ ret = ssl_write_server_key_exchange( ssl );
+ break;
+
+ case SSL_CERTIFICATE_REQUEST:
+ ret = ssl_write_certificate_request( ssl );
+ break;
+
+ case SSL_SERVER_HELLO_DONE:
+ ret = ssl_write_server_hello_done( ssl );
+ break;
+
+ /*
+ * <== ( Certificate/Alert )
+ * ClientKeyExchange
+ * ( CertificateVerify )
+ * ChangeCipherSpec
+ * Finished
+ */
+ case SSL_CLIENT_CERTIFICATE:
+ ret = ssl_parse_certificate( ssl );
+ break;
+
+ case SSL_CLIENT_KEY_EXCHANGE:
+ ret = ssl_parse_client_key_exchange( ssl );
+ break;
+
+ case SSL_CERTIFICATE_VERIFY:
+ ret = ssl_parse_certificate_verify( ssl );
+ break;
+
+ case SSL_CLIENT_CHANGE_CIPHER_SPEC:
+ ret = ssl_parse_change_cipher_spec( ssl );
+ break;
+
+ case SSL_CLIENT_FINISHED:
+ ret = ssl_parse_finished( ssl );
+ break;
+
+ /*
+ * ==> ChangeCipherSpec
+ * Finished
+ */
+ case SSL_SERVER_CHANGE_CIPHER_SPEC:
+ ret = ssl_write_change_cipher_spec( ssl );
+ break;
+
+ case SSL_SERVER_FINISHED:
+ ret = ssl_write_finished( ssl );
+ break;
+
+ case SSL_FLUSH_BUFFERS:
+ SSL_DEBUG_MSG( 2, ( "handshake: done" ) );
+ ssl->state = SSL_HANDSHAKE_WRAPUP;
+ break;
+
+ case SSL_HANDSHAKE_WRAPUP:
+ ssl_handshake_wrapup( ssl );
+ break;
+
+ default:
+ SSL_DEBUG_MSG( 1, ( "invalid state %d", ssl->state ) );
+ return( POLARSSL_ERR_SSL_BAD_INPUT_DATA );
}
- SSL_DEBUG_MSG( 2, ( "<= handshake server" ) );
-
return( ret );
}
-
#endif
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index e0a64ab06..94113928f 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -3513,24 +3513,42 @@ const int ssl_default_ciphersuites[] =
};
/*
- * Perform the SSL handshake
+ * Perform a single step of the SSL handshake
*/
-int ssl_handshake( ssl_context *ssl )
+int ssl_handshake_step( ssl_context *ssl )
{
int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE;
- SSL_DEBUG_MSG( 2, ( "=> handshake" ) );
-
#if defined(POLARSSL_SSL_CLI_C)
if( ssl->endpoint == SSL_IS_CLIENT )
- ret = ssl_handshake_client( ssl );
+ ret = ssl_handshake_client_step( ssl );
#endif
#if defined(POLARSSL_SSL_SRV_C)
if( ssl->endpoint == SSL_IS_SERVER )
- ret = ssl_handshake_server( ssl );
+ ret = ssl_handshake_server_step( ssl );
#endif
+ return( ret );
+}
+
+/*
+ * Perform the SSL handshake
+ */
+int ssl_handshake( ssl_context *ssl )
+{
+ int ret = 0;
+
+ SSL_DEBUG_MSG( 2, ( "=> handshake" ) );
+
+ while( ssl->state != SSL_HANDSHAKE_OVER )
+ {
+ ret = ssl_handshake_step( ssl );
+
+ if( ret != 0 )
+ break;
+ }
+
SSL_DEBUG_MSG( 2, ( "<= handshake" ) );
return( ret );