diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index a39fadf7f..da13136a8 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -147,31 +147,38 @@ typedef struct void * pk_ctx; /**< Underlying public key context */ } pk_context; +/** + * \brief Return information associated with the given PK type + * + * \param type PK type to search for. + * + * \return The PK info associated with the type or NULL if not found. + */ +const pk_info_t *pk_info_from_type( pk_type_t pk_type ); + /** * \brief Initialize a pk_context (as NONE) */ void pk_init( pk_context *ctx ); +/** + * \brief Initialize a PK context with the information given + * and allocates the type-specific PK subcontext. + * + * \param ctx Context to initialize. Must be empty (type NONE). + * \param info Information to use + * + * \return 0 on success, + * POLARSSL_ERR_PK_BAD_INPUT_DATA on invalid input, + * POLARSSL_ERR_PK_MALLOC_FAILED on allocation failure. + */ +int pk_init_ctx( pk_context *ctx, const pk_info_t *info ); + /** * \brief Free a pk_context */ void pk_free( pk_context *ctx ); -/** - * \brief Set a pk_context to a given type - * - * \param ctx Context to initialize - * \param type Type of key - * - * \note Once the type of a key has been set, it cannot be reset. - * If you want to do so, you need to use pk_free() first. - * - * \return O on success, - * POLARSSL_ERR_PK_MALLOC_FAILED on memory allocation fail, - * POLARSSL_ERR_PK_TYPE_MISMATCH on attempts to reset type. - */ -int pk_set_type( pk_context *ctx, pk_type_t type ); - /** * \brief Get the size in bits of the underlying key * diff --git a/library/pk.c b/library/pk.c index 61544ebd4..4c16de8d7 100644 --- a/library/pk.c +++ b/library/pk.c @@ -67,7 +67,7 @@ void pk_free( pk_context *ctx ) /* * Get pk_info structure from type */ -static const pk_info_t * pk_info_from_type( pk_type_t pk_type ) +const pk_info_t * pk_info_from_type( pk_type_t pk_type ) { switch( pk_type ) { #if defined(POLARSSL_RSA_C) @@ -90,21 +90,11 @@ static const pk_info_t * pk_info_from_type( pk_type_t pk_type ) } /* - * Set a pk_context to a given type + * Initialise context */ -int pk_set_type( pk_context *ctx, pk_type_t type ) +int pk_init_ctx( pk_context *ctx, const pk_info_t *info ) { - const pk_info_t *info; - - if( ctx->pk_info != NULL ) - { - if( ctx->pk_info->type == type ) - return 0; - - return( POLARSSL_ERR_PK_TYPE_MISMATCH ); - } - - if( ( info = pk_info_from_type( type ) ) == NULL ) + if( ctx == NULL || info == NULL || ctx->pk_info != NULL ) return( POLARSSL_ERR_PK_BAD_INPUT_DATA ); if( ( ctx->pk_ctx = info->ctx_alloc_func() ) == NULL ) diff --git a/library/x509parse.c b/library/x509parse.c index e080174e8..4da4e7518 100644 --- a/library/x509parse.c +++ b/library/x509parse.c @@ -570,6 +570,7 @@ static int x509_get_pubkey( unsigned char **p, size_t len; x509_buf alg_params; pk_type_t pk_alg = POLARSSL_PK_NONE; + const pk_info_t *pk_info; if( ( ret = asn1_get_tag( p, end, &len, ASN1_CONSTRUCTED | ASN1_SEQUENCE ) ) != 0 ) @@ -589,7 +590,10 @@ static int x509_get_pubkey( unsigned char **p, return( POLARSSL_ERR_X509_CERT_INVALID_PUBKEY + POLARSSL_ERR_ASN1_LENGTH_MISMATCH ); - if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 ) + if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ) return( ret ); #if defined(POLARSSL_RSA_C) @@ -2142,10 +2146,12 @@ int x509parse_keyfile_rsa( rsa_context *rsa, const char *path, const char *pwd ) pk_context pk; pk_init( &pk ); - pk_set_type( &pk, POLARSSL_PK_RSA ); ret = x509parse_keyfile( &pk, path, pwd ); + if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) ) + ret = POLARSSL_ERR_PK_TYPE_MISMATCH; + if( ret == 0 ) rsa_copy( rsa, pk_rsa( pk ) ); else @@ -2165,10 +2171,12 @@ int x509parse_public_keyfile_rsa( rsa_context *rsa, const char *path ) pk_context pk; pk_init( &pk ); - pk_set_type( &pk, POLARSSL_PK_RSA ); ret = x509parse_public_keyfile( &pk, path ); + if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) ) + ret = POLARSSL_ERR_PK_TYPE_MISMATCH; + if( ret == 0 ) rsa_copy( rsa, pk_rsa( pk ) ); else @@ -2380,6 +2388,7 @@ static int x509parse_key_pkcs8_unencrypted_der( unsigned char *p = (unsigned char *) key; unsigned char *end = p + keylen; pk_type_t pk_alg = POLARSSL_PK_NONE; + const pk_info_t *pk_info; /* * This function parses the PrivatKeyInfo object (PKCS#8 v1.2 = RFC 5208) @@ -2421,7 +2430,10 @@ static int x509parse_key_pkcs8_unencrypted_der( return( POLARSSL_ERR_X509_KEY_INVALID_FORMAT + POLARSSL_ERR_ASN1_OUT_OF_DATA ); - if( ( ret = pk_set_type( pk, pk_alg ) ) != 0 ) + if( ( pk_info = pk_info_from_type( pk_alg ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 ) return( ret ); #if defined(POLARSSL_RSA_C) @@ -2568,6 +2580,7 @@ int x509parse_key( pk_context *pk, const unsigned char *pwd, size_t pwdlen ) { int ret; + const pk_info_t *pk_info; #if defined(POLARSSL_PEM_C) size_t len; @@ -2582,7 +2595,10 @@ int x509parse_key( pk_context *pk, key, pwd, pwdlen, &len ); if( ret == 0 ) { - if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA ) ) != 0 || + if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 || ( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ), pem.buf, pem.buflen ) ) != 0 ) { @@ -2607,7 +2623,10 @@ int x509parse_key( pk_context *pk, key, pwd, pwdlen, &len ); if( ret == 0 ) { - if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY ) ) != 0 || + if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 || ( ret = x509parse_key_sec1_der( pk_ec( *pk ), pem.buf, pem.buflen ) ) != 0 ) { @@ -2692,7 +2711,10 @@ int x509parse_key( pk_context *pk, pk_free( pk ); #if defined(POLARSSL_RSA_C) - if( ( ret = pk_set_type( pk, POLARSSL_PK_RSA ) ) == 0 && + if( ( pk_info = pk_info_from_type( POLARSSL_PK_RSA ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 || ( ret = x509parse_key_pkcs1_der( pk_rsa( *pk ), key, keylen ) ) == 0 ) { return( 0 ); @@ -2702,7 +2724,10 @@ int x509parse_key( pk_context *pk, #endif /* POLARSSL_RSA_C */ #if defined(POLARSSL_ECP_C) - if( ( ret = pk_set_type( pk, POLARSSL_PK_ECKEY ) ) == 0 && + if( ( pk_info = pk_info_from_type( POLARSSL_PK_ECKEY ) ) == NULL ) + return( POLARSSL_ERR_X509_UNKNOWN_PK_ALG ); + + if( ( ret = pk_init_ctx( pk, pk_info ) ) != 0 || ( ret = x509parse_key_sec1_der( pk_ec( *pk ), key, keylen ) ) == 0 ) { return( 0 ); @@ -2769,10 +2794,12 @@ int x509parse_key_rsa( rsa_context *rsa, pk_context pk; pk_init( &pk ); - pk_set_type( &pk, POLARSSL_PK_RSA ); ret = x509parse_key( &pk, key, keylen, pwd, pwdlen ); + if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) ) + ret = POLARSSL_ERR_PK_TYPE_MISMATCH; + if( ret == 0 ) rsa_copy( rsa, pk_rsa( pk ) ); else @@ -2793,10 +2820,12 @@ int x509parse_public_key_rsa( rsa_context *rsa, pk_context pk; pk_init( &pk ); - pk_set_type( &pk, POLARSSL_PK_RSA ); ret = x509parse_public_key( &pk, key, keylen ); + if( ret == 0 && ! pk_can_do( &pk, POLARSSL_PK_RSA ) ) + ret = POLARSSL_ERR_PK_TYPE_MISMATCH; + if( ret == 0 ) rsa_copy( rsa, pk_rsa( pk ) ); else