diff --git a/include/polarssl/net.h b/include/polarssl/net.h index 22698b4ce..4271a7ce4 100644 --- a/include/polarssl/net.h +++ b/include/polarssl/net.h @@ -43,38 +43,48 @@ #define POLARSSL_NET_LISTEN_BACKLOG 10 /**< The backlog that listen() should use. */ +#define NET_PROTO_TCP 0 /**< The TCP transport protocol */ +#define NET_PROTO_UDP 1 /**< The UDP transport protocol */ + #ifdef __cplusplus extern "C" { #endif /** - * \brief Initiate a TCP connection with host:port + * \brief Initiate a connection with host:port in the given protocol * * \param fd Socket to use * \param host Host to connect to * \param port Port to connect to + * \param proto Protocol: NET_PROTO_TCP or NET_PROTO_UDP * * \return 0 if successful, or one of: * POLARSSL_ERR_NET_SOCKET_FAILED, * POLARSSL_ERR_NET_UNKNOWN_HOST, * POLARSSL_ERR_NET_CONNECT_FAILED + * + * \note Sets the socket in connected mode even with UDP. */ -int net_connect( int *fd, const char *host, int port ); +int net_connect( int *fd, const char *host, int port, int proto ); /** - * \brief Create a listening socket on bind_ip:port. - * If bind_ip == NULL, all interfaces are binded. + * \brief Create a receiving socket on bind_ip:port in the chosen + * protocol. If bind_ip == NULL, all interfaces are bound. * * \param fd Socket to use * \param bind_ip IP to bind to, can be NULL * \param port Port number to use + * \param proto Protocol: NET_PROTO_TCP or NET_PROTO_UDP * * \return 0 if successful, or one of: * POLARSSL_ERR_NET_SOCKET_FAILED, * POLARSSL_ERR_NET_BIND_FAILED, * POLARSSL_ERR_NET_LISTEN_FAILED + * + * \note Regardless of the protocol, opens the sockets and binds it. + * In addition, make the socket listening if protocol is TCP. */ -int net_bind( int *fd, const char *bind_ip, int port ); +int net_bind( int *fd, const char *bind_ip, int port, int proto ); /** * \brief Accept a connection from a remote client @@ -87,6 +97,10 @@ int net_bind( int *fd, const char *bind_ip, int port ); * \return 0 if successful, POLARSSL_ERR_NET_ACCEPT_FAILED, or * POLARSSL_ERR_NET_WANT_READ is bind_fd was set to * non-blocking and accept() is blocking. + * + * \note With UDP, connects the bind_fd to the client and just copy + * its descriptor to client_fd. New clients will not be able + * to connect until you close the socket and bind a new one. */ int net_accept( int bind_fd, int *client_fd, void *client_ip ); diff --git a/library/net.c b/library/net.c index 3f0e448ba..16dcab21c 100644 --- a/library/net.c +++ b/library/net.c @@ -160,9 +160,9 @@ static int net_prepare( void ) } /* - * Initiate a TCP connection with host:port + * Initiate a TCP connection with host:port and the given protocol */ -int net_connect( int *fd, const char *host, int port ) +int net_connect( int *fd, const char *host, int port, int proto ) { #if defined(POLARSSL_HAVE_IPV6) int ret; @@ -176,11 +176,11 @@ int net_connect( int *fd, const char *host, int port ) memset( port_str, 0, sizeof( port_str ) ); snprintf( port_str, sizeof( port_str ), "%d", port ); - /* Do name resolution with both IPv6 and IPv4, but only TCP */ + /* Do name resolution with both IPv6 and IPv4 */ memset( &hints, 0, sizeof( hints ) ); hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; + hints.ai_socktype = proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM; + hints.ai_protocol = proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP; if( getaddrinfo( host, port_str, &hints, &addr_list ) != 0 ) return( POLARSSL_ERR_NET_UNKNOWN_HOST ); @@ -224,7 +224,9 @@ int net_connect( int *fd, const char *host, int port ) if( ( server_host = gethostbyname( host ) ) == NULL ) return( POLARSSL_ERR_NET_UNKNOWN_HOST ); - if( ( *fd = (int) socket( AF_INET, SOCK_STREAM, IPPROTO_IP ) ) < 0 ) + if( ( *fd = (int) socket( AF_INET, + proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM, + proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP ) ) < 0 ) return( POLARSSL_ERR_NET_SOCKET_FAILED ); memcpy( (void *) &server_addr.sin_addr, @@ -248,7 +250,7 @@ int net_connect( int *fd, const char *host, int port ) /* * Create a listening socket on bind_ip:port */ -int net_bind( int *fd, const char *bind_ip, int port ) +int net_bind( int *fd, const char *bind_ip, int port, int proto ) { #if defined(POLARSSL_HAVE_IPV6) int n, ret; @@ -265,8 +267,8 @@ int net_bind( int *fd, const char *bind_ip, int port ) /* Bind to IPv6 and/or IPv4, but only in TCP */ memset( &hints, 0, sizeof( hints ) ); hints.ai_family = AF_UNSPEC; - hints.ai_socktype = SOCK_STREAM; - hints.ai_protocol = IPPROTO_TCP; + hints.ai_socktype = proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM; + hints.ai_protocol = proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP; if( bind_ip == NULL ) hints.ai_flags = AI_PASSIVE; @@ -301,11 +303,15 @@ int net_bind( int *fd, const char *bind_ip, int port ) continue; } - if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) + /* Listen only makes sense for TCP */ + if( proto == NET_PROTO_TCP ) { - close( *fd ); - ret = POLARSSL_ERR_NET_LISTEN_FAILED; - continue; + if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) + { + close( *fd ); + ret = POLARSSL_ERR_NET_LISTEN_FAILED; + continue; + } } /* I we ever get there, it's a success */ @@ -326,7 +332,9 @@ int net_bind( int *fd, const char *bind_ip, int port ) if( ( ret = net_prepare() ) != 0 ) return( ret ); - if( ( *fd = (int) socket( AF_INET, SOCK_STREAM, IPPROTO_IP ) ) < 0 ) + if( ( *fd = (int) socket( AF_INET, + proto == NET_PROTO_UDP ? SOCK_DGRAM : SOCK_STREAM, + proto == NET_PROTO_UDP ? IPPROTO_UDP : IPPROTO_TCP ) ) < 0 ) return( POLARSSL_ERR_NET_SOCKET_FAILED ); n = 1; @@ -361,10 +369,14 @@ int net_bind( int *fd, const char *bind_ip, int port ) return( POLARSSL_ERR_NET_BIND_FAILED ); } - if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) + /* Listen only makes sense for TCP */ + if( proto == NET_PROTO_TCP ) { - close( *fd ); - return( POLARSSL_ERR_NET_LISTEN_FAILED ); + if( listen( *fd, POLARSSL_NET_LISTEN_BACKLOG ) != 0 ) + { + close( *fd ); + return( POLARSSL_ERR_NET_LISTEN_FAILED ); + } } return( 0 ); @@ -416,6 +428,9 @@ static int net_would_block( int fd ) */ int net_accept( int bind_fd, int *client_fd, void *client_ip ) { + int ret; + int type; + #if defined(POLARSSL_HAVE_IPV6) struct sockaddr_storage client_addr; #else @@ -425,14 +440,35 @@ int net_accept( int bind_fd, int *client_fd, void *client_ip ) #if defined(__socklen_t_defined) || defined(_SOCKLEN_T) || \ defined(_SOCKLEN_T_DECLARED) socklen_t n = (socklen_t) sizeof( client_addr ); + socklen_t type_len = (socklen_t) sizeof( type ); #else int n = (int) sizeof( client_addr ); + int type_len = (int) sizeof( type ); #endif - *client_fd = (int) accept( bind_fd, (struct sockaddr *) - &client_addr, &n ); + /* Is this a TCP or UDP socket? */ + if( getsockopt( bind_fd, SOL_SOCKET, SO_TYPE, &type, &type_len ) != 0 || + ( type != SOCK_STREAM && type != SOCK_DGRAM ) ) + { + return( POLARSSL_ERR_NET_ACCEPT_FAILED ); + } - if( *client_fd < 0 ) + if( type == SOCK_STREAM ) + { + /* TCP: actual accept() */ + ret = *client_fd = (int) accept( bind_fd, + (struct sockaddr *) &client_addr, &n ); + } + else + { + /* UDP: wait for a message, but keep it in the queue */ + char buf[1] = { 0 }; + + ret = recvfrom( bind_fd, buf, 0, MSG_PEEK, + (struct sockaddr *) &client_addr, &n ); + } + + if( ret < 0 ) { if( net_would_block( bind_fd ) != 0 ) return( POLARSSL_ERR_NET_WANT_READ ); @@ -440,6 +476,15 @@ int net_accept( int bind_fd, int *client_fd, void *client_ip ) return( POLARSSL_ERR_NET_ACCEPT_FAILED ); } + /* UDP: hijack the listening socket for communicating with the client */ + if( type != SOCK_STREAM ) + { + if( connect( bind_fd, (struct sockaddr *) &client_addr, n ) != 0 ) + return( POLARSSL_ERR_NET_ACCEPT_FAILED ); + + *client_fd = bind_fd; + } + if( client_ip != NULL ) { #if defined(POLARSSL_HAVE_IPV6) diff --git a/programs/pkey/dh_client.c b/programs/pkey/dh_client.c index 5315eb921..ba0ca9273 100644 --- a/programs/pkey/dh_client.c +++ b/programs/pkey/dh_client.c @@ -135,7 +135,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_connect( &server_fd, SERVER_NAME, - SERVER_PORT ) ) != 0 ) + SERVER_PORT, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_connect returned %d\n\n", ret ); goto exit; diff --git a/programs/pkey/dh_server.c b/programs/pkey/dh_server.c index 976da4ca8..d4eb61355 100644 --- a/programs/pkey/dh_server.c +++ b/programs/pkey/dh_server.c @@ -163,7 +163,7 @@ int main( int argc, char *argv[] ) printf( "\n . Waiting for a remote connection" ); fflush( stdout ); - if( ( ret = net_bind( &listen_fd, NULL, SERVER_PORT ) ) != 0 ) + if( ( ret = net_bind( &listen_fd, NULL, SERVER_PORT, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_bind returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_client1.c b/programs/ssl/ssl_client1.c index 1b369a658..f1b4e6df1 100644 --- a/programs/ssl/ssl_client1.c +++ b/programs/ssl/ssl_client1.c @@ -140,7 +140,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_connect( &server_fd, SERVER_NAME, - SERVER_PORT ) ) != 0 ) + SERVER_PORT, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_connect returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c index c0a957bd8..9aacfee7a 100644 --- a/programs/ssl/ssl_client2.c +++ b/programs/ssl/ssl_client2.c @@ -844,7 +844,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_connect( &server_fd, opt.server_addr, - opt.server_port ) ) != 0 ) + opt.server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_connect returned -0x%x\n\n", -ret ); goto exit; @@ -1260,7 +1260,7 @@ reconnect: } if( ( ret = net_connect( &server_fd, opt.server_name, - opt.server_port ) ) != 0 ) + opt.server_port , NET_PROTO_TCP) ) != 0 ) { printf( " failed\n ! net_connect returned -0x%x\n\n", -ret ); goto exit; diff --git a/programs/ssl/ssl_fork_server.c b/programs/ssl/ssl_fork_server.c index 706cdd492..7c759a2e4 100644 --- a/programs/ssl/ssl_fork_server.c +++ b/programs/ssl/ssl_fork_server.c @@ -179,7 +179,7 @@ int main( int argc, char *argv[] ) printf( " . Bind on https://localhost:4433/ ..." ); fflush( stdout ); - if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) + if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_bind returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_mail_client.c b/programs/ssl/ssl_mail_client.c index 4cf59d03a..4e6602a23 100644 --- a/programs/ssl/ssl_mail_client.c +++ b/programs/ssl/ssl_mail_client.c @@ -574,7 +574,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_connect( &server_fd, opt.server_name, - opt.server_port ) ) != 0 ) + opt.server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_connect returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_pthread_server.c b/programs/ssl/ssl_pthread_server.c index 9a4c554e9..c19e3e072 100644 --- a/programs/ssl/ssl_pthread_server.c +++ b/programs/ssl/ssl_pthread_server.c @@ -445,7 +445,7 @@ int main( int argc, char *argv[] ) printf( " . Bind on https://localhost:4433/ ..." ); fflush( stdout ); - if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) + if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_bind returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_server.c b/programs/ssl/ssl_server.c index 9e097998f..962b0985c 100644 --- a/programs/ssl/ssl_server.c +++ b/programs/ssl/ssl_server.c @@ -159,7 +159,7 @@ int main( int argc, char *argv[] ) printf( " . Bind on https://localhost:4433/ ..." ); fflush( stdout ); - if( ( ret = net_bind( &listen_fd, NULL, 4433 ) ) != 0 ) + if( ( ret = net_bind( &listen_fd, NULL, 4433, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_bind returned %d\n\n", ret ); goto exit; diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index fcc8adbd0..51a4213e4 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -1246,7 +1246,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_bind( &listen_fd, opt.server_addr, - opt.server_port ) ) != 0 ) + opt.server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_bind returned -0x%x\n\n", -ret ); goto exit; diff --git a/programs/test/ssl_test.c b/programs/test/ssl_test.c index b436d17e8..9bde5de15 100644 --- a/programs/test/ssl_test.c +++ b/programs/test/ssl_test.c @@ -193,7 +193,7 @@ static int ssl_test( struct options *opt ) if( opt->opmode == OPMODE_CLIENT ) { if( ( ret = net_connect( &client_fd, opt->server_name, - opt->server_port ) ) != 0 ) + opt->server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " ! net_connect returned %d\n\n", ret ); return( ret ); @@ -242,7 +242,7 @@ static int ssl_test( struct options *opt ) if( server_fd < 0 ) { if( ( ret = net_bind( &server_fd, NULL, - opt->server_port ) ) != 0 ) + opt->server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " ! net_bind returned %d\n\n", ret ); return( ret ); diff --git a/programs/x509/cert_app.c b/programs/x509/cert_app.c index 5f8636b10..8b528a839 100644 --- a/programs/x509/cert_app.c +++ b/programs/x509/cert_app.c @@ -402,7 +402,7 @@ int main( int argc, char *argv[] ) fflush( stdout ); if( ( ret = net_connect( &server_fd, opt.server_name, - opt.server_port ) ) != 0 ) + opt.server_port, NET_PROTO_TCP ) ) != 0 ) { printf( " failed\n ! net_connect returned %d\n\n", ret ); goto exit;