From 380da53c4896177b8de837d09cff466cdc05eab2 Mon Sep 17 00:00:00 2001 From: Paul Bakker Date: Wed, 18 Apr 2012 16:10:25 +0000 Subject: [PATCH] - Abstracted checksum updating during handshake --- include/polarssl/ssl.h | 11 +- library/ssl_cli.c | 7 +- library/ssl_srv.c | 16 ++- library/ssl_tls.c | 231 ++++++++++++++++++++++++++++------------- 4 files changed, 177 insertions(+), 88 deletions(-) diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 8110fbce6..4ac6f86c0 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -342,11 +342,10 @@ struct _ssl_context * Crypto layer */ dhm_context dhm_ctx; /*!< DHM key exchange */ - md5_context fin_md5; /*!< Finished MD5 checksum */ - sha1_context fin_sha1; /*!< Finished SHA-1 checksum */ - sha2_context fin_sha2; /*!< Finished SHA-256 checksum */ - sha4_context fin_sha4; /*!< Finished SHA-384 checksum */ + unsigned char ctx_checksum[500]; /*!< Checksum context(s) */ + void (*update_checksum)(ssl_context *, unsigned char *, size_t); + void (*calc_verify)(ssl_context *, unsigned char *); void (*calc_finished)(ssl_context *, unsigned char *, int); int (*tls_prf)(unsigned char *, size_t, char *, unsigned char *, size_t, @@ -737,7 +736,6 @@ int ssl_handshake_client( ssl_context *ssl ); int ssl_handshake_server( ssl_context *ssl ); int ssl_derive_keys( ssl_context *ssl ); -void ssl_calc_verify( ssl_context *ssl, unsigned char hash[36] ); int ssl_read_record( ssl_context *ssl ); /** @@ -758,6 +756,9 @@ int ssl_write_change_cipher_spec( ssl_context *ssl ); int ssl_parse_finished( ssl_context *ssl ); int ssl_write_finished( ssl_context *ssl ); +void ssl_kickstart_checksum( ssl_context *ssl, int ciphersuite, + unsigned char *input_buf, size_t len ); + #ifdef __cplusplus } #endif diff --git a/library/ssl_cli.c b/library/ssl_cli.c index df4dbb0f9..99120d569 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -368,6 +368,11 @@ static int ssl_parse_server_hello( ssl_context *ssl ) i = ( buf[39 + n] << 8 ) | buf[40 + n]; + /* + * Initialize update checksum functions + */ + ssl_kickstart_checksum( ssl, i, buf, ssl->in_hslen ); + SSL_DEBUG_MSG( 3, ( "server hello, session id len.: %d", n ) ); SSL_DEBUG_BUF( 3, "server hello, session id", buf + 39, n ); @@ -940,7 +945,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl ) /* * Make an RSA signature of the handshake digests */ - ssl_calc_verify( ssl, hash ); + ssl->calc_verify( ssl, hash ); if ( ssl->rsa_key ) n = ssl->rsa_key->len; diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 790b8a7c6..ecf153609 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -106,10 +106,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) return( ret ); } - md5_update( &ssl->fin_md5 , buf + 2, n ); - sha1_update( &ssl->fin_sha1, buf + 2, n ); - sha2_update( &ssl->fin_sha2, buf + 2, n ); - sha4_update( &ssl->fin_sha4, buf + 2, n ); + ssl->update_checksum( ssl, buf + 2, n ); buf = ssl->in_msg; n = ssl->in_left - 5; @@ -228,10 +225,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) buf = ssl->in_msg; n = ssl->in_left - 5; - md5_update( &ssl->fin_md5 , buf, n ); - sha1_update( &ssl->fin_sha1, buf, n ); - sha2_update( &ssl->fin_sha2, buf, n ); - sha4_update( &ssl->fin_sha4, buf, n ); + ssl->update_checksum( ssl, buf, n ); /* * SSL layer: @@ -352,6 +346,8 @@ static int ssl_parse_client_hello( ssl_context *ssl ) have_ciphersuite: ssl->session->ciphersuite = ssl->ciphersuites[i]; + ssl_kickstart_checksum( ssl, ssl->session->ciphersuite, buf, n ); + ssl->in_left = 0; ssl->state++; @@ -912,7 +908,7 @@ static int ssl_parse_certificate_verify( ssl_context *ssl ) { int ret; size_t n1, n2; - unsigned char hash[36]; + unsigned char hash[48]; SSL_DEBUG_MSG( 2, ( "=> parse certificate verify" ) ); @@ -923,7 +919,7 @@ static int ssl_parse_certificate_verify( ssl_context *ssl ) return( 0 ); } - ssl_calc_verify( ssl, hash ); + ssl->calc_verify( ssl, hash ); if( ( ret = ssl_read_record( ssl ) ) != 0 ) { diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 9a962b2d1..88c6e55de 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -195,6 +195,16 @@ static int tls_prf_sha384( unsigned char *secret, size_t slen, char *label, return( 0 ); } +static void ssl_update_checksum_start(ssl_context *, unsigned char *, size_t); +static void ssl_update_checksum_md5sha1(ssl_context *, unsigned char *, size_t); +static void ssl_update_checksum_sha256(ssl_context *, unsigned char *, size_t); +static void ssl_update_checksum_sha384(ssl_context *, unsigned char *, size_t); + +static void ssl_calc_verify_ssl(ssl_context *,unsigned char *); +static void ssl_calc_verify_tls(ssl_context *,unsigned char *); +static void ssl_calc_verify_tls_sha256(ssl_context *,unsigned char *); +static void ssl_calc_verify_tls_sha384(ssl_context *,unsigned char *); + static void ssl_calc_finished_ssl(ssl_context *,unsigned char *,int); static void ssl_calc_finished_tls(ssl_context *,unsigned char *,int); static void ssl_calc_finished_tls_sha256(ssl_context *,unsigned char *,int); @@ -221,22 +231,26 @@ int ssl_derive_keys( ssl_context *ssl ) if( ssl->minor_ver == SSL_MINOR_VERSION_0 ) { ssl->tls_prf = tls1_prf; + ssl->calc_verify = ssl_calc_verify_ssl; ssl->calc_finished = ssl_calc_finished_ssl; } else if( ssl->minor_ver < SSL_MINOR_VERSION_3 ) { ssl->tls_prf = tls1_prf; + ssl->calc_verify = ssl_calc_verify_tls; ssl->calc_finished = ssl_calc_finished_tls; } else if( ssl->session->ciphersuite == SSL_RSA_AES_256_GCM_SHA384 || ssl->session->ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 ) { ssl->tls_prf = tls_prf_sha384; + ssl->calc_verify = ssl_calc_verify_tls_sha384; ssl->calc_finished = ssl_calc_finished_tls_sha384; } else { ssl->tls_prf = tls_prf_sha256; + ssl->calc_verify = ssl_calc_verify_tls_sha256; ssl->calc_finished = ssl_calc_finished_tls_sha256; } @@ -602,61 +616,91 @@ int ssl_derive_keys( ssl_context *ssl ) return( 0 ); } -void ssl_calc_verify( ssl_context *ssl, unsigned char hash[48] ) +void ssl_calc_verify_ssl( ssl_context *ssl, unsigned char hash[36] ) { md5_context md5; sha1_context sha1; - sha2_context sha2; - sha4_context sha4; unsigned char pad_1[48]; unsigned char pad_2[48]; - SSL_DEBUG_MSG( 2, ( "=> calc verify" ) ); + SSL_DEBUG_MSG( 2, ( "=> calc verify ssl" ) ); - memcpy( &md5 , &ssl->fin_md5 , sizeof( md5_context ) ); - memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) ); - memcpy( &sha2, &ssl->fin_sha2, sizeof( sha2_context ) ); - memcpy( &sha4, &ssl->fin_sha4, sizeof( sha4_context ) ); + memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) ); + memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ), + sizeof( sha1_context ) ); - if( ssl->minor_ver == SSL_MINOR_VERSION_0 ) - { - memset( pad_1, 0x36, 48 ); - memset( pad_2, 0x5C, 48 ); + memset( pad_1, 0x36, 48 ); + memset( pad_2, 0x5C, 48 ); - md5_update( &md5, ssl->session->master, 48 ); - md5_update( &md5, pad_1, 48 ); - md5_finish( &md5, hash ); + md5_update( &md5, ssl->session->master, 48 ); + md5_update( &md5, pad_1, 48 ); + md5_finish( &md5, hash ); - md5_starts( &md5 ); - md5_update( &md5, ssl->session->master, 48 ); - md5_update( &md5, pad_2, 48 ); - md5_update( &md5, hash, 16 ); - md5_finish( &md5, hash ); - - sha1_update( &sha1, ssl->session->master, 48 ); - sha1_update( &sha1, pad_1, 40 ); - sha1_finish( &sha1, hash + 16 ); + md5_starts( &md5 ); + md5_update( &md5, ssl->session->master, 48 ); + md5_update( &md5, pad_2, 48 ); + md5_update( &md5, hash, 16 ); + md5_finish( &md5, hash ); - sha1_starts( &sha1 ); - sha1_update( &sha1, ssl->session->master, 48 ); - sha1_update( &sha1, pad_2, 40 ); - sha1_update( &sha1, hash + 16, 20 ); - sha1_finish( &sha1, hash + 16 ); - } - else if( ssl->minor_ver != SSL_MINOR_VERSION_3 ) /* TLSv1 */ - { - md5_finish( &md5, hash ); - sha1_finish( &sha1, hash + 16 ); - } - else if( ssl->session->ciphersuite == SSL_RSA_AES_256_GCM_SHA384 || - ssl->session->ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 ) - { - sha4_finish( &sha4, hash ); - } - else - { - sha2_finish( &sha2, hash ); - } + sha1_update( &sha1, ssl->session->master, 48 ); + sha1_update( &sha1, pad_1, 40 ); + sha1_finish( &sha1, hash + 16 ); + + sha1_starts( &sha1 ); + sha1_update( &sha1, ssl->session->master, 48 ); + sha1_update( &sha1, pad_2, 40 ); + sha1_update( &sha1, hash + 16, 20 ); + sha1_finish( &sha1, hash + 16 ); + + SSL_DEBUG_BUF( 3, "calculated verify result", hash, 36 ); + SSL_DEBUG_MSG( 2, ( "<= calc verify" ) ); + + return; +} + +void ssl_calc_verify_tls( ssl_context *ssl, unsigned char hash[36] ) +{ + md5_context md5; + sha1_context sha1; + + SSL_DEBUG_MSG( 2, ( "=> calc verify tls" ) ); + + memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) ); + memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ), + sizeof( sha1_context ) ); + + md5_finish( &md5, hash ); + sha1_finish( &sha1, hash + 16 ); + + SSL_DEBUG_BUF( 3, "calculated verify result", hash, 36 ); + SSL_DEBUG_MSG( 2, ( "<= calc verify" ) ); + + return; +} + +void ssl_calc_verify_tls_sha256( ssl_context *ssl, unsigned char hash[32] ) +{ + sha2_context sha2; + + SSL_DEBUG_MSG( 2, ( "=> calc verify sha256" ) ); + + memcpy( &sha2 , (sha2_context *) ssl->ctx_checksum, sizeof(sha2_context) ); + sha2_finish( &sha2, hash ); + + SSL_DEBUG_BUF( 3, "calculated verify result", hash, 32 ); + SSL_DEBUG_MSG( 2, ( "<= calc verify" ) ); + + return; +} + +void ssl_calc_verify_tls_sha384( ssl_context *ssl, unsigned char hash[48] ) +{ + sha4_context sha4; + + SSL_DEBUG_MSG( 2, ( "=> calc verify sha384" ) ); + + memcpy( &sha4 , (sha4_context *) ssl->ctx_checksum, sizeof(sha4_context) ); + sha4_finish( &sha4, hash ); SSL_DEBUG_BUF( 3, "calculated verify result", hash, 48 ); SSL_DEBUG_MSG( 2, ( "<= calc verify" ) ); @@ -1395,10 +1439,7 @@ int ssl_write_record( ssl_context *ssl ) ssl->out_msg[2] = (unsigned char)( ( len - 4 ) >> 8 ); ssl->out_msg[3] = (unsigned char)( ( len - 4 ) ); - md5_update( &ssl->fin_md5 , ssl->out_msg, len ); - sha1_update( &ssl->fin_sha1, ssl->out_msg, len ); - sha2_update( &ssl->fin_sha2, ssl->out_msg, len ); - sha4_update( &ssl->fin_sha4, ssl->out_msg, len ); + ssl->update_checksum( ssl, ssl->out_msg, len ); } if( ssl->do_crypt != 0 ) @@ -1471,10 +1512,7 @@ int ssl_read_record( ssl_context *ssl ) return( POLARSSL_ERR_SSL_INVALID_RECORD ); } - md5_update( &ssl->fin_md5 , ssl->in_msg, ssl->in_hslen ); - sha1_update( &ssl->fin_sha1, ssl->in_msg, ssl->in_hslen ); - sha2_update( &ssl->fin_sha2, ssl->in_msg, ssl->in_hslen ); - sha4_update( &ssl->fin_sha4, ssl->in_msg, ssl->in_hslen ); + ssl->update_checksum( ssl, ssl->in_msg, ssl->in_hslen ); return( 0 ); } @@ -1618,10 +1656,7 @@ int ssl_read_record( ssl_context *ssl ) return( POLARSSL_ERR_SSL_INVALID_RECORD ); } - md5_update( &ssl->fin_md5 , ssl->in_msg, ssl->in_hslen ); - sha1_update( &ssl->fin_sha1, ssl->in_msg, ssl->in_hslen ); - sha2_update( &ssl->fin_sha2, ssl->in_msg, ssl->in_hslen ); - sha4_update( &ssl->fin_sha4, ssl->in_msg, ssl->in_hslen ); + ssl->update_checksum( ssl, ssl->in_msg, ssl->in_hslen ); } if( ssl->in_msgtype == SSL_MSG_ALERT ) @@ -1990,6 +2025,62 @@ int ssl_parse_change_cipher_spec( ssl_context *ssl ) return( 0 ); } +void ssl_kickstart_checksum( ssl_context *ssl, int ciphersuite, + unsigned char *input_buf, size_t len ) +{ + if( ssl->minor_ver < SSL_MINOR_VERSION_3 ) + { + md5_starts( (md5_context *) ssl->ctx_checksum ); + sha1_starts( (sha1_context *) ( ssl->ctx_checksum + + sizeof(md5_context) ) ); + + ssl->update_checksum = ssl_update_checksum_md5sha1; + } + else if ( ciphersuite == SSL_RSA_AES_256_GCM_SHA384 || + ciphersuite == SSL_EDH_RSA_AES_256_GCM_SHA384 ) + { + sha4_starts( (sha4_context *) ssl->ctx_checksum, 1 ); + ssl->update_checksum = ssl_update_checksum_sha384; + } + else + { + sha2_starts( (sha2_context *) ssl->ctx_checksum, 0 ); + ssl->update_checksum = ssl_update_checksum_sha256; + } + + if( ssl->endpoint == SSL_IS_CLIENT ) + ssl->update_checksum( ssl, ssl->out_msg, ssl->out_msglen ); + ssl->update_checksum( ssl, input_buf, len ); +} + +static void ssl_update_checksum_start( ssl_context *ssl, unsigned char *buf, + size_t len ) +{ + ((void) ssl); + ((void) buf); + ((void) len); +} + +static void ssl_update_checksum_md5sha1( ssl_context *ssl, unsigned char *buf, + size_t len ) +{ + md5_update( (md5_context *) ssl->ctx_checksum, buf, len ); + sha1_update( (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ), + buf, len ); +} + +static void ssl_update_checksum_sha256( ssl_context *ssl, unsigned char *buf, + size_t len ) +{ + sha2_update( (sha2_context *) ssl->ctx_checksum, buf, len ); +} + +static void ssl_update_checksum_sha384( ssl_context *ssl, unsigned char *buf, + size_t len ) +{ + sha4_update( (sha4_context *) ssl->ctx_checksum, buf, len ); +} + static void ssl_calc_finished_ssl( ssl_context *ssl, unsigned char *buf, int from ) { @@ -2003,8 +2094,9 @@ static void ssl_calc_finished_ssl( SSL_DEBUG_MSG( 2, ( "=> calc finished ssl" ) ); - memcpy( &md5 , &ssl->fin_md5 , sizeof( md5_context ) ); - memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) ); + memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) ); + memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ), + sizeof( sha1_context ) ); /* * SSLv3: @@ -2073,8 +2165,9 @@ static void ssl_calc_finished_tls( SSL_DEBUG_MSG( 2, ( "=> calc finished tls" ) ); - memcpy( &md5 , &ssl->fin_md5 , sizeof( md5_context ) ); - memcpy( &sha1, &ssl->fin_sha1, sizeof( sha1_context ) ); + memcpy( &md5 , (md5_context *) ssl->ctx_checksum, sizeof(md5_context) ); + memcpy( &sha1, (sha1_context *) ( ssl->ctx_checksum + sizeof(md5_context) ), + sizeof( sha1_context ) ); /* * TLSv1: @@ -2116,9 +2209,9 @@ static void ssl_calc_finished_tls_sha256( sha2_context sha2; unsigned char padbuf[32]; - SSL_DEBUG_MSG( 2, ( "=> calc finished tls 1.2" ) ); + SSL_DEBUG_MSG( 2, ( "=> calc finished tls sha256" ) ); - memcpy( &sha2, &ssl->fin_sha2, sizeof( sha2_context ) ); + memcpy( &sha2 , (sha2_context *) ssl->ctx_checksum, sizeof(sha2_context) ); /* * TLSv1.2: @@ -2155,9 +2248,9 @@ static void ssl_calc_finished_tls_sha384( sha4_context sha4; unsigned char padbuf[48]; - SSL_DEBUG_MSG( 2, ( "=> calc finished tls 1.2" ) ); + SSL_DEBUG_MSG( 2, ( "=> calc finished tls sha384" ) ); - memcpy( &sha4, &ssl->fin_sha4, sizeof( sha4_context ) ); + memcpy( &sha4 , (sha4_context *) ssl->ctx_checksum, sizeof(sha4_context) ); /* * TLSv1.2: @@ -2320,10 +2413,7 @@ int ssl_init( ssl_context *ssl ) ssl->hostname = NULL; ssl->hostname_len = 0; - md5_starts( &ssl->fin_md5 ); - sha1_starts( &ssl->fin_sha1 ); - sha2_starts( &ssl->fin_sha2, 0 ); - sha4_starts( &ssl->fin_sha4, 1 ); + ssl->update_checksum = ssl_update_checksum_start; return( 0 ); } @@ -2367,10 +2457,7 @@ void ssl_session_reset( ssl_context *ssl ) memset( ssl->ctx_enc, 0, 128 ); memset( ssl->ctx_dec, 0, 128 ); - md5_starts( &ssl->fin_md5 ); - sha1_starts( &ssl->fin_sha1 ); - sha2_starts( &ssl->fin_sha2, 0 ); - sha4_starts( &ssl->fin_sha4, 1 ); + ssl->update_checksum = ssl_update_checksum_start; } /*