diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 22b07bc17..391ce5bb8 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -374,11 +374,17 @@ mbedtls_pk_type_t mbedtls_ssl_pk_alg_from_sig( unsigned char sig ); #endif mbedtls_md_type_t mbedtls_ssl_md_alg_from_hash( unsigned char hash ); +unsigned char mbedtls_ssl_hash_from_md_alg( int md ); #if defined(MBEDTLS_ECP_C) int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id ); #endif +#if defined(MBEDTLS_KEY_EXCHANGE__SOME__SIGNATURE_ENABLED) +int mbedtls_ssl_check_sig_hash( const mbedtls_ssl_context *ssl, + mbedtls_md_type_t md ); +#endif + #if defined(MBEDTLS_X509_CRT_PARSE_C) static inline mbedtls_pk_context *mbedtls_ssl_own_key( mbedtls_ssl_context *ssl ) { diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 40c2d4c86..e2c2d3fa9 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -156,6 +156,7 @@ static void ssl_write_signature_algorithms_ext( mbedtls_ssl_context *ssl, { unsigned char *p = buf; size_t sig_alg_len = 0; + const int *md; #if defined(MBEDTLS_RSA_C) || defined(MBEDTLS_ECDSA_C) unsigned char *sig_alg_list = buf + 6; #endif @@ -170,55 +171,22 @@ static void ssl_write_signature_algorithms_ext( mbedtls_ssl_context *ssl, /* * Prepare signature_algorithms extension (TLS 1.2) */ -#if defined(MBEDTLS_RSA_C) -#if defined(MBEDTLS_SHA512_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA512; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA384; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; -#endif -#if defined(MBEDTLS_SHA256_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA256; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA224; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; -#endif -#if defined(MBEDTLS_SHA1_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA1; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; -#endif -#if defined(MBEDTLS_MD5_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_MD5; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; -#endif -#endif /* MBEDTLS_RSA_C */ + for( md = ssl->conf->sig_hashes; *md != MBEDTLS_MD_NONE; md++ ) + { #if defined(MBEDTLS_ECDSA_C) -#if defined(MBEDTLS_SHA512_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA512; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA384; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; + sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); + sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; #endif -#if defined(MBEDTLS_SHA256_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA256; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA224; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; +#if defined(MBEDTLS_RSA_C) + sig_alg_list[sig_alg_len++] = mbedtls_ssl_hash_from_md_alg( *md ); + sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_RSA; #endif -#if defined(MBEDTLS_SHA1_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_SHA1; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; -#endif -#if defined(MBEDTLS_MD5_C) - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_HASH_MD5; - sig_alg_list[sig_alg_len++] = MBEDTLS_SSL_SIG_ECDSA; -#endif -#endif /* MBEDTLS_ECDSA_C */ + } /* * enum { - * none(0), mbedtls_md5(1), mbedtls_sha1(2), sha224(3), mbedtls_sha256(4), sha384(5), - * mbedtls_sha512(6), (255) + * none(0), md5(1), sha1(2), sha224(3), sha256(4), sha384(5), + * sha512(6), (255) * } HashAlgorithm; * * enum { anonymous(0), rsa(1), dsa(2), ecdsa(3), (255) } @@ -1876,6 +1844,16 @@ static int ssl_parse_signature_algorithm( mbedtls_ssl_context *ssl, return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE ); } + /* + * Check if the hash is acceptable + */ + if( mbedtls_ssl_check_sig_hash( ssl, *md_alg ) != 0 ) + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "server used HashAlgorithm " + "that was not offered" ) ); + return( MBEDTLS_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE ); + } + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Server used SignatureAlgorithm %d", (*p)[1] ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "Server used HashAlgorithm %d", (*p)[0] ) ); *p += 2; diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 554a55239..457362f96 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -211,7 +211,7 @@ static int ssl_parse_signature_algorithms_ext( mbedtls_ssl_context *ssl, * * So, just look at the HashAlgorithm part. */ - for( md_cur = mbedtls_md_list(); *md_cur != MBEDTLS_MD_NONE; md_cur++ ) { + for( md_cur = ssl->conf->sig_hashes; *md_cur != MBEDTLS_MD_NONE; md_cur++ ) { for( p = buf + 2; p < end; p += 2 ) { if( *md_cur == (int) mbedtls_ssl_md_alg_from_hash( p[0] ) ) { ssl->handshake->sig_alg = p[0]; diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 63d2e83a6..9007e0540 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -6788,7 +6788,7 @@ mbedtls_pk_type_t mbedtls_ssl_pk_alg_from_sig( unsigned char sig ) #endif /* MBEDTLS_PK_C */ /* - * Convert between SSL_HASH_XXX and MBEDTLS_MD_XXX + * Convert from MBEDTLS_SSL_HASH_XXX to MBEDTLS_MD_XXX */ mbedtls_md_type_t mbedtls_ssl_md_alg_from_hash( unsigned char hash ) { @@ -6819,9 +6819,41 @@ mbedtls_md_type_t mbedtls_ssl_md_alg_from_hash( unsigned char hash ) } } +/* + * Convert from MBEDTLS_MD_XXX to MBEDTLS_SSL_HASH_XXX + */ +unsigned char mbedtls_ssl_hash_from_md_alg( int md ) +{ + switch( md ) + { +#if defined(MBEDTLS_MD5_C) + case MBEDTLS_MD_MD5: + return( MBEDTLS_SSL_HASH_MD5 ); +#endif +#if defined(MBEDTLS_SHA1_C) + case MBEDTLS_MD_SHA1: + return( MBEDTLS_SSL_HASH_SHA1 ); +#endif +#if defined(MBEDTLS_SHA256_C) + case MBEDTLS_MD_SHA224: + return( MBEDTLS_SSL_HASH_SHA224 ); + case MBEDTLS_MD_SHA256: + return( MBEDTLS_SSL_HASH_SHA256 ); +#endif +#if defined(MBEDTLS_SHA512_C) + case MBEDTLS_MD_SHA384: + return( MBEDTLS_SSL_HASH_SHA384 ); + case MBEDTLS_MD_SHA512: + return( MBEDTLS_SSL_HASH_SHA512 ); +#endif + default: + return( MBEDTLS_SSL_HASH_NONE ); + } +} + #if defined(MBEDTLS_ECP_C) /* - * Check is a curve proposed by the peer is in our list. + * Check if a curve proposed by the peer is in our list. * Return 0 if we're willing to use it, -1 otherwise. */ int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_id grp_id ) @@ -6839,6 +6871,27 @@ int mbedtls_ssl_check_curve( const mbedtls_ssl_context *ssl, mbedtls_ecp_group_i } #endif /* MBEDTLS_ECP_C */ +#if defined(MBEDTLS_KEY_EXCHANGE__SOME__SIGNATURE_ENABLED) +/* + * Check if a hash proposed by the peer is in our list. + * Return 0 if we're willing to use it, -1 otherwise. + */ +int mbedtls_ssl_check_sig_hash( const mbedtls_ssl_context *ssl, + mbedtls_md_type_t md ) +{ + const int *cur; + + if( ssl->conf->sig_hashes == NULL ) + return( -1 ); + + for( cur = ssl->conf->sig_hashes; *cur != MBEDTLS_MD_NONE; cur++ ) + if( *cur == (int) md ) + return( 0 ); + + return( -1 ); +} +#endif /* MBEDTLS_KEY_EXCHANGE__SOME__SIGNATURE_ENABLED */ + #if defined(MBEDTLS_X509_CRT_PARSE_C) int mbedtls_ssl_check_cert_usage( const mbedtls_x509_crt *cert, const mbedtls_ssl_ciphersuite_t *ciphersuite,