diff --git a/include/mbedtls/pk_info.h b/include/mbedtls/pk_info.h index a8b735fd1..a4bba4680 100644 --- a/include/mbedtls/pk_info.h +++ b/include/mbedtls/pk_info.h @@ -202,7 +202,7 @@ struct mbedtls_pk_info_t * is guaranteed to be initialized. * * Opaque implementations may omit this method. */ - int (*check_pair_func)( const mbedtls_pk_context *pub, const void *prv ); + int (*check_pair_func)( const mbedtls_pk_context *pub, const mbedtls_pk_context *prv ); /** Allocate a new context * diff --git a/library/pk.c b/library/pk.c index 980256a25..ac9635cb7 100644 --- a/library/pk.c +++ b/library/pk.c @@ -329,14 +329,14 @@ int mbedtls_pk_check_pair( const mbedtls_pk_context *pub, const mbedtls_pk_conte return( MBEDTLS_ERR_PK_FEATURE_UNAVAILABLE ); } - if( prv->pk_info->type != MBEDTLS_PK_RSA_ALT && - prv->pk_info->type != MBEDTLS_PK_OPAQUE ) - { - if( pub->pk_info != prv->pk_info ) - return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); - } + if( prv->pk_info->type != MBEDTLS_PK_OPAQUE && + prv->pk_info->type != MBEDTLS_PK_RSA_ALT ) + { + if( pub->pk_info != prv->pk_info ) + return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); + } - return( prv->pk_info->check_pair_func( pub, prv->pk_ctx ) ); + return( prv->pk_info->check_pair_func( pub, prv ) ); } /* diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 6098ac178..d90228c9e 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -154,9 +154,10 @@ static int rsa_encrypt_wrap( void *ctx, ilen, input, output ) ); } -static int rsa_check_pair_wrap( const mbedtls_pk_context *pub, const void *prv ) +static int rsa_check_pair_wrap( const mbedtls_pk_context *pub, + const mbedtls_pk_context *prv ) { - return( mbedtls_rsa_check_pub_priv( pub->pk_ctx, prv ) ); + return( mbedtls_rsa_check_pub_priv( pub->pk_ctx, prv->pk_ctx ) ); } static void *rsa_alloc_wrap( void ) @@ -277,9 +278,10 @@ static size_t ecdsa_signature_size( const void *ctx_arg ) #endif /* MBEDTLS_ECDSA_C */ -static int eckey_check_pair( const mbedtls_pk_context *pub, const void *prv ) +static int eckey_check_pair( const mbedtls_pk_context *pub, + const mbedtls_pk_context *prv ) { - return( mbedtls_ecp_check_pub_priv( pub->pk_ctx, prv ) ); + return( mbedtls_ecp_check_pub_priv( pub->pk_ctx, prv->pk_ctx ) ); } static void *eckey_alloc_wrap( void ) @@ -480,26 +482,25 @@ static int rsa_alt_decrypt_wrap( void *ctx, } #if defined(MBEDTLS_RSA_C) -static int rsa_alt_check_pair( const mbedtls_pk_context *pub, const void *prv ) +static int rsa_alt_check_pair( const mbedtls_pk_context *pub, + const mbedtls_pk_context *prv ) { unsigned char sig[MBEDTLS_MPI_MAX_SIZE]; unsigned char hash[32]; size_t sig_len = 0; int ret; - const mbedtls_pk_context* prv_context = prv; - - if( prv_context->pk_info->type == MBEDTLS_PK_RSA_ALT ) + if( prv->pk_info->type == MBEDTLS_PK_RSA_ALT ) { if( pub->pk_info->type != MBEDTLS_PK_RSA ) return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); } - if( rsa_alt_get_bitlen( prv ) != rsa_get_bitlen( pub->pk_ctx ) ) + if( rsa_alt_get_bitlen( prv->pk_ctx ) != rsa_get_bitlen( pub->pk_ctx ) ) return( MBEDTLS_ERR_RSA_KEY_CHECK_FAILED ); memset( hash, 0x2a, sizeof( hash ) ); - if( ( ret = rsa_alt_sign_wrap( (void *) prv, MBEDTLS_MD_NONE, + if( ( ret = rsa_alt_sign_wrap( (void *) prv->pk_ctx, MBEDTLS_MD_NONE, hash, sizeof( hash ), sig, &sig_len, NULL, NULL ) ) != 0 ) { diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function index d9246cf80..e1c123012 100644 --- a/tests/suites/test_suite_pk.function +++ b/tests/suites/test_suite_pk.function @@ -273,9 +273,9 @@ exit: } static int opaque_mock_check_pair_func( const mbedtls_pk_context *pub, - const void *prv ) + const mbedtls_pk_context *prv ) { - TEST_ASSERT( prv == &opaque_mock_fake_ctx ); + TEST_ASSERT( prv->pk_ctx == &opaque_mock_fake_ctx ); if( mbedtls_pk_get_type( pub ) != MBEDTLS_PK_RSA ) return( MBEDTLS_ERR_PK_TYPE_MISMATCH ); return( 0 );