diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index 778efa703..6a3d4b8a7 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -52,7 +52,7 @@ * \warning You must make sure the PK context actually holds an RSA context * before using this macro! */ -#define pk_rsa( pk ) ( (rsa_context *) (pk).data ) +#define pk_rsa( pk ) ( (rsa_context *) (pk).pk_ctx ) #endif /* POLARSSL_RSA_C */ #if defined(POLARSSL_ECP_C) @@ -62,7 +62,7 @@ * \warning You must make sure the PK context actually holds an EC context * before using this macro! */ -#define pk_ec( pk ) ( (ecp_keypair *) (pk).data ) +#define pk_ec( pk ) ( (ecp_keypair *) (pk).pk_ctx ) #endif /* POLARSSL_ECP_C */ @@ -105,7 +105,7 @@ typedef struct #define POLARSSL_PK_DEBUG_MAX_ITEMS 3 /** - * \brief Public key info + * \brief Public key information and operations */ typedef struct { @@ -142,8 +142,8 @@ typedef struct */ typedef struct { - const pk_info_t * info; /**< Public key informations */ - void * data; /**< Public key data */ + const pk_info_t * pk_info; /**< Public key informations */ + void * pk_ctx; /**< Underlying public key context */ } pk_context; /** @@ -217,6 +217,15 @@ int pk_verify( pk_context *ctx, */ int pk_debug( const pk_context *ctx, pk_debug_item *items ); +/** + * \brief Access the type name + * + * \param ctx Context to use + * + * \return Type name on success, or "invalid PK" + */ +const char * pk_get_name( const pk_context *ctx ); + #ifdef __cplusplus } #endif diff --git a/library/pk.c b/library/pk.c index f3c64cb42..d8b4c8598 100644 --- a/library/pk.c +++ b/library/pk.c @@ -55,8 +55,8 @@ void pk_init( pk_context *ctx ) if( ctx == NULL ) return; - ctx->info = NULL; - ctx->data = NULL; + ctx->pk_info = NULL; + ctx->pk_ctx = NULL; } /* @@ -64,13 +64,13 @@ void pk_init( pk_context *ctx ) */ void pk_free( pk_context *ctx ) { - if( ctx == NULL || ctx->info == NULL) + if( ctx == NULL || ctx->pk_info == NULL) return; - ctx->info->ctx_free_func( ctx->data ); - ctx->data = NULL; + ctx->pk_info->ctx_free_func( ctx->pk_ctx ); + ctx->pk_ctx = NULL; - ctx->info = NULL; + ctx->pk_info = NULL; } /* @@ -105,9 +105,9 @@ int pk_set_type( pk_context *ctx, pk_type_t type ) { const pk_info_t *info; - if( ctx->info != NULL ) + if( ctx->pk_info != NULL ) { - if( ctx->info->type == type ) + if( ctx->pk_info->type == type ) return 0; return( POLARSSL_ERR_PK_TYPE_MISMATCH ); @@ -116,10 +116,10 @@ int pk_set_type( pk_context *ctx, pk_type_t type ) if( ( info = pk_info_from_type( type ) ) == NULL ) return( POLARSSL_ERR_PK_TYPE_MISMATCH ); - if( ( ctx->data = info->ctx_alloc_func() ) == NULL ) + if( ( ctx->pk_ctx = info->ctx_alloc_func() ) == NULL ) return( POLARSSL_ERR_PK_MALLOC_FAILED ); - ctx->info = info; + ctx->pk_info = info; return( 0 ); } @@ -130,10 +130,10 @@ int pk_set_type( pk_context *ctx, pk_type_t type ) int pk_can_do( pk_context *ctx, pk_type_t type ) { /* null of NONE context can't do anything */ - if( ctx == NULL || ctx->info == NULL ) + if( ctx == NULL || ctx->pk_info == NULL ) return( 0 ); - return( ctx->info->can_do( type ) ); + return( ctx->pk_info->can_do( type ) ); } /* @@ -143,10 +143,10 @@ int pk_verify( pk_context *ctx, const unsigned char *hash, const md_info_t *md_info, const unsigned char *sig, size_t sig_len ) { - if( ctx == NULL || ctx->info == NULL ) + if( ctx == NULL || ctx->pk_info == NULL ) return( POLARSSL_ERR_PK_TYPE_MISMATCH ); // TODO - return( ctx->info->verify_func( ctx->data, hash, md_info, sig, sig_len ) ); + return( ctx->pk_info->verify_func( ctx->pk_ctx, hash, md_info, sig, sig_len ) ); } /* @@ -154,10 +154,10 @@ int pk_verify( pk_context *ctx, */ size_t pk_get_size( const pk_context *ctx ) { - if( ctx == NULL || ctx->info == NULL ) + if( ctx == NULL || ctx->pk_info == NULL ) return( 0 ); - return( ctx->info->get_size( ctx->data ) ); + return( ctx->pk_info->get_size( ctx->pk_ctx ) ); } /* @@ -165,9 +165,20 @@ size_t pk_get_size( const pk_context *ctx ) */ int pk_debug( const pk_context *ctx, pk_debug_item *items ) { - if( ctx == NULL || ctx->info == NULL ) + if( ctx == NULL || ctx->pk_info == NULL ) return( POLARSSL_ERR_PK_TYPE_MISMATCH ); // TODO - ctx->info->debug_func( ctx->data, items ); + ctx->pk_info->debug_func( ctx->pk_ctx, items ); return( 0 ); } + +/* + * Access the PK type name + */ +const char * pk_get_name( const pk_context *ctx ) +{ + if( ctx == NULL || ctx->pk_info == NULL ) + return( "invalid PK" ); + + return( ctx->pk_info->name ); +} diff --git a/library/x509parse.c b/library/x509parse.c index 225f45d18..e080174e8 100644 --- a/library/x509parse.c +++ b/library/x509parse.c @@ -2147,7 +2147,7 @@ int x509parse_keyfile_rsa( rsa_context *rsa, const char *path, const char *pwd ) ret = x509parse_keyfile( &pk, path, pwd ); if( ret == 0 ) - rsa_copy( rsa, pk.data ); + rsa_copy( rsa, pk_rsa( pk ) ); else rsa_free( rsa ); @@ -2170,7 +2170,7 @@ int x509parse_public_keyfile_rsa( rsa_context *rsa, const char *path ) ret = x509parse_public_keyfile( &pk, path ); if( ret == 0 ) - rsa_copy( rsa, pk.data ); + rsa_copy( rsa, pk_rsa( pk ) ); else rsa_free( rsa ); @@ -2774,7 +2774,7 @@ int x509parse_key_rsa( rsa_context *rsa, ret = x509parse_key( &pk, key, keylen, pwd, pwdlen ); if( ret == 0 ) - rsa_copy( rsa, pk.data ); + rsa_copy( rsa, pk_rsa( pk ) ); else rsa_free( rsa ); @@ -2798,7 +2798,7 @@ int x509parse_public_key_rsa( rsa_context *rsa, ret = x509parse_public_key( &pk, key, keylen ); if( ret == 0 ) - rsa_copy( rsa, pk.data ); + rsa_copy( rsa, pk_rsa( pk ) ); else rsa_free( rsa ); @@ -3141,7 +3141,7 @@ int x509parse_cert_info( char *buf, size_t size, const char *prefix, SAFE_SNPRINTF(); if( ( ret = x509_key_size_helper( key_size_str, BEFORE_COLON, - crt->pk.info->name ) ) != 0 ) + pk_get_name( &crt->pk ) ) ) != 0 ) { return( ret ); } diff --git a/tests/suites/test_suite_x509parse.function b/tests/suites/test_suite_x509parse.function index 6bda6faab..ce27a9f39 100644 --- a/tests/suites/test_suite_x509parse.function +++ b/tests/suites/test_suite_x509parse.function @@ -227,7 +227,7 @@ void x509parse_public_keyfile_ec( char *key_file, int result ) { ecp_keypair *eckey; TEST_ASSERT( pk_can_do( &ctx, POLARSSL_PK_ECKEY ) ); - eckey = (ecp_keypair *) ctx.data; + eckey = pk_ec( ctx ); TEST_ASSERT( ecp_check_pubkey( &eckey->grp, &eckey->Q ) == 0 ); } @@ -251,7 +251,7 @@ void x509parse_keyfile_ec( char *key_file, char *password, int result ) { ecp_keypair *eckey; TEST_ASSERT( pk_can_do( &ctx, POLARSSL_PK_ECKEY ) ); - eckey = (ecp_keypair *) ctx.data; + eckey = pk_ec( ctx ); TEST_ASSERT( ecp_check_privkey( &eckey->grp, &eckey->d ) == 0 ); }