diff --git a/library/ssl_msg.c b/library/ssl_msg.c index e5def644f..981d94e16 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -1048,6 +1048,12 @@ int mbedtls_ssl_encrypt_buf( mbedtls_ssl_context *ssl, * Turn a bit into a mask: * - if bit == 1, return the all-bits 1 mask, aka (size_t) -1 * - if bit == 0, return the all-bits 0 mask, aka 0 + * + * This function can be used to write constant-time code by replacing branches + * with bit operations using masks. + * + * This function is implemented without using comparison operators, as those + * might be translated to branches by some compilers on some platforms. */ static size_t mbedtls_ssl_cf_mask_from_bit( size_t bit ) { @@ -1068,15 +1074,18 @@ static size_t mbedtls_ssl_cf_mask_from_bit( size_t bit ) * - if x < y, return all bits 1, that is (size_t) -1 * - otherwise, return all bits 0, that is 0 * - * Use only bit operations to avoid branches that could be used by some - * compilers on some platforms to translate comparison operators. + * This function can be used to write constant-time code by replacing branches + * with bit operations using masks. + * + * This function is implemented without using comparison operators, as those + * might be translated to branches by some compilers on some platforms. */ static size_t mbedtls_ssl_cf_mask_lt( size_t x, size_t y ) { - /* This has the msb set if and only if x < y */ + /* This has the most significant bit set if and only if x < y */ const size_t sub = x - y; - /* sub1 = (x < y) in {0, 1} */ + /* sub1 = (x < y) ? 1 : 0 */ const size_t sub1 = sub >> ( sizeof( sub ) * 8 - 1 ); /* mask = (x < y) ? 0xff... : 0x00... */ @@ -1090,8 +1099,11 @@ static size_t mbedtls_ssl_cf_mask_lt( size_t x, size_t y ) * - if x >= y, return all bits 1, that is (size_t) -1 * - otherwise, return all bits 0, that is 0 * - * Use only bit operations to avoid branches that could be used by some - * compilers on some platforms to translate comparison operators. + * This function can be used to write constant-time code by replacing branches + * with bit operations using masks. + * + * This function is implemented without using comparison operators, as those + * might be translated to branches by some compilers on some platforms. */ static size_t mbedtls_ssl_cf_mask_ge( size_t x, size_t y ) { @@ -1102,8 +1114,12 @@ static size_t mbedtls_ssl_cf_mask_ge( size_t x, size_t y ) * Constant-flow boolean "equal" comparison: * return x == y * - * Use only bit operations to avoid branches that could be used by some - * compilers on some platforms to translate comparison operators. + * This function can be used to write constant-time code by replacing branches + * with bit operations - it can be used in conjunction with + * mbedtls_ssl_cf_mask_from_bit(). + * + * This function is implemented without using comparison operators, as those + * might be translated to branches by some compilers on some platforms. */ static size_t mbedtls_ssl_cf_bool_eq( size_t x, size_t y ) { @@ -1124,7 +1140,7 @@ static size_t mbedtls_ssl_cf_bool_eq( size_t x, size_t y ) #pragma warning( pop ) #endif - /* diff1 = (x != y) in {0, 1} */ + /* diff1 = (x != y) ? 1 : 0 */ const size_t diff1 = diff_msb >> ( sizeof( diff_msb ) * 8 - 1 ); return( 1 ^ diff1 ); @@ -1136,8 +1152,8 @@ static size_t mbedtls_ssl_cf_bool_eq( size_t x, size_t y ) * - otherwise, a no-op, * but with execution flow independent of the values of c1 and c2. * - * Use only bit operations to avoid branches that could be used by some - * compilers on some platforms to translate comparison operators. + * This function is implemented without using comparison operators, as those + * might be translated to branches by some compilers on some platforms. */ static void mbedtls_ssl_cf_memcpy_if_eq( unsigned char *dst, const unsigned char *src, @@ -1148,9 +1164,9 @@ static void mbedtls_ssl_cf_memcpy_if_eq( unsigned char *dst, const size_t equal = mbedtls_ssl_cf_bool_eq( c1, c2 ); const unsigned char mask = (unsigned char) mbedtls_ssl_cf_mask_from_bit( equal ); - /* dst[i] = c1 != c2 ? dst[i] : src[i] */ + /* dst[i] = c1 == c2 ? src[i] : dst[i] */ for( size_t i = 0; i < len; i++ ) - dst[i] = ( dst[i] & ~mask ) | ( src[i] & mask ); + dst[i] = ( src[i] & mask ) | ( dst[i] & ~mask ); } /* @@ -1630,7 +1646,8 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, { /* This is the SSL 3.0 path, we don't have to worry about Lucky * 13, because there's a strictly worse padding attack built in - * the protocol (known as part of POODLE), so branches are OK. */ + * the protocol (known as part of POODLE), so we don't care if the + * code is not constant-time, in particular branches are OK. */ if( padlen > transform->ivlen ) { #if defined(MBEDTLS_SSL_DEBUG_ALL) @@ -1666,7 +1683,7 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context const *ssl, for( idx = start_idx; idx < rec->data_len; idx++ ) { /* pad_count += (idx >= padding_idx) && - * (chech[idx] == padlen - 1); + * (check[idx] == padlen - 1); */ const size_t mask = mbedtls_ssl_cf_mask_ge( idx, padding_idx ); const size_t equal = mbedtls_ssl_cf_bool_eq( check[idx],