diff --git a/include/mbedtls/config.h b/include/mbedtls/config.h index e87284017..fcb92f2c7 100644 --- a/include/mbedtls/config.h +++ b/include/mbedtls/config.h @@ -3457,6 +3457,9 @@ /* Timeout */ //#define MBEDTLS_SSL_CONF_READ_TIMEOUT 0 +/* Endpoint (Client/Server) */ +//#define MBEDTLS_SSL_CONF_ENDPOINT MBED + /* DTLS-specific settings */ //#define MBEDTLS_SSL_CONF_HS_TIMEOUT_MIN MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MIN //#define MBEDTLS_SSL_CONF_HS_TIMEOUT_MAX MBEDTLS_SSL_DTLS_TIMEOUT_DFL_MAX diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 1d7c05d40..86759e50c 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -1055,7 +1055,9 @@ struct mbedtls_ssl_config * Flags (bitfields) */ +#if !defined(MBEDTLS_SSL_CONF_ENDPOINT) unsigned int endpoint : 1; /*!< 0: client, 1: server */ +#endif /* !MBEDTLS_SSL_CONF_ENDPOINT */ unsigned int transport : 1; /*!< stream (TLS) or datagram (DTLS) */ #if !defined(MBEDTLS_SSL_CONF_AUTHMODE) unsigned int authmode : 2; /*!< MBEDTLS_SSL_VERIFY_XXX */ @@ -1381,13 +1383,18 @@ int mbedtls_ssl_setup( mbedtls_ssl_context *ssl, */ int mbedtls_ssl_session_reset( mbedtls_ssl_context *ssl ); +#if !defined(MBEDTLS_SSL_CONF_ENDPOINT) /** * \brief Set the current endpoint type * + * \note On constrained systems, this can also be configured + * at compile-time via MBEDTLS_SSL_CONF_ENDPOINT. + * * \param conf SSL configuration * \param endpoint must be MBEDTLS_SSL_IS_CLIENT or MBEDTLS_SSL_IS_SERVER */ void mbedtls_ssl_conf_endpoint( mbedtls_ssl_config *conf, int endpoint ); +#endif /* !MBEDTLS_SSL_CONF_ENDPOINT */ /** * \brief Set the transport type (TLS or DTLS). diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index 093c4ac13..b08aae288 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -1085,6 +1085,21 @@ int mbedtls_ssl_decrypt_buf( mbedtls_ssl_context *ssl, * be fixed at compile time via one of MBEDTLS_SSL_SSL_CONF_XXX. */ +#if !defined(MBEDTLS_SSL_CONF_ENDPOINT) +static inline unsigned int mbedtls_ssl_conf_get_endpoint( + mbedtls_ssl_config const *conf ) +{ + return( conf->endpoint ); +} +#else /* !MBEDTLS_SSL_CONF_ENDPOINT */ +static inline unsigned int mbedtls_ssl_conf_get_endpoint( + mbedtls_ssl_config const *conf ) +{ + ((void) conf); + return( MBEDTLS_SSL_CONF_ENDPOINT ); +} +#endif /* MBEDTLS_SSL_CONF_ENDPOINT */ + #if !defined(MBEDTLS_SSL_CONF_READ_TIMEOUT) static inline uint32_t mbedtls_ssl_conf_get_read_timeout( mbedtls_ssl_config const *conf ) diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 0009f8bed..b6b7750c7 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -55,7 +55,7 @@ int mbedtls_ssl_set_client_transport_id( mbedtls_ssl_context *ssl, const unsigned char *info, size_t ilen ) { - if( ssl->conf->endpoint != MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) != MBEDTLS_SSL_IS_SERVER ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); mbedtls_free( ssl->cli_id ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 95cf554f0..cfd6589b1 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1388,7 +1388,7 @@ int mbedtls_ssl_derive_keys( mbedtls_ssl_context *ssl ) ssl->handshake->tls_prf, ssl->handshake->randbytes, ssl->minor_ver, - ssl->conf->endpoint, + mbedtls_ssl_conf_get_endpoint( ssl->conf ), ssl ); if( ret != 0 ) { @@ -3196,7 +3196,8 @@ int mbedtls_ssl_fetch_input( mbedtls_ssl_context *ssl, size_t nb_want ) return( MBEDTLS_ERR_SSL_WANT_READ ); } #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_RENEGOTIATION) - else if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER && + else if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER && ssl->renego_status == MBEDTLS_SSL_RENEGOTIATION_PENDING ) { if( ( ret = ssl_resend_hello_request( ssl ) ) != 0 ) @@ -3722,7 +3723,8 @@ int mbedtls_ssl_write_handshake_msg( mbedtls_ssl_context *ssl ) #if defined(MBEDTLS_SSL_PROTO_SSL3) && defined(MBEDTLS_SSL_CLI_C) if( ! ( ssl->minor_ver == MBEDTLS_SSL_MINOR_VERSION_0 && ssl->out_msgtype == MBEDTLS_SSL_MSG_ALERT && - ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) ) + mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT ) ) #endif /* MBEDTLS_SSL_PROTO_SSL3 && MBEDTLS_SSL_SRV_C */ { MBEDTLS_SSL_DEBUG_MSG( 1, ( "should never happen" ) ); @@ -4754,7 +4756,8 @@ static int ssl_parse_record_header( mbedtls_ssl_context *ssl ) * have an active transform (possibly iv_len != 0), so use the * fact that the record header len is 13 instead. */ - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER && + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER && ssl->state == MBEDTLS_SSL_HANDSHAKE_OVER && rec_epoch == 0 && ssl->in_msgtype == MBEDTLS_SSL_MSG_HANDSHAKE && @@ -5937,7 +5940,8 @@ int mbedtls_ssl_handle_message_type( mbedtls_ssl_context *ssl ) #if defined(MBEDTLS_SSL_PROTO_SSL3) && defined(MBEDTLS_SSL_SRV_C) if( ssl->minor_ver == MBEDTLS_SSL_MINOR_VERSION_0 && - ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER && + mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER && ssl->in_msg[0] == MBEDTLS_SSL_ALERT_LEVEL_WARNING && ssl->in_msg[1] == MBEDTLS_SSL_ALERT_MSG_NO_CERT ) { @@ -6104,7 +6108,8 @@ int mbedtls_ssl_write_certificate( mbedtls_ssl_context *ssl ) } #if defined(MBEDTLS_SSL_CLI_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT ) { if( ssl->client_auth == 0 ) { @@ -6133,7 +6138,7 @@ int mbedtls_ssl_write_certificate( mbedtls_ssl_context *ssl ) } #endif /* MBEDTLS_SSL_CLI_C */ #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) { if( mbedtls_ssl_own_cert( ssl ) == NULL ) { @@ -6337,7 +6342,8 @@ static int ssl_parse_certificate_chain( mbedtls_ssl_context *ssl, /* Check if we're handling the first CRT in the chain. */ #if defined(MBEDTLS_SSL_RENEGOTIATION) && defined(MBEDTLS_SSL_CLI_C) if( crt_cnt++ == 0 && - ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT && + mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT && ssl->renego_status == MBEDTLS_SSL_RENEGOTIATION_IN_PROGRESS ) { /* During client-side renegotiation, check that the server's @@ -6403,7 +6409,7 @@ static int ssl_parse_certificate_chain( mbedtls_ssl_context *ssl, #if defined(MBEDTLS_SSL_SRV_C) static int ssl_srv_check_client_no_crt_notification( mbedtls_ssl_context *ssl ) { - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_CLIENT ) return( -1 ); #if defined(MBEDTLS_SSL_PROTO_SSL3) @@ -6461,7 +6467,7 @@ static int ssl_parse_certificate_coordinate( mbedtls_ssl_context *ssl, return( SSL_CERTIFICATE_SKIP ); #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) { if( ciphersuite_info->key_exchange == MBEDTLS_KEY_EXCHANGE_RSA_PSK ) return( SSL_CERTIFICATE_SKIP ); @@ -6551,7 +6557,8 @@ static int ssl_parse_certificate_verify( mbedtls_ssl_context *ssl, if( mbedtls_ssl_check_cert_usage( chain, ciphersuite_info, - ! ssl->conf->endpoint, + ( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT ), &ssl->session_negotiate->verify_result ) != 0 ) { MBEDTLS_SSL_DEBUG_MSG( 1, ( "bad certificate (usage extensions)" ) ); @@ -7371,7 +7378,8 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) ssl_update_out_pointers( ssl, ssl->transform_negotiate ); - ssl->handshake->calc_finished( ssl, ssl->out_msg + 4, ssl->conf->endpoint ); + ssl->handshake->calc_finished( ssl, ssl->out_msg + 4, + mbedtls_ssl_conf_get_endpoint( ssl->conf ) ); /* * RFC 5246 7.4.9 (Page 63) says 12 is the default length and ciphersuites @@ -7397,12 +7405,18 @@ int mbedtls_ssl_write_finished( mbedtls_ssl_context *ssl ) if( ssl->handshake->resume != 0 ) { #if defined(MBEDTLS_SSL_CLI_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT ) + { ssl->state = MBEDTLS_SSL_HANDSHAKE_WRAPUP; + } #endif #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER ) + { ssl->state = MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC; + } #endif } else @@ -7499,7 +7513,8 @@ int mbedtls_ssl_parse_finished( mbedtls_ssl_context *ssl ) MBEDTLS_SSL_DEBUG_MSG( 2, ( "=> parse finished" ) ); - ssl->handshake->calc_finished( ssl, buf, ssl->conf->endpoint ^ 1 ); + ssl->handshake->calc_finished( ssl, buf, + mbedtls_ssl_conf_get_endpoint( ssl->conf ) ^ 1 ); if( ( ret = mbedtls_ssl_read_record( ssl, 1 ) ) != 0 ) { @@ -7549,11 +7564,11 @@ int mbedtls_ssl_parse_finished( mbedtls_ssl_context *ssl ) if( ssl->handshake->resume != 0 ) { #if defined(MBEDTLS_SSL_CLI_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_CLIENT ) ssl->state = MBEDTLS_SSL_CLIENT_CHANGE_CIPHER_SPEC; #endif #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) ssl->state = MBEDTLS_SSL_HANDSHAKE_WRAPUP; #endif } @@ -7702,7 +7717,7 @@ static int ssl_handshake_init( mbedtls_ssl_context *ssl ) { ssl->handshake->alt_transform_out = ssl->transform_out; - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_CLIENT ) ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_PREPARING; else ssl->handshake->retransmit_state = MBEDTLS_SSL_RETRANS_WAITING; @@ -8074,10 +8089,12 @@ int mbedtls_ssl_session_reset( mbedtls_ssl_context *ssl ) /* * SSL set accessors */ +#if !defined(MBEDTLS_SSL_CONF_ENDPOINT) void mbedtls_ssl_conf_endpoint( mbedtls_ssl_config *conf, int endpoint ) { conf->endpoint = endpoint; } +#endif /* MBEDTLS_SSL_CONF_ENDPOINT */ void mbedtls_ssl_conf_transport( mbedtls_ssl_config *conf, int transport ) { @@ -8226,7 +8243,7 @@ int mbedtls_ssl_set_session( mbedtls_ssl_context *ssl, const mbedtls_ssl_session if( ssl == NULL || session == NULL || ssl->session_negotiate == NULL || - ssl->conf->endpoint != MBEDTLS_SSL_IS_CLIENT ) + mbedtls_ssl_conf_get_endpoint( ssl->conf ) != MBEDTLS_SSL_IS_CLIENT ) { return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } @@ -8354,7 +8371,7 @@ int mbedtls_ssl_set_hs_ecjpake_password( mbedtls_ssl_context *ssl, if( ssl->handshake == NULL || ssl->conf == NULL ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) role = MBEDTLS_ECJPAKE_SERVER; else role = MBEDTLS_ECJPAKE_CLIENT; @@ -9024,7 +9041,7 @@ size_t mbedtls_ssl_get_max_frag_len( const mbedtls_ssl_context *ssl ) static size_t ssl_get_current_mtu( const mbedtls_ssl_context *ssl ) { /* Return unlimited mtu for client hello messages to avoid fragmentation. */ - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT && + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_CLIENT && ( ssl->state == MBEDTLS_SSL_CLIENT_HELLO || ssl->state == MBEDTLS_SSL_SERVER_HELLO ) ) return ( 0 ); @@ -9106,7 +9123,7 @@ int mbedtls_ssl_get_session( const mbedtls_ssl_context *ssl, if( ssl == NULL || dst == NULL || ssl->session == NULL || - ssl->conf->endpoint != MBEDTLS_SSL_IS_CLIENT ) + mbedtls_ssl_conf_get_endpoint( ssl->conf ) != MBEDTLS_SSL_IS_CLIENT ) { return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); } @@ -9690,11 +9707,11 @@ int mbedtls_ssl_handshake_step( mbedtls_ssl_context *ssl ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); #if defined(MBEDTLS_SSL_CLI_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_CLIENT ) ret = mbedtls_ssl_handshake_client_step( ssl ); #endif #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) ret = mbedtls_ssl_handshake_server_step( ssl ); #endif @@ -9777,10 +9794,15 @@ static int ssl_start_renegotiation( mbedtls_ssl_context *ssl ) if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && ssl->renego_status == MBEDTLS_SSL_RENEGOTIATION_PENDING ) { - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER ) + { ssl->handshake->out_msg_seq = 1; + } else + { ssl->handshake->in_msg_seq = 1; + } } #endif @@ -9811,7 +9833,7 @@ int mbedtls_ssl_renegotiate( mbedtls_ssl_context *ssl ) #if defined(MBEDTLS_SSL_SRV_C) /* On server, just send the request */ - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER ) + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == MBEDTLS_SSL_IS_SERVER ) { if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER ) return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA ); @@ -9994,7 +10016,8 @@ int mbedtls_ssl_read( mbedtls_ssl_context *ssl, unsigned char *buf, size_t len ) */ #if defined(MBEDTLS_SSL_CLI_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT && + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT && ( ssl->in_msg[0] != MBEDTLS_SSL_HS_HELLO_REQUEST || ssl->in_hslen != mbedtls_ssl_hs_hdr_len( ssl ) ) ) { @@ -10017,7 +10040,8 @@ int mbedtls_ssl_read( mbedtls_ssl_context *ssl, unsigned char *buf, size_t len ) #endif /* MBEDTLS_SSL_CLI_C */ #if defined(MBEDTLS_SSL_SRV_C) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER && + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER && ssl->in_msg[0] != MBEDTLS_SSL_HS_CLIENT_HELLO ) { MBEDTLS_SSL_DEBUG_MSG( 1, ( "handshake received (not ClientHello)" ) ); @@ -10052,7 +10076,8 @@ int mbedtls_ssl_read( mbedtls_ssl_context *ssl, unsigned char *buf, size_t len ) /* DTLS clients need to know renego is server-initiated */ #if defined(MBEDTLS_SSL_PROTO_DTLS) if( MBEDTLS_SSL_TRANSPORT_IS_DTLS( ssl->conf->transport ) && - ssl->conf->endpoint == MBEDTLS_SSL_IS_CLIENT ) + mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_CLIENT ) { ssl->renego_status = MBEDTLS_SSL_RENEGOTIATION_PENDING; } @@ -10165,7 +10190,8 @@ int mbedtls_ssl_read( mbedtls_ssl_context *ssl, unsigned char *buf, size_t len ) * Do it now, after setting in_offt, to avoid taking this branch * again if ssl_write_hello_request() returns WANT_WRITE */ #if defined(MBEDTLS_SSL_SRV_C) && defined(MBEDTLS_SSL_RENEGOTIATION) - if( ssl->conf->endpoint == MBEDTLS_SSL_IS_SERVER && + if( mbedtls_ssl_conf_get_endpoint( ssl->conf ) == + MBEDTLS_SSL_IS_SERVER && ssl->renego_status == MBEDTLS_SSL_RENEGOTIATION_PENDING ) { if( ( ret = ssl_resend_hello_request( ssl ) ) != 0 ) diff --git a/programs/ssl/query_config.c b/programs/ssl/query_config.c index e58ddc4d5..29b778c4c 100644 --- a/programs/ssl/query_config.c +++ b/programs/ssl/query_config.c @@ -2602,6 +2602,14 @@ int query_config( const char *config ) } #endif /* MBEDTLS_SSL_CONF_READ_TIMEOUT */ +#if defined(MBEDTLS_SSL_CONF_ENDPOINT) + if( strcmp( "MBEDTLS_SSL_CONF_ENDPOINT", config ) == 0 ) + { + MACRO_EXPANSION_TO_STR( MBEDTLS_SSL_CONF_ENDPOINT ); + return( 0 ); + } +#endif /* MBEDTLS_SSL_CONF_ENDPOINT */ + #if defined(MBEDTLS_SSL_CONF_HS_TIMEOUT_MIN) if( strcmp( "MBEDTLS_SSL_CONF_HS_TIMEOUT_MIN", config ) == 0 ) {