diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 18982f89a..765da7a71 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -561,7 +561,7 @@ int mbedtls_ssl_read_record( mbedtls_ssl_context *ssl ); int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ); int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ); -int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl ); +int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush ); int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl ); int mbedtls_ssl_parse_certificate( mbedtls_ssl_context *ssl ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ad071a976..878495b17 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -100,6 +100,10 @@ static void ssl_update_out_pointers( mbedtls_ssl_context *ssl, mbedtls_ssl_transform *transform ); static void ssl_update_in_pointers( mbedtls_ssl_context *ssl, mbedtls_ssl_transform *transform ); + +#define SSL_DONT_FORCE_FLUSH 0 +#define SSL_FORCE_FLUSH 1 + #if defined(MBEDTLS_SSL_PROTO_DTLS) static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl ) @@ -112,6 +116,55 @@ static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl ) return( MBEDTLS_SSL_OUT_BUFFER_LEN ); } +static int ssl_get_remaining_space_in_datagram( mbedtls_ssl_context const *ssl ) +{ + size_t const bytes_written = ssl->out_left; + uint16_t const mtu = ssl_get_maximum_datagram_size( ssl ); + + /* Double-check that the write-index hasn't gone + * past what we can transmit in a single datagram. */ + if( bytes_written > (size_t) mtu ) + { + /* Should never happen... */ + return( MBEDTLS_ERR_SSL_INTERNAL_ERROR ); + } + + return( (int) ( mtu - bytes_written ) ); +} + +static int ssl_get_remaining_payload_in_datagram( mbedtls_ssl_context const *ssl ) +{ + int ret; + size_t remaining, expansion; + size_t max_len = MBEDTLS_SSL_MAX_CONTENT_LEN; + +#if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) + const size_t mfl = mbedtls_ssl_get_max_frag_len( ssl ); + + if( max_len > mfl ) + max_len = mfl; +#endif + + ret = ssl_get_remaining_space_in_datagram( ssl ); + if( ret < 0 ) + return( ret ); + remaining = (size_t) ret; + + ret = mbedtls_ssl_get_record_expansion( ssl ); + if( ret < 0 ) + return( ret ); + expansion = (size_t) ret; + + if( remaining <= expansion ) + return( 0 ); + + remaining -= expansion; + if( remaining >= max_len ) + remaining = max_len; + + return( (int) remaining ); +} + /* * Double the retransmit timeout value, within the allowed range, * returning -1 if the maximum value has already been reached. @@ -2857,20 +2910,9 @@ int mbedtls_ssl_resend( mbedtls_ssl_context *ssl ) */ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) { - const int ret_payload = mbedtls_ssl_get_max_out_record_payload( ssl ); - const size_t max_record_payload = (size_t) ret_payload; - /* DTLS handshake headers are 12 bytes */ - const size_t max_hs_fragment_len = max_record_payload - 12; - + int ret; MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_flight_transmit" ) ); - if( ret_payload < 0 ) - { - MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_get_max_out_record_payload", - ret_payload ); - return( ret_payload ); - } - if( ssl->handshake->retransmit_state != MBEDTLS_SSL_RETRANS_SENDING ) { MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) ); @@ -2884,22 +2926,38 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) while( ssl->handshake->cur_msg != NULL ) { - int ret; + size_t max_frag_len; const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg; + /* Swap epochs before sending Finished: we can't do it after * sending ChangeCipherSpec, in case write returns WANT_READ. * Must be done before copying, may change out_msg pointer */ if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE && - cur->p[0] == MBEDTLS_SSL_HS_FINISHED ) + cur->p[0] == MBEDTLS_SSL_HS_FINISHED && + ssl->handshake->cur_msg_p == ( cur->p + 12 ) ) { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) ); ssl_swap_epochs( ssl ); } + ret = ssl_get_remaining_payload_in_datagram( ssl ); + if( ret < 0 ) + return( ret ); + max_frag_len = (size_t) ret; + /* CCS is copied as is, while HS messages may need fragmentation */ if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) { + if( max_frag_len == 0 ) + { + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + continue; + } + memcpy( ssl->out_msg, cur->p, cur->len ); - ssl->out_msglen = cur->len; + ssl->out_msglen = cur->len; ssl->out_msgtype = cur->type; /* Update position inside current message */ @@ -2911,14 +2969,31 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) const size_t hs_len = cur->len - 12; const size_t frag_off = p - ( cur->p + 12 ); const size_t rem_len = hs_len - frag_off; - const size_t frag_len = rem_len > max_hs_fragment_len - ? max_hs_fragment_len : rem_len; + size_t cur_hs_frag_len, max_hs_frag_len; - if( frag_off == 0 && frag_len != hs_len ) + if( max_frag_len < 12 ) + { + if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE && + cur->p[0] == MBEDTLS_SSL_HS_FINISHED ) + { + ssl_swap_epochs( ssl ); + } + + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + + continue; + } + max_hs_frag_len = max_frag_len - 12; + + cur_hs_frag_len = rem_len > max_hs_frag_len ? + max_hs_frag_len : rem_len; + + if( frag_off == 0 && cur_hs_frag_len != hs_len ) { MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", - (unsigned) hs_len, - (unsigned) max_hs_fragment_len ) ); + (unsigned) cur_hs_frag_len, + (unsigned) max_hs_frag_len ) ); } /* Messages are stored with handshake headers as if not fragmented, @@ -2930,19 +3005,19 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) ssl->out_msg[7] = ( ( frag_off >> 8 ) & 0xff ); ssl->out_msg[8] = ( ( frag_off ) & 0xff ); - ssl->out_msg[ 9] = ( ( frag_len >> 16 ) & 0xff ); - ssl->out_msg[10] = ( ( frag_len >> 8 ) & 0xff ); - ssl->out_msg[11] = ( ( frag_len ) & 0xff ); + ssl->out_msg[ 9] = ( ( cur_hs_frag_len >> 16 ) & 0xff ); + ssl->out_msg[10] = ( ( cur_hs_frag_len >> 8 ) & 0xff ); + ssl->out_msg[11] = ( ( cur_hs_frag_len ) & 0xff ); MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); - /* Copy the handshake message content and set records fields */ - memcpy( ssl->out_msg + 12, p, frag_len ); - ssl->out_msglen = frag_len + 12; + /* Copy the handshame message content and set records fields */ + memcpy( ssl->out_msg + 12, p, cur_hs_frag_len ); + ssl->out_msglen = cur_hs_frag_len + 12; ssl->out_msgtype = cur->type; /* Update position inside current message */ - ssl->handshake->cur_msg_p += frag_len; + ssl->handshake->cur_msg_p += cur_hs_frag_len; } /* If done with the current message move to the next one if any */ @@ -2961,13 +3036,17 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) } /* Actually send the message out */ - if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) + if( ( ret = mbedtls_ssl_write_record( ssl, + SSL_DONT_FORCE_FLUSH ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); return( ret ); } } + if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) + return( ret ); + /* Update state and set timer */ if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER ) ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED; @@ -3158,7 +3237,7 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) else #endif { - if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) + if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret ); return( ret ); @@ -3182,10 +3261,11 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) * - ssl->out_msglen: length of the record content (excl headers) * - ssl->out_msg: record content */ -int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl ) +int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush ) { int ret, done = 0; size_t len = ssl->out_msglen; + uint8_t flush = force_flush; MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write record" ) ); @@ -3288,7 +3368,21 @@ int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl ) } } - if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) +#if defined(MBEDTLS_SSL_PROTO_DTLS) + if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) + { + size_t remaining = ssl_get_remaining_payload_in_datagram( ssl ); + if( remaining == 0 ) + flush = SSL_FORCE_FLUSH; + else + { + MBEDTLS_SSL_DEBUG_MSG( 2, ( "Stil %u bytes available in current datagram", (unsigned) remaining ) ); + } + } +#endif /* MBEDTLS_SSL_PROTO_DTLS */ + + if( ( flush == SSL_FORCE_FLUSH ) && + ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flush_output", ret ); return( ret ); @@ -4570,7 +4664,7 @@ int mbedtls_ssl_send_alert_message( mbedtls_ssl_context *ssl, ssl->out_msg[0] = level; ssl->out_msg[1] = message; - if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) + if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); return( ret ); @@ -7815,7 +7909,7 @@ static int ssl_write_real( mbedtls_ssl_context *ssl, ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA; memcpy( ssl->out_msg, buf, len ); - if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) + if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 ) { MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); return( ret );