From 91895853acee4e9b26ee482db70f2233fc7cf1a7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?= Date: Tue, 30 Jun 2015 13:34:45 +0200 Subject: [PATCH] Move from naked int to a structure in net.c Provides more flexibility for future changes/extensions. --- ChangeLog | 2 + include/mbedtls/net.h | 48 ++++++++++++++++------ library/net.c | 93 +++++++++++++++++++++++++------------------ 3 files changed, 91 insertions(+), 52 deletions(-) diff --git a/ChangeLog b/ChangeLog index 1f0152d26..7c238e54a 100644 --- a/ChangeLog +++ b/ChangeLog @@ -71,6 +71,8 @@ API Changes mbedtls_base64_decode() mbedtls_mpi_write_string() mbedtls_dhm_calc_secret() + * In the NET module, all "int" and "int *" arguments for file descriptors + changed type to "mbedtls_net_context *". * net_accept() gained new arguments for the size of the client_ip buffer. * In the threading layer, mbedtls_mutex_init() and mbedtls_mutex_free() now return void. diff --git a/include/mbedtls/net.h b/include/mbedtls/net.h index 92f94c0ba..d2cb8d48d 100644 --- a/include/mbedtls/net.h +++ b/include/mbedtls/net.h @@ -55,10 +55,31 @@ extern "C" { #endif +/** + * Wrapper type for sockets. + * + * Currently backed by just a file descriptor, but might be more in the future + * (eg two file descriptors for combined IPv4 + IPv6 support, or additional + * structures for hand-made UDP demultiplexing). + */ +typedef struct +{ + int fd; /**< The underlying file descriptor */ +} +mbedtls_net_context; + +/** + * \brief Initialize a context + * Just makes the context ready to be used or freed safely. + * + * \param ctx Context to initialize + */ +void mbedtls_net_init( mbedtls_net_context *ctx ); + /** * \brief Initiate a connection with host:port in the given protocol * - * \param fd Socket to use + * \param ctx Socket to use * \param host Host to connect to * \param port Port to connect to * \param proto Protocol: MBEDTLS_NET_PROTO_TCP or MBEDTLS_NET_PROTO_UDP @@ -70,13 +91,13 @@ extern "C" { * * \note Sets the socket in connected mode even with UDP. */ -int mbedtls_net_connect( int *fd, const char *host, const char *port, int proto ); +int mbedtls_net_connect( mbedtls_net_context *ctx, const char *host, const char *port, int proto ); /** * \brief Create a receiving socket on bind_ip:port in the chosen * protocol. If bind_ip == NULL, all interfaces are bound. * - * \param fd Socket to use + * \param ctx Socket to use * \param bind_ip IP to bind to, can be NULL * \param port Port number to use * \param proto Protocol: MBEDTLS_NET_PROTO_TCP or MBEDTLS_NET_PROTO_UDP @@ -89,13 +110,13 @@ int mbedtls_net_connect( int *fd, const char *host, const char *port, int proto * \note Regardless of the protocol, opens the sockets and binds it. * In addition, make the socket listening if protocol is TCP. */ -int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto ); +int mbedtls_net_bind( mbedtls_net_context *ctx, const char *bind_ip, const char *port, int proto ); /** * \brief Accept a connection from a remote client * - * \param bind_fd Relevant socket - * \param client_fd Will contain the connected client socket + * \param bind_ctx Relevant socket + * \param client_ctx Will contain the connected client socket * \param client_ip Will contain the client IP address * \param buf_size Size of the client_ip buffer * \param ip_len Will receive the size of the client IP written @@ -110,26 +131,27 @@ int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto * its descriptor to client_fd. New clients will not be able * to connect until you close the socket and bind a new one. */ -int mbedtls_net_accept( int bind_fd, int *client_fd, +int mbedtls_net_accept( mbedtls_net_context *bind_ctx, + mbedtls_net_context *client_ctx, void *client_ip, size_t buf_size, size_t *ip_len ); /** * \brief Set the socket blocking * - * \param fd Socket to set + * \param ctx Socket to set * * \return 0 if successful, or a non-zero error code */ -int mbedtls_net_set_block( int fd ); +int mbedtls_net_set_block( mbedtls_net_context *ctx ); /** * \brief Set the socket non-blocking * - * \param fd Socket to set + * \param ctx Socket to set * * \return 0 if successful, or a non-zero error code */ -int mbedtls_net_set_nonblock( int fd ); +int mbedtls_net_set_nonblock( mbedtls_net_context *ctx ); /** * \brief Portable usleep helper @@ -196,9 +218,9 @@ int mbedtls_net_recv_timeout( void *ctx, unsigned char *buf, size_t len, /** * \brief Gracefully shutdown the connection * - * \param fd The socket to close + * \param ctx The socket to close */ -void mbedtls_net_close( int fd ); +void mbedtls_net_close( mbedtls_net_context *ctx ); #ifdef __cplusplus } diff --git a/library/net.c b/library/net.c index b3928799d..0576ed6b1 100644 --- a/library/net.c +++ b/library/net.c @@ -110,10 +110,18 @@ static int net_prepare( void ) return( 0 ); } +/* + * Initialize a context + */ +void mbedtls_net_init( mbedtls_net_context *ctx ) +{ + ctx->fd = -1; +} + /* * Initiate a TCP connection with host:port and the given protocol */ -int mbedtls_net_connect( int *fd, const char *host, const char *port, int proto ) +int mbedtls_net_connect( mbedtls_net_context *ctx, const char *host, const char *port, int proto ) { int ret; struct addrinfo hints, *addr_list, *cur; @@ -134,21 +142,21 @@ int mbedtls_net_connect( int *fd, const char *host, const char *port, int proto ret = MBEDTLS_ERR_NET_UNKNOWN_HOST; for( cur = addr_list; cur != NULL; cur = cur->ai_next ) { - *fd = (int) socket( cur->ai_family, cur->ai_socktype, + ctx->fd = (int) socket( cur->ai_family, cur->ai_socktype, cur->ai_protocol ); - if( *fd < 0 ) + if( ctx->fd < 0 ) { ret = MBEDTLS_ERR_NET_SOCKET_FAILED; continue; } - if( connect( *fd, cur->ai_addr, cur->ai_addrlen ) == 0 ) + if( connect( ctx->fd, cur->ai_addr, cur->ai_addrlen ) == 0 ) { ret = 0; break; } - close( *fd ); + close( ctx->fd ); ret = MBEDTLS_ERR_NET_CONNECT_FAILED; } @@ -160,7 +168,7 @@ int mbedtls_net_connect( int *fd, const char *host, const char *port, int proto /* * Create a listening socket on bind_ip:port */ -int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto ) +int mbedtls_net_bind( mbedtls_net_context *ctx, const char *bind_ip, const char *port, int proto ) { int n, ret; struct addrinfo hints, *addr_list, *cur; @@ -183,26 +191,26 @@ int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto ret = MBEDTLS_ERR_NET_UNKNOWN_HOST; for( cur = addr_list; cur != NULL; cur = cur->ai_next ) { - *fd = (int) socket( cur->ai_family, cur->ai_socktype, + ctx->fd = (int) socket( cur->ai_family, cur->ai_socktype, cur->ai_protocol ); - if( *fd < 0 ) + if( ctx->fd < 0 ) { ret = MBEDTLS_ERR_NET_SOCKET_FAILED; continue; } n = 1; - if( setsockopt( *fd, SOL_SOCKET, SO_REUSEADDR, + if( setsockopt( ctx->fd, SOL_SOCKET, SO_REUSEADDR, (const char *) &n, sizeof( n ) ) != 0 ) { - close( *fd ); + close( ctx->fd ); ret = MBEDTLS_ERR_NET_SOCKET_FAILED; continue; } - if( bind( *fd, cur->ai_addr, cur->ai_addrlen ) != 0 ) + if( bind( ctx->fd, cur->ai_addr, cur->ai_addrlen ) != 0 ) { - close( *fd ); + close( ctx->fd ); ret = MBEDTLS_ERR_NET_BIND_FAILED; continue; } @@ -210,9 +218,9 @@ int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto /* Listen only makes sense for TCP */ if( proto == MBEDTLS_NET_PROTO_TCP ) { - if( listen( *fd, MBEDTLS_NET_LISTEN_BACKLOG ) != 0 ) + if( listen( ctx->fd, MBEDTLS_NET_LISTEN_BACKLOG ) != 0 ) { - close( *fd ); + close( ctx->fd ); ret = MBEDTLS_ERR_NET_LISTEN_FAILED; continue; } @@ -235,9 +243,9 @@ int mbedtls_net_bind( int *fd, const char *bind_ip, const char *port, int proto * Check if the requested operation would be blocking on a non-blocking socket * and thus 'failed' with a negative return value. */ -static int net_would_block( int fd ) +static int net_would_block( const mbedtls_net_context *ctx ) { - ((void) fd); + ((void) ctx); return( WSAGetLastError() == WSAEWOULDBLOCK ); } #else @@ -247,12 +255,12 @@ static int net_would_block( int fd ) * * Note: on a blocking socket this function always returns 0! */ -static int net_would_block( int fd ) +static int net_would_block( const mbedtls_net_context *ctx ) { /* * Never return 'WOULD BLOCK' on a non-blocking socket */ - if( ( fcntl( fd, F_GETFL ) & O_NONBLOCK ) != O_NONBLOCK ) + if( ( fcntl( ctx->fd, F_GETFL ) & O_NONBLOCK ) != O_NONBLOCK ) return( 0 ); switch( errno ) @@ -272,7 +280,8 @@ static int net_would_block( int fd ) /* * Accept a connection from a remote client */ -int mbedtls_net_accept( int bind_fd, int *client_fd, +int mbedtls_net_accept( mbedtls_net_context *bind_ctx, + mbedtls_net_context *client_ctx, void *client_ip, size_t buf_size, size_t *ip_len ) { int ret; @@ -290,7 +299,8 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, #endif /* Is this a TCP or UDP socket? */ - if( getsockopt( bind_fd, SOL_SOCKET, SO_TYPE, (void *) &type, &type_len ) != 0 || + if( getsockopt( bind_ctx->fd, SOL_SOCKET, SO_TYPE, + (void *) &type, &type_len ) != 0 || ( type != SOCK_STREAM && type != SOCK_DGRAM ) ) { return( MBEDTLS_ERR_NET_ACCEPT_FAILED ); @@ -299,7 +309,7 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, if( type == SOCK_STREAM ) { /* TCP: actual accept() */ - ret = *client_fd = (int) accept( bind_fd, + ret = client_ctx->fd = (int) accept( bind_ctx->fd, (struct sockaddr *) &client_addr, &n ); } else @@ -307,7 +317,7 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, /* UDP: wait for a message, but keep it in the queue */ char buf[1] = { 0 }; - ret = recvfrom( bind_fd, buf, sizeof( buf ), MSG_PEEK, + ret = recvfrom( bind_ctx->fd, buf, sizeof( buf ), MSG_PEEK, (struct sockaddr *) &client_addr, &n ); #if defined(_WIN32) @@ -322,7 +332,7 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, if( ret < 0 ) { - if( net_would_block( bind_fd ) != 0 ) + if( net_would_block( bind_ctx ) != 0 ) return( MBEDTLS_ERR_SSL_WANT_READ ); return( MBEDTLS_ERR_NET_ACCEPT_FAILED ); @@ -331,10 +341,10 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, /* UDP: hijack the listening socket for communicating with the client */ if( type != SOCK_STREAM ) { - if( connect( bind_fd, (struct sockaddr *) &client_addr, n ) != 0 ) + if( connect( bind_ctx->fd, (struct sockaddr *) &client_addr, n ) != 0 ) return( MBEDTLS_ERR_NET_ACCEPT_FAILED ); - *client_fd = bind_fd; + client_ctx->fd = bind_ctx->fd; } if( client_ip != NULL ) @@ -367,25 +377,25 @@ int mbedtls_net_accept( int bind_fd, int *client_fd, /* * Set the socket blocking or non-blocking */ -int mbedtls_net_set_block( int fd ) +int mbedtls_net_set_block( mbedtls_net_context *ctx ) { #if ( defined(_WIN32) || defined(_WIN32_WCE) ) && !defined(EFIX64) && \ !defined(EFI32) u_long n = 0; - return( ioctlsocket( fd, FIONBIO, &n ) ); + return( ioctlsocket( ctx->fd, FIONBIO, &n ) ); #else - return( fcntl( fd, F_SETFL, fcntl( fd, F_GETFL ) & ~O_NONBLOCK ) ); + return( fcntl( ctx->fd, F_SETFL, fcntl( ctx->fd, F_GETFL ) & ~O_NONBLOCK ) ); #endif } -int mbedtls_net_set_nonblock( int fd ) +int mbedtls_net_set_nonblock( mbedtls_net_context *ctx ) { #if ( defined(_WIN32) || defined(_WIN32_WCE) ) && !defined(EFIX64) && \ !defined(EFI32) u_long n = 1; - return( ioctlsocket( fd, FIONBIO, &n ) ); + return( ioctlsocket( ctx->fd, FIONBIO, &n ) ); #else - return( fcntl( fd, F_SETFL, fcntl( fd, F_GETFL ) | O_NONBLOCK ) ); + return( fcntl( ctx->fd, F_SETFL, fcntl( ctx->fd, F_GETFL ) | O_NONBLOCK ) ); #endif } @@ -410,12 +420,12 @@ void mbedtls_net_usleep( unsigned long usec ) */ int mbedtls_net_recv( void *ctx, unsigned char *buf, size_t len ) { - int fd = *((int *) ctx); + int fd = ((mbedtls_net_context *) ctx)->fd; int ret = (int) read( fd, buf, len ); if( ret < 0 ) { - if( net_would_block( fd ) != 0 ) + if( net_would_block( ctx ) != 0 ) return( MBEDTLS_ERR_SSL_WANT_READ ); #if ( defined(_WIN32) || defined(_WIN32_WCE) ) && !defined(EFIX64) && \ @@ -445,7 +455,7 @@ int mbedtls_net_recv_timeout( void *ctx, unsigned char *buf, size_t len, int ret; struct timeval tv; fd_set read_fds; - int fd = *((int *) ctx); + int fd = ((mbedtls_net_context *) ctx)->fd; FD_ZERO( &read_fds ); FD_SET( fd, &read_fds ); @@ -482,12 +492,12 @@ int mbedtls_net_recv_timeout( void *ctx, unsigned char *buf, size_t len, */ int mbedtls_net_send( void *ctx, const unsigned char *buf, size_t len ) { - int fd = *((int *) ctx); + int fd = ((mbedtls_net_context *) ctx)->fd; int ret = (int) write( fd, buf, len ); if( ret < 0 ) { - if( net_would_block( fd ) != 0 ) + if( net_would_block( ctx ) != 0 ) return( MBEDTLS_ERR_SSL_WANT_WRITE ); #if ( defined(_WIN32) || defined(_WIN32_WCE) ) && !defined(EFIX64) && \ @@ -511,10 +521,15 @@ int mbedtls_net_send( void *ctx, const unsigned char *buf, size_t len ) /* * Gracefully close the connection */ -void mbedtls_net_close( int fd ) +void mbedtls_net_close( mbedtls_net_context *ctx ) { - shutdown( fd, 2 ); - close( fd ); + if( ctx->fd == -1 ) + return; + + shutdown( ctx->fd, 2 ); + close( ctx->fd ); + + ctx->fd = -1; } #endif /* MBEDTLS_NET_C */