diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index 747523b5f..fe9ed6b2e 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -621,7 +621,7 @@ static void my_debug( void *ctx, int level, * Test recv/send functions that make sure each try returns * WANT_READ/WANT_WRITE at least once before sucesseding */ -static int my_recv( void *ctx, unsigned char *buf, size_t len ) +static int delayed_recv( void *ctx, unsigned char *buf, size_t len ) { static int first_try = 1; int ret; @@ -638,7 +638,7 @@ static int my_recv( void *ctx, unsigned char *buf, size_t len ) return( ret ); } -static int my_send( void *ctx, const unsigned char *buf, size_t len ) +static int delayed_send( void *ctx, const unsigned char *buf, size_t len ) { static int first_try = 1; int ret; @@ -658,6 +658,70 @@ static int my_send( void *ctx, const unsigned char *buf, size_t len ) MBEDTLS_SSL_CONF_SEND && MBEDTLS_SSL_CONF_RECV_TIMEOUT */ +typedef struct +{ + mbedtls_ssl_context *ssl; + mbedtls_net_context *net; +} io_ctx_t; + +static int recv_cb( void *ctx, unsigned char *buf, size_t len ) +{ + io_ctx_t *io_ctx = (io_ctx_t*) ctx; + size_t recv_len; + int ret; + + if( opt.nbio == 2 ) + ret = delayed_recv( io_ctx->net, buf, len ); + else + ret = mbedtls_net_recv( io_ctx->net, buf, len ); + if( ret < 0 ) + return( ret ); + recv_len = (size_t) ret; + + if( opt.transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) + { + /* Here's the place to do any datagram/record checking + * in between receiving the packet from the underlying + * transport and passing it on to the TLS stack. */ + mbedtls_printf( "[RECV] Datagram of size %u\n", (unsigned) recv_len ); + } + + return( (int) recv_len ); +} + +static int recv_timeout_cb( void *ctx, unsigned char *buf, size_t len, + uint32_t timeout ) +{ + io_ctx_t *io_ctx = (io_ctx_t*) ctx; + int ret; + size_t recv_len; + + ret = mbedtls_net_recv_timeout( io_ctx->net, buf, len, timeout ); + if( ret < 0 ) + return( ret ); + recv_len = (size_t) ret; + + if( opt.transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM ) + { + /* Here's the place to do any datagram/record checking + * in between receiving the packet from the underlying + * transport and passing it on to the TLS stack. */ + mbedtls_printf( "[RECV] Datagram of size %u\n", (unsigned) recv_len ); + } + + return( (int) recv_len ); +} + +static int send_cb( void *ctx, unsigned char const *buf, size_t len ) +{ + io_ctx_t *io_ctx = (io_ctx_t*) ctx; + + if( opt.nbio == 2 ) + return( delayed_send( io_ctx->net, buf, len ) ); + + return( mbedtls_net_send( io_ctx->net, buf, len ) ); +} + #if !defined(MBEDTLS_SSL_CONF_AUTHMODE) /* * Return authmode from string, or -1 on error @@ -1376,6 +1440,7 @@ int main( int argc, char *argv[] ) { int ret = 0, len, written, frags, exchanges_left; int version_suites[4][2]; + io_ctx_t io_ctx; unsigned char* buf = 0; #if defined(MBEDTLS_KEY_EXCHANGE__SOME__PSK_ENABLED) unsigned char psk[MBEDTLS_PSK_MAX_LEN]; @@ -2918,12 +2983,10 @@ int main( int argc, char *argv[] ) #if !defined(MBEDTLS_SSL_CONF_RECV) && \ !defined(MBEDTLS_SSL_CONF_SEND) && \ !defined(MBEDTLS_SSL_CONF_RECV_TIMEOUT) - if( opt.nbio == 2 ) - mbedtls_ssl_set_bio( &ssl, &client_fd, my_send, my_recv, NULL ); - else - mbedtls_ssl_set_bio( &ssl, &client_fd, - mbedtls_net_send, mbedtls_net_recv, - opt.nbio == 0 ? mbedtls_net_recv_timeout : NULL ); + io_ctx.ssl = &ssl; + io_ctx.net = &client_fd; + mbedtls_ssl_set_bio( &ssl, &io_ctx, send_cb, recv_cb, + opt.nbio == 0 ? recv_timeout_cb : NULL ); #else mbedtls_ssl_set_bio_ctx( &ssl, &client_fd ); #endif @@ -3586,13 +3649,8 @@ data_exchange: * if you want to share your set up code between the case of * establishing a new connection and this case. */ - if( opt.nbio == 2 ) - mbedtls_ssl_set_bio( &ssl, &client_fd, my_send, my_recv, - NULL ); - else - mbedtls_ssl_set_bio( &ssl, &client_fd, mbedtls_net_send, - mbedtls_net_recv, - opt.nbio == 0 ? mbedtls_net_recv_timeout : NULL ); + mbedtls_ssl_set_bio( &ssl, &io_ctx, send_cb, recv_cb, + opt.nbio == 0 ? recv_timeout_cb : NULL ); #if defined(MBEDTLS_TIMING_C) #if !defined(MBEDTLS_SSL_CONF_SET_TIMER) && \