Handle DTLS version encoding and fix some checks

This commit is contained in:
Manuel Pégourié-Gonnard 2014-02-11 18:15:03 +01:00 committed by Paul Bakker
parent 864a81fdc0
commit abc7e3b4ba
4 changed files with 105 additions and 35 deletions

View File

@ -1818,6 +1818,11 @@ int ssl_check_cert_usage( const x509_crt *cert,
int cert_endpoint ); int cert_endpoint );
#endif /* POLARSSL_X509_CRT_PARSE_C */ #endif /* POLARSSL_X509_CRT_PARSE_C */
void ssl_write_version( int major, int minor, int transport,
unsigned char ver[2] );
void ssl_read_version( int *major, int *minor, int transport,
const unsigned char ver[2] );
/* constant-time buffer comparison */ /* constant-time buffer comparison */
static inline int safer_memcmp( const void *a, const void *b, size_t n ) static inline int safer_memcmp( const void *a, const void *b, size_t n )
{ {

View File

@ -486,8 +486,9 @@ static int ssl_write_client_hello( ssl_context *ssl )
buf = ssl->out_msg; buf = ssl->out_msg;
p = buf + 4; p = buf + 4;
*p++ = (unsigned char) ssl->max_major_ver; ssl_write_version( ssl->max_major_ver, ssl->max_minor_ver,
*p++ = (unsigned char) ssl->max_minor_ver; ssl->transport, p );
p += 2;
SSL_DEBUG_MSG( 3, ( "client hello, max version: [%d:%d]", SSL_DEBUG_MSG( 3, ( "client hello, max version: [%d:%d]",
buf[4], buf[5] ) ); buf[4], buf[5] ) );
@ -932,26 +933,25 @@ static int ssl_parse_server_hello( ssl_context *ssl )
buf[4], buf[5] ) ); buf[4], buf[5] ) );
if( ssl->in_hslen < 42 || if( ssl->in_hslen < 42 ||
buf[0] != SSL_HS_SERVER_HELLO || buf[0] != SSL_HS_SERVER_HELLO )
buf[4] != SSL_MAJOR_VERSION_3 )
{ {
SSL_DEBUG_MSG( 1, ( "bad server hello message" ) ); SSL_DEBUG_MSG( 1, ( "bad server hello message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO ); return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
} }
if( buf[5] > ssl->max_minor_ver ) ssl_read_version( &ssl->major_ver, &ssl->minor_ver,
{ ssl->transport, buf + 4 );
SSL_DEBUG_MSG( 1, ( "bad server hello message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_SERVER_HELLO );
}
ssl->minor_ver = buf[5]; if( ssl->major_ver < ssl->min_major_ver ||
ssl->minor_ver < ssl->min_minor_ver ||
if( ssl->minor_ver < ssl->min_minor_ver ) ssl->major_ver > ssl->max_major_ver ||
ssl->minor_ver > ssl->max_minor_ver )
{ {
SSL_DEBUG_MSG( 1, ( "server only supports ssl smaller than minimum" SSL_DEBUG_MSG( 1, ( "server version out of bounds - "
" [%d:%d] < [%d:%d]", ssl->major_ver, " min: [%d:%d], server: [%d:%d], max: [%d:%d]",
ssl->minor_ver, buf[4], buf[5] ) ); ssl->min_major_ver, ssl->min_minor_ver,
ssl->major_ver, ssl->minor_ver,
ssl->max_major_ver, ssl->max_minor_ver ) );
ssl_send_alert_message( ssl, SSL_ALERT_LEVEL_FATAL, ssl_send_alert_message( ssl, SSL_ALERT_LEVEL_FATAL,
SSL_ALERT_MSG_PROTOCOL_VERSION ); SSL_ALERT_MSG_PROTOCOL_VERSION );
@ -1404,8 +1404,8 @@ static int ssl_write_encrypted_pms( ssl_context *ssl,
* opaque random[46]; * opaque random[46];
* } PreMasterSecret; * } PreMasterSecret;
*/ */
p[0] = (unsigned char) ssl->max_major_ver; ssl_write_version( ssl->max_major_ver, ssl->max_minor_ver,
p[1] = (unsigned char) ssl->max_minor_ver; ssl->transport, p );
if( ( ret = ssl->f_rng( ssl->p_rng, p + 2, 46 ) ) != 0 ) if( ( ret = ssl->f_rng( ssl->p_rng, p + 2, 46 ) ) != 0 )
{ {

View File

@ -1129,6 +1129,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
int handshake_failure = 0; int handshake_failure = 0;
const int *ciphersuites; const int *ciphersuites;
const ssl_ciphersuite_t *ciphersuite_info; const ssl_ciphersuite_t *ciphersuite_info;
int major, minor;
SSL_DEBUG_MSG( 2, ( "=> parse client hello" ) ); SSL_DEBUG_MSG( 2, ( "=> parse client hello" ) );
@ -1142,7 +1143,7 @@ static int ssl_parse_client_hello( ssl_context *ssl )
buf = ssl->in_hdr; buf = ssl->in_hdr;
#if defined(POLARSSL_SSL_SRV_SUPPORT_SSLV2_CLIENT_HELLO) #if defined(POLARSSL_SSL_SRV_SUPPORT_SSLV2_CLIENT_HELLO)
if( ( buf[0] & 0x80 ) != 0 ) if( ssl->transport == SSL_TRANSPORT_STREAM && ( buf[0] & 0x80 ) != 0 )
return ssl_parse_client_hello_v2( ssl ); return ssl_parse_client_hello_v2( ssl );
#endif #endif
@ -1163,13 +1164,19 @@ static int ssl_parse_client_hello( ssl_context *ssl )
* 1 . 2 protocol version * 1 . 2 protocol version
* 3 . 4 message length * 3 . 4 message length
*/ */
if( buf[0] != SSL_MSG_HANDSHAKE )
{
SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
}
ssl_read_version( &major, &minor, ssl->transport, buf + 1 );
/* According to RFC 5246 Appendix E.1, the version here is typically /* According to RFC 5246 Appendix E.1, the version here is typically
* "{03,00}, the lowest version number supported by the client, [or] the * "{03,00}, the lowest version number supported by the client, [or] the
* value of ClientHello.client_version", so the only meaningful check here * value of ClientHello.client_version", so the only meaningful check here
* is the major version shouldn't be less than 3 */ * is the major version shouldn't be less than 3 */
if( buf[0] != SSL_MSG_HANDSHAKE || if( major < SSL_MAJOR_VERSION_3 )
buf[1] < SSL_MAJOR_VERSION_3 )
{ {
SSL_DEBUG_MSG( 1, ( "bad client hello message" ) ); SSL_DEBUG_MSG( 1, ( "bad client hello message" ) );
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
@ -1231,8 +1238,8 @@ static int ssl_parse_client_hello( ssl_context *ssl )
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO );
} }
ssl->major_ver = buf[4]; ssl_read_version( &ssl->major_ver, &ssl->minor_ver,
ssl->minor_ver = buf[5]; ssl->transport, buf + 4 );
ssl->handshake->max_major_ver = ssl->major_ver; ssl->handshake->max_major_ver = ssl->major_ver;
ssl->handshake->max_minor_ver = ssl->minor_ver; ssl->handshake->max_minor_ver = ssl->minor_ver;
@ -1782,8 +1789,9 @@ static int ssl_write_server_hello( ssl_context *ssl )
buf = ssl->out_msg; buf = ssl->out_msg;
p = buf + 4; p = buf + 4;
*p++ = (unsigned char) ssl->major_ver; ssl_write_version( ssl->major_ver, ssl->minor_ver,
*p++ = (unsigned char) ssl->minor_ver; ssl->transport, p );
p += 2;
SSL_DEBUG_MSG( 3, ( "server hello, chosen version: [%d:%d]", SSL_DEBUG_MSG( 3, ( "server hello, chosen version: [%d:%d]",
buf[4], buf[5] ) ); buf[4], buf[5] ) );
@ -2564,6 +2572,7 @@ static int ssl_parse_encrypted_pms( ssl_context *ssl,
int ret; int ret;
size_t len = pk_get_len( ssl_own_key( ssl ) ); size_t len = pk_get_len( ssl_own_key( ssl ) );
unsigned char *pms = ssl->handshake->premaster + pms_offset; unsigned char *pms = ssl->handshake->premaster + pms_offset;
unsigned char ver[2];
if( ! pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_RSA ) ) if( ! pk_can_do( ssl_own_key( ssl ), POLARSSL_PK_RSA ) )
{ {
@ -2593,14 +2602,18 @@ static int ssl_parse_encrypted_pms( ssl_context *ssl,
return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE ); return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE );
} }
ssl_write_version( ssl->handshake->max_major_ver,
ssl->handshake->max_minor_ver,
ssl->transport, ver );
ret = pk_decrypt( ssl_own_key( ssl ), p, len, ret = pk_decrypt( ssl_own_key( ssl ), p, len,
pms, &ssl->handshake->pmslen, pms, &ssl->handshake->pmslen,
sizeof( ssl->handshake->premaster ) - pms_offset, sizeof( ssl->handshake->premaster ) - pms_offset,
ssl->f_rng, ssl->p_rng ); ssl->f_rng, ssl->p_rng );
if( ret != 0 || ssl->handshake->pmslen != 48 || if( ret != 0 || ssl->handshake->pmslen != 48 ||
pms[0] != ssl->handshake->max_major_ver || pms[0] != ver[0] ||
pms[1] != ssl->handshake->max_minor_ver ) pms[1] != ver[1] )
{ {
SSL_DEBUG_MSG( 1, ( "bad client key exchange message" ) ); SSL_DEBUG_MSG( 1, ( "bad client key exchange message" ) );

View File

@ -1126,8 +1126,8 @@ static int ssl_encrypt_buf( ssl_context *ssl )
memcpy( add_data, ssl->out_ctr, 8 ); memcpy( add_data, ssl->out_ctr, 8 );
add_data[8] = ssl->out_msgtype; add_data[8] = ssl->out_msgtype;
add_data[9] = ssl->major_ver; ssl_write_version( ssl->major_ver, ssl->minor_ver,
add_data[10] = ssl->minor_ver; ssl->transport, add_data + 9 );
add_data[11] = ( ssl->out_msglen >> 8 ) & 0xFF; add_data[11] = ( ssl->out_msglen >> 8 ) & 0xFF;
add_data[12] = ssl->out_msglen & 0xFF; add_data[12] = ssl->out_msglen & 0xFF;
@ -1377,8 +1377,8 @@ static int ssl_decrypt_buf( ssl_context *ssl )
memcpy( add_data, ssl->in_ctr, 8 ); memcpy( add_data, ssl->in_ctr, 8 );
add_data[8] = ssl->in_msgtype; add_data[8] = ssl->in_msgtype;
add_data[9] = ssl->major_ver; ssl_write_version( ssl->major_ver, ssl->minor_ver,
add_data[10] = ssl->minor_ver; ssl->transport, add_data + 9 );
add_data[11] = ( ssl->in_msglen >> 8 ) & 0xFF; add_data[11] = ( ssl->in_msglen >> 8 ) & 0xFF;
add_data[12] = ssl->in_msglen & 0xFF; add_data[12] = ssl->in_msglen & 0xFF;
@ -1937,8 +1937,8 @@ int ssl_write_record( ssl_context *ssl )
if( !done ) if( !done )
{ {
ssl->out_hdr[0] = (unsigned char) ssl->out_msgtype; ssl->out_hdr[0] = (unsigned char) ssl->out_msgtype;
ssl->out_hdr[1] = (unsigned char) ssl->major_ver; ssl_write_version( ssl->major_ver, ssl->minor_ver,
ssl->out_hdr[2] = (unsigned char) ssl->minor_ver; ssl->transport, ssl->out_hdr + 1 );
ssl->out_hdr[3] = (unsigned char)( len >> 8 ); ssl->out_hdr[3] = (unsigned char)( len >> 8 );
ssl->out_hdr[4] = (unsigned char)( len ); ssl->out_hdr[4] = (unsigned char)( len );
@ -1980,6 +1980,7 @@ int ssl_write_record( ssl_context *ssl )
int ssl_read_record( ssl_context *ssl ) int ssl_read_record( ssl_context *ssl )
{ {
int ret, done = 0; int ret, done = 0;
int major_ver, minor_ver;
SSL_DEBUG_MSG( 2, ( "=> read record" ) ); SSL_DEBUG_MSG( 2, ( "=> read record" ) );
@ -2038,13 +2039,15 @@ int ssl_read_record( ssl_context *ssl )
ssl->in_hdr[0], ssl->in_hdr[1], ssl->in_hdr[2], ssl->in_hdr[0], ssl->in_hdr[1], ssl->in_hdr[2],
( ssl->in_hdr[3] << 8 ) | ssl->in_hdr[4] ) ); ( ssl->in_hdr[3] << 8 ) | ssl->in_hdr[4] ) );
if( ssl->in_hdr[1] != ssl->major_ver ) ssl_read_version( &major_ver, &minor_ver, ssl->transport, ssl->in_hdr + 1 );
if( major_ver != ssl->major_ver )
{ {
SSL_DEBUG_MSG( 1, ( "major version mismatch" ) ); SSL_DEBUG_MSG( 1, ( "major version mismatch" ) );
return( POLARSSL_ERR_SSL_INVALID_RECORD ); return( POLARSSL_ERR_SSL_INVALID_RECORD );
} }
if( ssl->in_hdr[2] > ssl->max_minor_ver ) if( minor_ver > ssl->max_minor_ver )
{ {
SSL_DEBUG_MSG( 1, ( "minor version mismatch" ) ); SSL_DEBUG_MSG( 1, ( "minor version mismatch" ) );
return( POLARSSL_ERR_SSL_INVALID_RECORD ); return( POLARSSL_ERR_SSL_INVALID_RECORD );
@ -4947,4 +4950,53 @@ int ssl_check_cert_usage( const x509_crt *cert,
} }
#endif /* POLARSSL_X509_CRT_PARSE_C */ #endif /* POLARSSL_X509_CRT_PARSE_C */
/*
* Convert version numbers to/from wire format
* and, for DTLS, to/from TLS equivalent.
*
* For TLS this is the identity.
* For DTLS, use one complement (v -> 255 - v, and then map as follows:
* 1.0 <-> 3.2 (DTLS 1.0 is based on TLS 1.1)
* 1.x <-> 3.x+1 for x != 0 (DTLS 1.2 based on TLS 1.2)
*/
void ssl_write_version( int major, int minor, int transport,
unsigned char ver[2] )
{
if( transport == SSL_TRANSPORT_STREAM )
{
ver[0] = (unsigned char) major;
ver[1] = (unsigned char) minor;
}
#if defined(POLARSSL_SSL_PROTO_DTLS)
else
{
if( minor == SSL_MINOR_VERSION_2 )
--minor; /* DTLS 1.0 stored as TLS 1.1 internally */
ver[0] = (unsigned char)( 255 - ( major - 2 ) );
ver[1] = (unsigned char)( 255 - ( minor - 1 ) );
}
#endif
}
void ssl_read_version( int *major, int *minor, int transport,
const unsigned char ver[2] )
{
if( transport == SSL_TRANSPORT_STREAM )
{
*major = ver[0];
*minor = ver[1];
}
#if defined(POLARSSL_SSL_PROTO_DTLS)
else
{
*major = 255 - ver[0] + 2;
*minor = 255 - ver[1] + 1;
if( *minor == SSL_MINOR_VERSION_1 )
++*minor; /* DTLS 1.0 stored as TLS 1.1 internally */
}
#endif
}
#endif /* POLARSSL_SSL_TLS_C */ #endif /* POLARSSL_SSL_TLS_C */