aes: xts: Rewrite to avoid use of goto

The flow was a bit hard to follow with the `goto` everywhere. Rewrite the
XTS implementation to avoid the use of `goto`.
This commit is contained in:
Jaeden Amero 2018-04-28 15:02:45 +01:00
parent 0a8b02087a
commit d82cd860b2

View File

@ -1135,129 +1135,92 @@ int mbedtls_aes_crypt_xts( mbedtls_aes_xts_context *ctx,
const unsigned char *input, const unsigned char *input,
unsigned char *output ) unsigned char *output )
{ {
union xts_buf128 { int ret;
uint8_t u8[16]; size_t blocks = length / 16;
uint64_t u64[2]; size_t leftover = length % 16;
}; unsigned char tweak[16];
unsigned char prev_tweak[16];
unsigned char tmp[16];
union xts_buf128 scratch; /* Sectors must be at least 16 bytes. */
union xts_buf128 cts_scratch;
union xts_buf128 t_buf;
union xts_buf128 cts_t_buf;
union xts_buf128 *inbuf;
union xts_buf128 *outbuf;
size_t nblk = length / 16;
size_t remn = length % 16;
inbuf = (union xts_buf128*)input;
outbuf = (union xts_buf128*)output;
/* For performing the ciphertext-stealing operation, we have to get at least
* one complete block */
if( length < 16 ) if( length < 16 )
return( MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH ); return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
/* NIST SP 80-38E disallows data units larger than 2**20 blocks. */ /* NIST SP 80-38E disallows data units larger than 2**20 blocks. */
if( length > ( 1 << 20 ) * 16 ) if( length > ( 1 << 20 ) * 16 )
return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH; return MBEDTLS_ERR_AES_INVALID_INPUT_LENGTH;
mbedtls_aes_crypt_ecb( &ctx->tweak, MBEDTLS_AES_ENCRYPT, iv, t_buf.u8 ); /* Compute the tweak. */
ret = mbedtls_aes_crypt_ecb( &ctx->tweak, MBEDTLS_AES_ENCRYPT, iv, tweak );
if( ret != 0 )
return( ret );
if( mode == MBEDTLS_AES_DECRYPT && remn ) while( blocks-- )
{ {
if( nblk == 1 ) size_t i;
goto decrypt_only_one_full_block;
nblk--; if( leftover && ( mode == MBEDTLS_AES_DECRYPT ) && blocks == 0 )
{
/* We are on the last block in a decrypt operation that has
* leftover bytes, so we need to use the next tweak for this block,
* and this tweak for the lefover bytes. Save the current tweak for
* the leftovers and then update the current tweak for use on this,
* the last full block. */
memcpy( prev_tweak, tweak, sizeof( tweak ) );
mbedtls_gf128mul_x_ble( tweak, tweak );
}
for( i = 0; i < 16; i++ )
tmp[i] = input[i] ^ tweak[i];
ret = mbedtls_aes_crypt_ecb( &ctx->crypt, mode, tmp, tmp );
if( ret != 0 )
return( ret );
for( i = 0; i < 16; i++ )
output[i] = tmp[i] ^ tweak[i];
/* Update the tweak for the next block. */
mbedtls_gf128mul_x_ble( tweak, tweak );
output += 16;
input += 16;
} }
goto first; if( leftover )
do
{ {
mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 ); /* If we are on the leftover bytes in a decrypt operation, we need to
* use the previous tweak for these bytes (as saved in prev_tweak). */
unsigned char *t = mode == MBEDTLS_AES_DECRYPT ? prev_tweak : tweak;
first: /* We are now on the final part of the data unit, which doesn't divide
/* PP <- T xor P */ * evenly by 16. It's time for ciphertext stealing. */
scratch.u64[0] = (uint64_t)( inbuf->u64[0] ^ t_buf.u64[0] ); size_t i;
scratch.u64[1] = (uint64_t)( inbuf->u64[1] ^ t_buf.u64[1] ); unsigned char *prev_output = output - 16;
/* CC <- E(Key2,PP) */ /* Copy ciphertext bytes from the previous block to our output for each
mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, outbuf->u8 ); * byte of cyphertext we won't steal. At the same time, copy the
* remainder of the input for this final round (since the loop bounds
/* C <- T xor CC */ * are the same). */
outbuf->u64[0] = (uint64_t)( outbuf->u64[0] ^ t_buf.u64[0] ); for( i = 0; i < leftover; i++ )
outbuf->u64[1] = (uint64_t)( outbuf->u64[1] ^ t_buf.u64[1] );
inbuf += 1;
outbuf += 1;
nblk -= 1;
} while( nblk > 0 );
/* Ciphertext stealing, if necessary */
if( remn != 0 )
{
outbuf = (union xts_buf128*)output;
inbuf = (union xts_buf128*)input;
nblk = length / 16;
if( mode == MBEDTLS_AES_ENCRYPT )
{ {
memcpy( cts_scratch.u8, (uint8_t*)&inbuf[nblk], remn ); output[i] = prev_output[i];
memcpy( cts_scratch.u8 + remn, ((uint8_t*)&outbuf[nblk - 1]) + remn, 16 - remn ); tmp[i] = input[i] ^ t[i];
memcpy( (uint8_t*)&outbuf[nblk], (uint8_t*)&outbuf[nblk - 1], remn );
mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
/* PP <- T xor P */
scratch.u64[0] = (uint64_t)( cts_scratch.u64[0] ^ t_buf.u64[0] );
scratch.u64[1] = (uint64_t)( cts_scratch.u64[1] ^ t_buf.u64[1] );
/* CC <- E(Key2,PP) */
mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
/* C <- T xor CC */
outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
outbuf[nblk - 1].u64[1] = (uint64_t)( scratch.u64[1] ^ t_buf.u64[1] );
} }
else /* AES_DECRYPT */
{
mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 );
decrypt_only_one_full_block: /* Copy ciphertext bytes from the previous block for input in this
cts_t_buf.u64[0] = t_buf.u64[0]; * round. */
cts_t_buf.u64[1] = t_buf.u64[1]; for( ; i < 16; i++ )
tmp[i] = prev_output[i] ^ t[i];
mbedtls_gf128mul_x_ble( t_buf.u8, t_buf.u8 ); ret = mbedtls_aes_crypt_ecb( &ctx->crypt, mode, tmp, tmp );
if( ret != 0 )
return ret;
/* PP <- T xor P */ /* Write the result back to the previous block, overriding the previous
scratch.u64[0] = (uint64_t)( inbuf[nblk - 1].u64[0] ^ t_buf.u64[0] ); * output we copied. */
scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ t_buf.u64[1] ); for( i = 0; i < 16; i++ )
prev_output[i] = tmp[i] ^ t[i];
/* CC <- E(Key2,PP) */
mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
/* C <- T xor CC */
cts_scratch.u64[0] = (uint64_t)( scratch.u64[0] ^ t_buf.u64[0] );
cts_scratch.u64[1] = (uint64_t)( scratch.u64[1] ^ t_buf.u64[1] );
memcpy( (uint8_t*)&inbuf[nblk - 1], (uint8_t*)&inbuf[nblk], remn );
memcpy( (uint8_t*)&inbuf[nblk - 1] + remn, cts_scratch.u8 + remn, 16 - remn );
memcpy( (uint8_t*)&outbuf[nblk], cts_scratch.u8, remn );
/* PP <- T xor P */
scratch.u64[0] = (uint64_t)( inbuf[nblk - 1].u64[0] ^ cts_t_buf.u64[0] );
scratch.u64[1] = (uint64_t)( inbuf[nblk - 1].u64[1] ^ cts_t_buf.u64[1] );
/* CC <- E(Key2,PP) */
mbedtls_aes_crypt_ecb( &ctx->crypt, mode, scratch.u8, scratch.u8 );
/* C <- T xor CC */
outbuf[nblk - 1].u64[0] = (uint64_t)( scratch.u64[0] ^ cts_t_buf.u64[0] );
outbuf[nblk - 1].u64[1] = (uint64_t)( scratch.u64[1] ^ cts_t_buf.u64[1] );
}
} }
return( 0 ); return( 0 );