diff --git a/library/psa_crypto.c b/library/psa_crypto.c index fac1c7564..30b68faf6 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -541,6 +541,43 @@ static psa_status_t prepare_raw_data_slot( psa_key_type_t type, return( PSA_SUCCESS ); } +#if defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C) +static psa_status_t psa_import_rsa_key( mbedtls_pk_context *pk, + mbedtls_rsa_context **p_rsa ) +{ + if( mbedtls_pk_get_type( pk ) != MBEDTLS_PK_RSA ) + return( PSA_ERROR_INVALID_ARGUMENT ); + else + { + mbedtls_rsa_context *rsa = mbedtls_pk_rsa( *pk ); + size_t bits = mbedtls_rsa_get_bitlen( rsa ); + if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS ) + return( PSA_ERROR_NOT_SUPPORTED ); + *p_rsa = rsa; + return( PSA_SUCCESS ); + } +} +#endif /* defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PK_PARSE_C) */ + +#if defined(MBEDTLS_ECP_C) && defined(MBEDTLS_PK_PARSE_C) +static psa_status_t psa_import_ecp_key( psa_ecc_curve_t expected_curve, + mbedtls_pk_context *pk, + mbedtls_ecp_keypair **p_ecp ) +{ + if( mbedtls_pk_get_type( pk ) != MBEDTLS_PK_ECKEY ) + return( PSA_ERROR_INVALID_ARGUMENT ); + else + { + mbedtls_ecp_keypair *ecp = mbedtls_pk_ec( *pk ); + psa_ecc_curve_t actual_curve = mbedtls_ecc_group_to_psa( ecp->grp.id ); + if( actual_curve != expected_curve ) + return( PSA_ERROR_INVALID_ARGUMENT ); + *p_ecp = ecp; + return( PSA_SUCCESS ); + } +} +#endif /* defined(MBEDTLS_ECP_C) && defined(MBEDTLS_PK_PARSE_C) */ + psa_status_t psa_import_key( psa_key_slot_t key, psa_key_type_t type, const uint8_t *data, @@ -572,55 +609,33 @@ psa_status_t psa_import_key( psa_key_slot_t key, int ret; mbedtls_pk_context pk; mbedtls_pk_init( &pk ); + + /* Parse the data. */ if( PSA_KEY_TYPE_IS_KEYPAIR( type ) ) ret = mbedtls_pk_parse_key( &pk, data, data_length, NULL, 0 ); else ret = mbedtls_pk_parse_public_key( &pk, data, data_length ); if( ret != 0 ) return( mbedtls_to_psa_error( ret ) ); - switch( mbedtls_pk_get_type( &pk ) ) - { + + /* We have something that the pkparse module recognizes. + * If it has the expected type and passes any type-specific + * checks, store it. */ #if defined(MBEDTLS_RSA_C) - case MBEDTLS_PK_RSA: - if( PSA_KEY_TYPE_IS_RSA( type ) ) - { - mbedtls_rsa_context *rsa = mbedtls_pk_rsa( pk ); - size_t bits = mbedtls_rsa_get_bitlen( rsa ); - if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS ) - { - status = PSA_ERROR_NOT_SUPPORTED; - break; - } - slot->data.rsa = rsa; - } - else - status = PSA_ERROR_INVALID_ARGUMENT; - break; + if( PSA_KEY_TYPE_IS_RSA( type ) ) + status = psa_import_rsa_key( &pk, &slot->data.rsa ); + else #endif /* MBEDTLS_RSA_C */ #if defined(MBEDTLS_ECP_C) - case MBEDTLS_PK_ECKEY: - if( PSA_KEY_TYPE_IS_ECC( type ) ) - { - mbedtls_ecp_keypair *ecp = mbedtls_pk_ec( pk ); - psa_ecc_curve_t actual_curve = - mbedtls_ecc_group_to_psa( ecp->grp.id ); - psa_ecc_curve_t expected_curve = - PSA_KEY_TYPE_GET_CURVE( type ); - if( actual_curve != expected_curve ) - { - status = PSA_ERROR_INVALID_ARGUMENT; - break; - } - slot->data.ecp = ecp; - } - else - status = PSA_ERROR_INVALID_ARGUMENT; - break; + if( PSA_KEY_TYPE_IS_ECC( type ) ) + status = psa_import_ecp_key( PSA_KEY_TYPE_GET_CURVE( type ), + &pk, &slot->data.ecp ); + else #endif /* MBEDTLS_ECP_C */ - default: - status = PSA_ERROR_INVALID_ARGUMENT; - break; + { + status = PSA_ERROR_NOT_SUPPORTED; } + /* Free the content of the pk object only on error. On success, * the content of the object has been stored in the slot. */ if( status != PSA_SUCCESS )