diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 09c72a703..cddeb74e7 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -336,40 +336,42 @@ static int tls1_prf( const unsigned char *secret, size_t slen, #endif /* POLARSSL_SSL_PROTO_TLS1) || POLARSSL_SSL_PROTO_TLS1_1 */ #if defined(POLARSSL_SSL_PROTO_TLS1_2) -#if defined(POLARSSL_SHA256_C) -static int tls_prf_sha256( const unsigned char *secret, size_t slen, - const char *label, - const unsigned char *random, size_t rlen, - unsigned char *dstbuf, size_t dlen ) +static int tls_prf_generic( md_type_t md_type, + const unsigned char *secret, size_t slen, + const char *label, + const unsigned char *random, size_t rlen, + unsigned char *dstbuf, size_t dlen ) { size_t nb; - size_t i, j, k; + size_t i, j, k, md_len; unsigned char tmp[128]; - unsigned char h_i[32]; + unsigned char h_i[POLARSSL_MD_MAX_SIZE]; const md_info_t *md_info; - if( sizeof( tmp ) < 32 + strlen( label ) + rlen ) + if( ( md_info = md_info_from_type( md_type ) ) == NULL ) + return( POLARSSL_ERR_SSL_INTERNAL_ERROR ); + + md_len = md_get_size( md_info ); + + if( sizeof( tmp ) < md_len + strlen( label ) + rlen ) return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); nb = strlen( label ); - memcpy( tmp + 32, label, nb ); - memcpy( tmp + 32 + nb, random, rlen ); + memcpy( tmp + md_len, label, nb ); + memcpy( tmp + md_len + nb, random, rlen ); nb += rlen; /* * Compute P_(secret, label + random)[0..dlen] */ - if( ( md_info = md_info_from_type( POLARSSL_MD_SHA256 ) ) == NULL ) - return( POLARSSL_ERR_SSL_INTERNAL_ERROR ); + md_hmac( md_info, secret, slen, tmp + md_len, nb, tmp ); - md_hmac( md_info, secret, slen, tmp + 32, nb, tmp ); - - for( i = 0; i < dlen; i += 32 ) + for( i = 0; i < dlen; i += md_len ) { - md_hmac( md_info, secret, slen, tmp, 32 + nb, h_i ); - md_hmac( md_info, secret, slen, tmp, 32, tmp ); + md_hmac( md_info, secret, slen, tmp, md_len + nb, h_i ); + md_hmac( md_info, secret, slen, tmp, md_len, tmp ); - k = ( i + 32 > dlen ) ? dlen % 32 : 32; + k = ( i + md_len > dlen ) ? dlen % md_len : md_len; for( j = 0; j < k; j++ ) dstbuf[i + j] = h_i[j]; @@ -380,6 +382,16 @@ static int tls_prf_sha256( const unsigned char *secret, size_t slen, return( 0 ); } + +#if defined(POLARSSL_SHA256_C) +static int tls_prf_sha256( const unsigned char *secret, size_t slen, + const char *label, + const unsigned char *random, size_t rlen, + unsigned char *dstbuf, size_t dlen ) +{ + return( tls_prf_generic( POLARSSL_MD_SHA256, secret, slen, + label, random, rlen, dstbuf, dlen ) ); +} #endif /* POLARSSL_SHA256_C */ #if defined(POLARSSL_SHA512_C) @@ -388,43 +400,8 @@ static int tls_prf_sha384( const unsigned char *secret, size_t slen, const unsigned char *random, size_t rlen, unsigned char *dstbuf, size_t dlen ) { - size_t nb; - size_t i, j, k; - unsigned char tmp[128]; - unsigned char h_i[48]; - const md_info_t *md_info; - - if( sizeof( tmp ) < 48 + strlen( label ) + rlen ) - return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - - nb = strlen( label ); - memcpy( tmp + 48, label, nb ); - memcpy( tmp + 48 + nb, random, rlen ); - nb += rlen; - - /* - * Compute P_(secret, label + random)[0..dlen] - */ - if( ( md_info = md_info_from_type( POLARSSL_MD_SHA384 ) ) == NULL ) - return( POLARSSL_ERR_SSL_INTERNAL_ERROR ); - - md_hmac( md_info, secret, slen, tmp + 48, nb, tmp ); - - for( i = 0; i < dlen; i += 48 ) - { - md_hmac( md_info, secret, slen, tmp, 48 + nb, h_i ); - md_hmac( md_info, secret, slen, tmp, 48, tmp ); - - k = ( i + 48 > dlen ) ? dlen % 48 : 48; - - for( j = 0; j < k; j++ ) - dstbuf[i + j] = h_i[j]; - } - - polarssl_zeroize( tmp, sizeof( tmp ) ); - polarssl_zeroize( h_i, sizeof( h_i ) ); - - return( 0 ); + return( tls_prf_generic( POLARSSL_MD_SHA384, secret, slen, + label, random, rlen, dstbuf, dlen ) ); } #endif /* POLARSSL_SHA512_C */ #endif /* POLARSSL_SSL_PROTO_TLS1_2 */