diff --git a/library/pk.c b/library/pk.c index 513d8cadd..a37f58e12 100644 --- a/library/pk.c +++ b/library/pk.c @@ -306,18 +306,24 @@ int pk_encrypt( pk_context *ctx, int pk_check_pair( const pk_context *pub, const pk_context *prv ) { if( pub == NULL || pub->pk_info == NULL || - prv == NULL || prv->pk_info == NULL ) + prv == NULL || prv->pk_info == NULL || + prv->pk_info->check_pair_func == NULL ) { return( POLARSSL_ERR_PK_BAD_INPUT_DATA ); } - if( pub->pk_info != prv->pk_info || - pub->pk_info->check_pair_func == NULL ) + if( prv->pk_info->type == POLARSSL_PK_RSA_ALT ) { - return( POLARSSL_ERR_PK_TYPE_MISMATCH ); + if( pub->pk_info->type != POLARSSL_PK_RSA ) + return( POLARSSL_ERR_PK_TYPE_MISMATCH ); + } + else + { + if( pub->pk_info != prv->pk_info ) + return( POLARSSL_ERR_PK_TYPE_MISMATCH ); } - return( pub->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx ) ); + return( prv->pk_info->check_pair_func( pub->pk_ctx, prv->pk_ctx ) ); } /* diff --git a/library/pk_wrap.c b/library/pk_wrap.c index 0d2a368ad..a75ab3248 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -117,10 +117,11 @@ static int rsa_encrypt_wrap( void *ctx, unsigned char *output, size_t *olen, size_t osize, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) { - ((void) osize); - *olen = ((rsa_context *) ctx)->len; + if( *olen > osize ) + return( POLARSSL_ERR_RSA_OUTPUT_TOO_LARGE ); + return( rsa_pkcs1_encrypt( (rsa_context *) ctx, f_rng, p_rng, RSA_PUBLIC, ilen, input, output ) ); } @@ -435,6 +436,34 @@ static int rsa_alt_decrypt_wrap( void *ctx, RSA_PRIVATE, olen, input, output, osize ) ); } +static int rsa_alt_check_pair( const void *pub, const void *prv ) +{ + unsigned char sig[POLARSSL_MPI_MAX_SIZE]; + unsigned char hash[32]; + size_t sig_len = 0; + int ret; + + if( rsa_alt_get_size( prv ) != rsa_get_size( pub ) ) + return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED ); + + memset( hash, 0x2a, sizeof( hash ) ); + + if( ( ret = rsa_alt_sign_wrap( (void *) prv, POLARSSL_MD_NONE, + hash, sizeof( hash ), + sig, &sig_len, NULL, NULL ) ) != 0 ) + { + return( ret ); + } + + if( rsa_verify_wrap( (void *) pub, POLARSSL_MD_NONE, + hash, sizeof( hash ), sig, sig_len ) != 0 ) + { + return( POLARSSL_ERR_RSA_KEY_CHECK_FAILED ); + } + + return( 0 ); +} + static void *rsa_alt_alloc_wrap( void ) { void *ctx = polarssl_malloc( sizeof( rsa_alt_context ) ); @@ -460,7 +489,7 @@ const pk_info_t rsa_alt_info = { rsa_alt_sign_wrap, rsa_alt_decrypt_wrap, NULL, - NULL, /* No public key */ + rsa_alt_check_pair, rsa_alt_alloc_wrap, rsa_alt_free_wrap, NULL, diff --git a/tests/suites/test_suite_pk.function b/tests/suites/test_suite_pk.function index 352169c24..c88d36588 100644 --- a/tests/suites/test_suite_pk.function +++ b/tests/suites/test_suite_pk.function @@ -83,18 +83,27 @@ exit: /* BEGIN_CASE depends_on:POLARSSL_PK_PARSE_C */ void pk_check_pair( char *pub_file, char *prv_file, int ret ) { - pk_context pub, prv; + pk_context pub, prv, alt; pk_init( &pub ); pk_init( &prv ); + pk_init( &alt ); TEST_ASSERT( pk_parse_public_keyfile( &pub, pub_file ) == 0 ); TEST_ASSERT( pk_parse_keyfile( &prv, prv_file, NULL ) == 0 ); TEST_ASSERT( pk_check_pair( &pub, &prv ) == ret ); + if( pk_get_type( &prv ) == POLARSSL_PK_RSA ) + { + TEST_ASSERT( pk_init_ctx_rsa_alt( &alt, pk_rsa( prv ), + rsa_decrypt_func, rsa_sign_func, rsa_key_len_func ) == 0 ); + TEST_ASSERT( pk_check_pair( &pub, &alt ) == ret ); + } + pk_free( &pub ); pk_free( &prv ); + pk_free( &alt ); } /* END_CASE */