diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 7cfdbf213..3930b471c 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -977,6 +977,11 @@ int mbedtls_ssl_tls_prf( const mbedtls_tls_prf_types prf, return( tls_prf( secret, slen, label, random, rlen, dstbuf, dlen ) ); } +/* Type for the TLS PRF */ +typedef int ssl_tls_prf_t(const unsigned char *, size_t, const char *, + const unsigned char *, size_t, + unsigned char *, size_t); + /* * Populate a transform structure with session keys and all the other * necessary information. @@ -985,13 +990,15 @@ int mbedtls_ssl_tls_prf( const mbedtls_tls_prf_types prf, * - [in/out]: transform: structure to populate * [in] must be just initialised with mbedtls_ssl_transform_init() * [out] fully populate, ready for use by mbedtls_ssl_{en,de}crypt_buf() - * - [in] session: used members: encrypt_then_max, master, compression - * - [in] handshake: used members: prf, ciphersuite_info, randbytes - * - [in]: ssl: used members: minor_ver, conf->endpoint + * - [in] session: used: ciphersuite, encrypt_then_mac, master, compression + * - [in] tls_prf: pointer to PRF to use for key derivation + * - [in] randbytes: buffer holding ServerHello.random + ClientHello.random + * - [in] ssl: used members: minor_ver, conf->endpoint */ static int ssl_populate_transform( mbedtls_ssl_transform *transform, const mbedtls_ssl_session *session, - const mbedtls_ssl_handshake_params *handshake, + ssl_tls_prf_t tls_prf, + const unsigned char randbytes[64], const mbedtls_ssl_context *ssl ) { int ret = 0; @@ -1010,13 +1017,24 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, const mbedtls_cipher_info_t *cipher_info; const mbedtls_md_info_t *md_info; + /* Copy info about negotiated version and extensions */ #if defined(MBEDTLS_SSL_ENCRYPT_THEN_MAC) && \ defined(MBEDTLS_SSL_SOME_MODES_USE_MAC) transform->encrypt_then_mac = session->encrypt_then_mac; #endif transform->minor_ver = ssl->minor_ver; - ciphersuite_info = handshake->ciphersuite_info; + /* + * Get various info structures + */ + ciphersuite_info = mbedtls_ssl_ciphersuite_from_id( session->ciphersuite ); + if( ciphersuite_info == NULL ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "ciphersuite info for %d not found", + session->ciphersuite ) ); + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + } + cipher_info = mbedtls_cipher_info_from_type( ciphersuite_info->cipher ); if( cipher_info == NULL ) { @@ -1054,19 +1072,10 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */ /* - * SSLv3: - * key block = - * MD5( master + SHA1( 'A' + master + randbytes ) ) + - * MD5( master + SHA1( 'BB' + master + randbytes ) ) + - * MD5( master + SHA1( 'CCC' + master + randbytes ) ) + - * MD5( master + SHA1( 'DDDD' + master + randbytes ) ) + - * ... - * - * TLSv1: - * key block = PRF( master, "key expansion", randbytes ) + * Compute key block using the PRF */ - ret = handshake->tls_prf( session->master, 48, "key expansion", - handshake->randbytes, 64, keyblk, 256 ); + ret = tls_prf( session->master, 48, "key expansion", + randbytes, 64, keyblk, 256 ); if( ret != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "prf", ret ); @@ -1076,7 +1085,7 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, MBEDTLS_SSL_DEBUG_MSG( 3, ( "ciphersuite = %s", mbedtls_ssl_get_ciphersuite_name( session->ciphersuite ) ) ); MBEDTLS_SSL_DEBUG_BUF( 3, "master secret", session->master, 48 ); - MBEDTLS_SSL_DEBUG_BUF( 4, "random bytes", handshake->randbytes, 64 ); + MBEDTLS_SSL_DEBUG_BUF( 4, "random bytes", randbytes, 64 ); MBEDTLS_SSL_DEBUG_BUF( 4, "key block", keyblk, 256 ); /* @@ -1337,9 +1346,9 @@ static int ssl_populate_transform( mbedtls_ssl_transform *transform, mac_key_len, keylen, iv_copy_len, /* work around bug in exporter type */ - (unsigned char *) handshake->randbytes + 32, - (unsigned char *) handshake->randbytes, - tls_prf_get_type( handshake->tls_prf ) ); + (unsigned char *) randbytes + 32, + (unsigned char *) randbytes, + tls_prf_get_type( tls_prf ) ); } #endif @@ -1740,7 +1749,8 @@ int mbedtls_ssl_derive_keys( mbedtls_ssl_context *ssl ) /* Populate transform structure */ ret = ssl_populate_transform( ssl->transform_negotiate, ssl->session_negotiate, - ssl->handshake, + ssl->handshake->tls_prf, + ssl->handshake->randbytes, ssl ); if( ret != 0 ) {