diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 4ffaeb0f6..4af56f0e9 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -4371,6 +4371,19 @@ static int ssl_handle_possible_reconnect( mbedtls_ssl_context *ssl ) } #endif /* MBEDTLS_SSL_DTLS_CLIENT_PORT_REUSE && MBEDTLS_SSL_SRV_C */ +static int ssl_check_record_type( uint8_t record_type ) +{ + if( record_type != MBEDTLS_SSL_MSG_HANDSHAKE && + record_type != MBEDTLS_SSL_MSG_ALERT && + record_type != MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC && + record_type != MBEDTLS_SSL_MSG_APPLICATION_DATA ) + { + return( MBEDTLS_ERR_SSL_INVALID_RECORD ); + } + + return( 0 ); +} + /* * ContentType type; * ProtocolVersion version; @@ -4406,10 +4419,7 @@ static int ssl_parse_record_header( mbedtls_ssl_context *ssl ) major_ver, minor_ver, ssl->in_msglen ) ); /* Check record type */ - if( ssl->in_msgtype != MBEDTLS_SSL_MSG_HANDSHAKE && - ssl->in_msgtype != MBEDTLS_SSL_MSG_ALERT && - ssl->in_msgtype != MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC && - ssl->in_msgtype != MBEDTLS_SSL_MSG_APPLICATION_DATA ) + if( ssl_check_record_type( ssl->in_msgtype ) ) { MBEDTLS_SSL_DEBUG_MSG( 1, ( "unknown record type" ) );