diff --git a/library/psa_crypto.c b/library/psa_crypto.c index ba43e1968..b5208f0d0 100755 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -1544,6 +1544,14 @@ psa_status_t psa_aead_encrypt( psa_key_slot_t key, if( nonce_length < 7 || nonce_length > 13 ) return( PSA_ERROR_INVALID_ARGUMENT ); + tag_length = 16; + status = psa_aead_unpadded_locate_tag( tag_length, + ciphertext, ciphertext_length, + plaintext_size, plaintext_length, + &tag ); + if( status != PSA_SUCCESS ) + return( status ); + mbedtls_ccm_init( &ccm ); ret = mbedtls_ccm_setkey( &ccm, cipher_id, slot->data.raw.data, key_bits ); @@ -1575,6 +1583,29 @@ psa_status_t psa_aead_encrypt( psa_key_slot_t key, return( PSA_SUCCESS ); } +/* Locate the tag in a ciphertext buffer containing the encrypted data + * followed by the tag. Return the length of the part preceding the tag in + * *plaintext_length. This is the size of the plaintext in modes where + * the encrypted data has the same size as the plaintext, such as + * CCM and GCM. */ +static psa_status_t psa_aead_unpadded_locate_tag( size_t tag_length, + const uint8_t *ciphertext, + size_t ciphertext_length, + size_t plaintext_size, + size_t *plaintext_length, + const uint8_t **p_tag ) +{ + size_t payload_length; + if( tag_length > ciphertext_length ) + return( PSA_ERROR_INVALID_ARGUMENT ); + payload_length = ciphertext_length - tag_length; + if( payload_length > plaintext_size ) + return( PSA_ERROR_BUFFER_TOO_SMALL ); + *p_tag = ciphertext + payload_length; + *plaintext_length = payload_length; + return( PSA_SUCCESS ); +} + psa_status_t psa_aead_decrypt( psa_key_slot_t key, psa_algorithm_t alg, const uint8_t *nonce, @@ -1592,11 +1623,11 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, key_slot_t *slot; psa_key_type_t key_type; size_t key_bits; - unsigned char tag[16]; + const uint8_t *tag; + size_t tag_length; mbedtls_cipher_id_t cipher_id; - if( plaintext_size < ciphertext_length ) - return( PSA_ERROR_INVALID_ARGUMENT ); + *plaintext_length = 0; status = psa_get_key_information( key, &key_type, &key_bits ); if( status != PSA_SUCCESS ) @@ -1622,6 +1653,14 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, { mbedtls_gcm_context gcm; + tag_length = 16; + status = psa_aead_unpadded_locate_tag( tag_length, + ciphertext, ciphertext_length, + plaintext_size, plaintext_length, + &tag ); + if( status != PSA_SUCCESS ) + return( status ); + mbedtls_gcm_init( &gcm ); ret = mbedtls_gcm_setkey( &gcm, cipher_id, slot->data.raw.data, key_bits ); @@ -1630,18 +1669,13 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, mbedtls_gcm_free( &gcm ); return( mbedtls_to_psa_error( ret ) ); } - ret = mbedtls_gcm_crypt_and_tag( &gcm, MBEDTLS_GCM_DECRYPT, - ciphertext_length, nonce, nonce_length, - additional_data, additional_data_length, - ciphertext, plaintext, - sizeof( tag ), tag ); - if( ret != 0 ) - { - mbedtls_gcm_free( &gcm ); - mbedtls_zeroize( plaintext, ciphertext_length ); - return( mbedtls_to_psa_error( ret ) ); - } + ret = mbedtls_gcm_auth_decrypt( &gcm, + *plaintext_length, + nonce, nonce_length, + additional_data, additional_data_length, + tag, tag_length, + ciphertext, plaintext ); mbedtls_gcm_free( &gcm ); } else if( alg == PSA_ALG_CCM ) @@ -1659,17 +1693,11 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, mbedtls_ccm_free( &ccm ); return( mbedtls_to_psa_error( ret ) ); } - ret = mbedtls_ccm_auth_decrypt( &ccm, ciphertext_length, - nonce, nonce_length, additional_data, - additional_data_length, ciphertext, - plaintext, tag, sizeof( tag ) ); - if( ret != 0 ) - { - mbedtls_ccm_free( &ccm ); - mbedtls_zeroize( plaintext, ciphertext_length ); - return( mbedtls_to_psa_error( ret ) ); - } - + ret = mbedtls_ccm_auth_decrypt( &ccm, *plaintext_length, + nonce, nonce_length, + additional_data, additional_data_length, + ciphertext, plaintext, + tag, tag_length ); mbedtls_ccm_free( &ccm ); } else @@ -1677,8 +1705,12 @@ psa_status_t psa_aead_decrypt( psa_key_slot_t key, return( PSA_ERROR_INVALID_ARGUMENT ); } - *plaintext_length = ciphertext_length; - return( PSA_SUCCESS ); + if( ret != 0 ) + { + mbedtls_zeroize( plaintext, *plaintext_length ); + *plaintext_length = 0; + } + return( mbedtls_to_psa_error( ret ) ); } diff --git a/tests/suites/test_suite_psa_crypto.function b/tests/suites/test_suite_psa_crypto.function index 93bb9cc2a..e36719d31 100755 --- a/tests/suites/test_suite_psa_crypto.function +++ b/tests/suites/test_suite_psa_crypto.function @@ -637,7 +637,7 @@ void aead_encrypt_decrypt( int key_type_arg, char * key_hex, TEST_ASSERT( psa_aead_decrypt( slot, alg, nonce, nonce_length, additional_data, additional_data_length, - output_data, output_length - tag_length, output_data2, + output_data, output_length, output_data2, output_length, &output_length2 ) == expected_result );