Don't immediately flush datagram after preparing a record

This commit finally enables datagram packing by modifying the
record preparation function ssl_write_record() to not always
calling mbedtls_ssl_flush_output().
This commit is contained in:
Hanno Becker 2018-08-06 11:33:50 +01:00
parent 2b1e354754
commit 67bc7c3a38
2 changed files with 128 additions and 34 deletions

View File

@ -561,7 +561,7 @@ int mbedtls_ssl_read_record( mbedtls_ssl_context *ssl );
int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ); int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want );
int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ); int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl );
int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl ); int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush );
int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl ); int mbedtls_ssl_flush_output( mbedtls_ssl_context *ssl );
int mbedtls_ssl_parse_certificate( mbedtls_ssl_context *ssl ); int mbedtls_ssl_parse_certificate( mbedtls_ssl_context *ssl );

View File

@ -100,6 +100,10 @@ static void ssl_update_out_pointers( mbedtls_ssl_context *ssl,
mbedtls_ssl_transform *transform ); mbedtls_ssl_transform *transform );
static void ssl_update_in_pointers( mbedtls_ssl_context *ssl, static void ssl_update_in_pointers( mbedtls_ssl_context *ssl,
mbedtls_ssl_transform *transform ); mbedtls_ssl_transform *transform );
#define SSL_DONT_FORCE_FLUSH 0
#define SSL_FORCE_FLUSH 1
#if defined(MBEDTLS_SSL_PROTO_DTLS) #if defined(MBEDTLS_SSL_PROTO_DTLS)
static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl ) static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl )
@ -112,6 +116,55 @@ static uint16_t ssl_get_maximum_datagram_size( mbedtls_ssl_context const *ssl )
return( MBEDTLS_SSL_OUT_BUFFER_LEN ); return( MBEDTLS_SSL_OUT_BUFFER_LEN );
} }
static int ssl_get_remaining_space_in_datagram( mbedtls_ssl_context const *ssl )
{
size_t const bytes_written = ssl->out_left;
uint16_t const mtu = ssl_get_maximum_datagram_size( ssl );
/* Double-check that the write-index hasn't gone
* past what we can transmit in a single datagram. */
if( bytes_written > (size_t) mtu )
{
/* Should never happen... */
return( MBEDTLS_ERR_SSL_INTERNAL_ERROR );
}
return( (int) ( mtu - bytes_written ) );
}
static int ssl_get_remaining_payload_in_datagram( mbedtls_ssl_context const *ssl )
{
int ret;
size_t remaining, expansion;
size_t max_len = MBEDTLS_SSL_MAX_CONTENT_LEN;
#if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH)
const size_t mfl = mbedtls_ssl_get_max_frag_len( ssl );
if( max_len > mfl )
max_len = mfl;
#endif
ret = ssl_get_remaining_space_in_datagram( ssl );
if( ret < 0 )
return( ret );
remaining = (size_t) ret;
ret = mbedtls_ssl_get_record_expansion( ssl );
if( ret < 0 )
return( ret );
expansion = (size_t) ret;
if( remaining <= expansion )
return( 0 );
remaining -= expansion;
if( remaining >= max_len )
remaining = max_len;
return( (int) remaining );
}
/* /*
* Double the retransmit timeout value, within the allowed range, * Double the retransmit timeout value, within the allowed range,
* returning -1 if the maximum value has already been reached. * returning -1 if the maximum value has already been reached.
@ -2857,20 +2910,9 @@ int mbedtls_ssl_resend( mbedtls_ssl_context *ssl )
*/ */
int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl ) int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
{ {
const int ret_payload = mbedtls_ssl_get_max_out_record_payload( ssl ); int ret;
const size_t max_record_payload = (size_t) ret_payload;
/* DTLS handshake headers are 12 bytes */
const size_t max_hs_fragment_len = max_record_payload - 12;
MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_flight_transmit" ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> mbedtls_ssl_flight_transmit" ) );
if( ret_payload < 0 )
{
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_get_max_out_record_payload",
ret_payload );
return( ret_payload );
}
if( ssl->handshake->retransmit_state != MBEDTLS_SSL_RETRANS_SENDING ) if( ssl->handshake->retransmit_state != MBEDTLS_SSL_RETRANS_SENDING )
{ {
MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "initialise flight transmission" ) );
@ -2884,22 +2926,38 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
while( ssl->handshake->cur_msg != NULL ) while( ssl->handshake->cur_msg != NULL )
{ {
int ret; size_t max_frag_len;
const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg; const mbedtls_ssl_flight_item * const cur = ssl->handshake->cur_msg;
/* Swap epochs before sending Finished: we can't do it after /* Swap epochs before sending Finished: we can't do it after
* sending ChangeCipherSpec, in case write returns WANT_READ. * sending ChangeCipherSpec, in case write returns WANT_READ.
* Must be done before copying, may change out_msg pointer */ * Must be done before copying, may change out_msg pointer */
if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE && if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE &&
cur->p[0] == MBEDTLS_SSL_HS_FINISHED ) cur->p[0] == MBEDTLS_SSL_HS_FINISHED &&
ssl->handshake->cur_msg_p == ( cur->p + 12 ) )
{ {
MBEDTLS_SSL_DEBUG_MSG( 2, ( "swap epochs to send finished message" ) );
ssl_swap_epochs( ssl ); ssl_swap_epochs( ssl );
} }
ret = ssl_get_remaining_payload_in_datagram( ssl );
if( ret < 0 )
return( ret );
max_frag_len = (size_t) ret;
/* CCS is copied as is, while HS messages may need fragmentation */ /* CCS is copied as is, while HS messages may need fragmentation */
if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC ) if( cur->type == MBEDTLS_SSL_MSG_CHANGE_CIPHER_SPEC )
{ {
if( max_frag_len == 0 )
{
if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
return( ret );
continue;
}
memcpy( ssl->out_msg, cur->p, cur->len ); memcpy( ssl->out_msg, cur->p, cur->len );
ssl->out_msglen = cur->len; ssl->out_msglen = cur->len;
ssl->out_msgtype = cur->type; ssl->out_msgtype = cur->type;
/* Update position inside current message */ /* Update position inside current message */
@ -2911,14 +2969,31 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
const size_t hs_len = cur->len - 12; const size_t hs_len = cur->len - 12;
const size_t frag_off = p - ( cur->p + 12 ); const size_t frag_off = p - ( cur->p + 12 );
const size_t rem_len = hs_len - frag_off; const size_t rem_len = hs_len - frag_off;
const size_t frag_len = rem_len > max_hs_fragment_len size_t cur_hs_frag_len, max_hs_frag_len;
? max_hs_fragment_len : rem_len;
if( frag_off == 0 && frag_len != hs_len ) if( max_frag_len < 12 )
{
if( cur->type == MBEDTLS_SSL_MSG_HANDSHAKE &&
cur->p[0] == MBEDTLS_SSL_HS_FINISHED )
{
ssl_swap_epochs( ssl );
}
if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
return( ret );
continue;
}
max_hs_frag_len = max_frag_len - 12;
cur_hs_frag_len = rem_len > max_hs_frag_len ?
max_hs_frag_len : rem_len;
if( frag_off == 0 && cur_hs_frag_len != hs_len )
{ {
MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)", MBEDTLS_SSL_DEBUG_MSG( 2, ( "fragmenting handshake message (%u > %u)",
(unsigned) hs_len, (unsigned) cur_hs_frag_len,
(unsigned) max_hs_fragment_len ) ); (unsigned) max_hs_frag_len ) );
} }
/* Messages are stored with handshake headers as if not fragmented, /* Messages are stored with handshake headers as if not fragmented,
@ -2930,19 +3005,19 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
ssl->out_msg[7] = ( ( frag_off >> 8 ) & 0xff ); ssl->out_msg[7] = ( ( frag_off >> 8 ) & 0xff );
ssl->out_msg[8] = ( ( frag_off ) & 0xff ); ssl->out_msg[8] = ( ( frag_off ) & 0xff );
ssl->out_msg[ 9] = ( ( frag_len >> 16 ) & 0xff ); ssl->out_msg[ 9] = ( ( cur_hs_frag_len >> 16 ) & 0xff );
ssl->out_msg[10] = ( ( frag_len >> 8 ) & 0xff ); ssl->out_msg[10] = ( ( cur_hs_frag_len >> 8 ) & 0xff );
ssl->out_msg[11] = ( ( frag_len ) & 0xff ); ssl->out_msg[11] = ( ( cur_hs_frag_len ) & 0xff );
MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 ); MBEDTLS_SSL_DEBUG_BUF( 3, "handshake header", ssl->out_msg, 12 );
/* Copy the handshake message content and set records fields */ /* Copy the handshame message content and set records fields */
memcpy( ssl->out_msg + 12, p, frag_len ); memcpy( ssl->out_msg + 12, p, cur_hs_frag_len );
ssl->out_msglen = frag_len + 12; ssl->out_msglen = cur_hs_frag_len + 12;
ssl->out_msgtype = cur->type; ssl->out_msgtype = cur->type;
/* Update position inside current message */ /* Update position inside current message */
ssl->handshake->cur_msg_p += frag_len; ssl->handshake->cur_msg_p += cur_hs_frag_len;
} }
/* If done with the current message move to the next one if any */ /* If done with the current message move to the next one if any */
@ -2961,13 +3036,17 @@ int mbedtls_ssl_flight_transmit( mbedtls_ssl_context *ssl )
} }
/* Actually send the message out */ /* Actually send the message out */
if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) if( ( ret = mbedtls_ssl_write_record( ssl,
SSL_DONT_FORCE_FLUSH ) ) != 0 )
{ {
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
return( ret ); return( ret );
} }
} }
if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
return( ret );
/* Update state and set timer */ /* Update state and set timer */
if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER ) if( ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER )
ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED; ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_FINISHED;
@ -3158,7 +3237,7 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl )
else else
#endif #endif
{ {
if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
{ {
MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret ); MBEDTLS_SSL_DEBUG_RET( 1, "ssl_write_record", ret );
return( ret ); return( ret );
@ -3182,10 +3261,11 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl )
* - ssl->out_msglen: length of the record content (excl headers) * - ssl->out_msglen: length of the record content (excl headers)
* - ssl->out_msg: record content * - ssl->out_msg: record content
*/ */
int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl ) int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl, uint8_t force_flush )
{ {
int ret, done = 0; int ret, done = 0;
size_t len = ssl->out_msglen; size_t len = ssl->out_msglen;
uint8_t flush = force_flush;
MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write record" ) ); MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> write record" ) );
@ -3288,7 +3368,21 @@ int mbedtls_ssl_write_record( mbedtls_ssl_context *ssl )
} }
} }
if( ( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 ) #if defined(MBEDTLS_SSL_PROTO_DTLS)
if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM )
{
size_t remaining = ssl_get_remaining_payload_in_datagram( ssl );
if( remaining == 0 )
flush = SSL_FORCE_FLUSH;
else
{
MBEDTLS_SSL_DEBUG_MSG( 2, ( "Stil %u bytes available in current datagram", (unsigned) remaining ) );
}
}
#endif /* MBEDTLS_SSL_PROTO_DTLS */
if( ( flush == SSL_FORCE_FLUSH ) &&
( ret = mbedtls_ssl_flush_output( ssl ) ) != 0 )
{ {
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flush_output", ret ); MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_flush_output", ret );
return( ret ); return( ret );
@ -4570,7 +4664,7 @@ int mbedtls_ssl_send_alert_message( mbedtls_ssl_context *ssl,
ssl->out_msg[0] = level; ssl->out_msg[0] = level;
ssl->out_msg[1] = message; ssl->out_msg[1] = message;
if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
{ {
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
return( ret ); return( ret );
@ -7815,7 +7909,7 @@ static int ssl_write_real( mbedtls_ssl_context *ssl,
ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA; ssl->out_msgtype = MBEDTLS_SSL_MSG_APPLICATION_DATA;
memcpy( ssl->out_msg, buf, len ); memcpy( ssl->out_msg, buf, len );
if( ( ret = mbedtls_ssl_write_record( ssl ) ) != 0 ) if( ( ret = mbedtls_ssl_write_record( ssl, SSL_FORCE_FLUSH ) ) != 0 )
{ {
MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret ); MBEDTLS_SSL_DEBUG_RET( 1, "mbedtls_ssl_write_record", ret );
return( ret ); return( ret );