diff --git a/include/mbedtls/ssl_cookie.h b/include/mbedtls/ssl_cookie.h index 7f612e6b6..395768b58 100644 --- a/include/mbedtls/ssl_cookie.h +++ b/include/mbedtls/ssl_cookie.h @@ -26,6 +26,10 @@ #include "ssl.h" +#if defined(MBEDTLS_THREADING_C) +#include "threading.h" +#endif + /** * \name SECTION: Module settings * @@ -55,6 +59,9 @@ typedef struct unsigned long timeout; /*!< timeout delay, in seconds if HAVE_TIME, or in number of tickets issued */ +#if defined(MBEDTLS_THREADING_C) + mbedtls_threading_mutex_t mutex; +#endif } mbedtls_ssl_cookie_ctx; /** diff --git a/library/ssl_cookie.c b/library/ssl_cookie.c index 8b993ac1c..cc88905c4 100644 --- a/library/ssl_cookie.c +++ b/library/ssl_cookie.c @@ -83,6 +83,10 @@ void mbedtls_ssl_cookie_init( mbedtls_ssl_cookie_ctx *ctx ) ctx->serial = 0; #endif ctx->timeout = MBEDTLS_SSL_COOKIE_TIMEOUT; + +#if defined(MBEDTLS_THREADING_C) + mbedtls_mutex_init( &ctx->mutex ); +#endif } void mbedtls_ssl_cookie_set_timeout( mbedtls_ssl_cookie_ctx *ctx, unsigned long delay ) @@ -93,6 +97,12 @@ void mbedtls_ssl_cookie_set_timeout( mbedtls_ssl_cookie_ctx *ctx, unsigned long void mbedtls_ssl_cookie_free( mbedtls_ssl_cookie_ctx *ctx ) { mbedtls_md_free( &ctx->hmac_ctx ); + +#if defined(MBEDTLS_THREADING_C) + mbedtls_mutex_init( &ctx->mutex ); +#endif + + mbedtls_zeroize( ctx, sizeof( mbedtls_ssl_cookie_ctx ) ); } int mbedtls_ssl_cookie_setup( mbedtls_ssl_cookie_ctx *ctx, @@ -152,6 +162,7 @@ int mbedtls_ssl_cookie_write( void *p_ctx, unsigned char **p, unsigned char *end, const unsigned char *cli_id, size_t cli_id_len ) { + int ret; mbedtls_ssl_cookie_ctx *ctx = (mbedtls_ssl_cookie_ctx *) p_ctx; unsigned long t; @@ -173,8 +184,21 @@ int mbedtls_ssl_cookie_write( void *p_ctx, (*p)[3] = (unsigned char)( t ); *p += 4; - return( ssl_cookie_hmac( &ctx->hmac_ctx, *p - 4, - p, end, cli_id, cli_id_len ) ); +#if defined(MBEDTLS_THREADING_C) + if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + ret ); +#endif + + ret = ssl_cookie_hmac( &ctx->hmac_ctx, *p - 4, + p, end, cli_id, cli_id_len ); + +#if defined(MBEDTLS_THREADING_C) + if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + + MBEDTLS_ERR_THREADING_MUTEX_ERROR ); +#endif + + return( ret ); } /* @@ -185,6 +209,7 @@ int mbedtls_ssl_cookie_check( void *p_ctx, const unsigned char *cli_id, size_t cli_id_len ) { unsigned char ref_hmac[COOKIE_HMAC_LEN]; + int ret = 0; unsigned char *p = ref_hmac; mbedtls_ssl_cookie_ctx *ctx = (mbedtls_ssl_cookie_ctx *) p_ctx; unsigned long cur_time, cookie_time; @@ -195,10 +220,24 @@ int mbedtls_ssl_cookie_check( void *p_ctx, if( cookie_len != COOKIE_LEN ) return( -1 ); +#if defined(MBEDTLS_THREADING_C) + if( ( ret = mbedtls_mutex_lock( &ctx->mutex ) ) != 0 ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + ret ); +#endif + if( ssl_cookie_hmac( &ctx->hmac_ctx, cookie, &p, p + sizeof( ref_hmac ), cli_id, cli_id_len ) != 0 ) - return( -1 ); + ret = -1; + +#if defined(MBEDTLS_THREADING_C) + if( mbedtls_mutex_unlock( &ctx->mutex ) != 0 ) + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR + + MBEDTLS_ERR_THREADING_MUTEX_ERROR ); +#endif + + if( ret != 0 ) + return( ret ); if( mbedtls_ssl_safer_memcmp( cookie + 4, ref_hmac, sizeof( ref_hmac ) ) != 0 ) return( -1 );