diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index bf5d7243d..012b89a70 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -1762,10 +1762,6 @@ int ssl_send_fatal_handshake_failure( ssl_context *ssl ); int ssl_derive_keys( ssl_context *ssl ); int ssl_read_record( ssl_context *ssl ); -/** - * \return 0 if successful, POLARSSL_ERR_SSL_CONN_EOF on EOF or - * another negative error code. - */ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ); int ssl_write_record( ssl_context *ssl ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index e009b9717..e44ffa6ee 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1821,6 +1821,13 @@ static int ssl_decompress_buf( ssl_context *ssl ) /* * Fill the input message buffer + * + * If we return 0, is it guaranteed that (at least) nb_want bytes are + * available (from this read and/or a previous one). Otherwise, an error code + * is returned (possibly EOF or WANT_READ). + * + * Set ssl->in_left to 0 before calling to start a new record. Apart from + * this, ssl->in_left is an internal variable and should never be read. */ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ) { @@ -1829,19 +1836,40 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ) SSL_DEBUG_MSG( 2, ( "=> fetch input" ) ); - if( nb_want > SSL_BUFFER_LEN - 8 ) + if( nb_want > SSL_BUFFER_LEN - (size_t)( ssl->in_hdr - ssl->in_buf ) ) { SSL_DEBUG_MSG( 1, ( "requesting more data than fits" ) ); return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); } - while( ssl->in_left < nb_want ) +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) { - len = nb_want - ssl->in_left; - ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len ); - SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d", ssl->in_left, nb_want ) ); + + /* + * With UDP, we must always read a full datagram. + * Just remember how much we read and avoid reading again if we + * already have enough data. + */ + if( nb_want <= ssl->in_left) + return( 0 ); + + /* + * A record can't be split accross datagrams. If we need to read but + * are not at the beginning of a new record, the caller did something + * wrong. + */ + if( ssl->in_left != 0 ) + { + SSL_DEBUG_MSG( 1, ( "should never happen" ) ); + return( POLARSSL_ERR_SSL_INTERNAL_ERROR ); + } + + len = SSL_BUFFER_LEN - ( ssl->in_hdr - ssl->in_buf ); + ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr, len ); + SSL_DEBUG_RET( 2, "ssl->f_recv", ret ); if( ret == 0 ) @@ -1850,7 +1878,28 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ) if( ret < 0 ) return( ret ); - ssl->in_left += ret; + ssl->in_left = ret; + } + else +#endif + { + while( ssl->in_left < nb_want ) + { + len = nb_want - ssl->in_left; + ret = ssl->f_recv( ssl->p_recv, ssl->in_hdr + ssl->in_left, len ); + + SSL_DEBUG_MSG( 2, ( "in_left: %d, nb_want: %d", + ssl->in_left, nb_want ) ); + SSL_DEBUG_RET( 2, "ssl->f_recv", ret ); + + if( ret == 0 ) + return( POLARSSL_ERR_SSL_CONN_EOF ); + + if( ret < 0 ) + return( ret ); + + ssl->in_left += ret; + } } SSL_DEBUG_MSG( 2, ( "<= fetch input" ) ); @@ -2140,7 +2189,8 @@ int ssl_read_record( ssl_context *ssl ) } /* Sanity check (outer boundaries) */ - if( ssl->in_msglen < 1 || ssl->in_msglen > SSL_BUFFER_LEN - 13 ) + if( ssl->in_msglen < 1 || + ssl->in_msglen > SSL_BUFFER_LEN - (size_t)( ssl->in_msg - ssl->in_buf ) ) { SSL_DEBUG_MSG( 1, ( "bad message length" ) ); return( POLARSSL_ERR_SSL_INVALID_RECORD );