diff --git a/library/ssl_msg.c b/library/ssl_msg.c index 981d94e16..bdf882d87 100644 --- a/library/ssl_msg.c +++ b/library/ssl_msg.c @@ -2018,14 +2018,6 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) { uint32_t timeout; - /* Just to be sure */ - if( ssl->f_set_timer == NULL || ssl->f_get_timer == NULL ) - { - MBEDTLS_SSL_DEBUG_MSG( 1, ( "You must use " - "mbedtls_ssl_set_timer_cb() for DTLS" ) ); - return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - } - /* * The point is, we need to always read a full datagram at once, so we * sometimes read more then requested, and handle the additional data. diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 7062d53b7..34953f269 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -5682,11 +5682,24 @@ int mbedtls_ssl_handshake( mbedtls_ssl_context *ssl ) { int ret = 0; + /* Sanity checks */ + if( ssl == NULL || ssl->conf == NULL ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); +#if defined(MBEDTLS_SSL_PROTO_DTLS) + if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM && + ( ssl->f_set_timer == NULL || ssl->f_get_timer == NULL ) ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "You must use " + "mbedtls_ssl_set_timer_cb() for DTLS" ) ); + return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); + } +#endif /* MBEDTLS_SSL_PROTO_DTLS */ + MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> handshake" ) ); + /* Main handshake loop */ while( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) { ret = mbedtls_ssl_handshake_step( ssl );