diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index aa5daf414..c866b6ffd 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -320,6 +320,7 @@ #define SSL_ALERT_MSG_UNSUPPORTED_EXT 110 /* 0x6E */ #define SSL_ALERT_MSG_UNRECOGNIZED_NAME 112 /* 0x70 */ #define SSL_ALERT_MSG_UNKNOWN_PSK_IDENTITY 115 /* 0x73 */ +#define SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL 120 /* 0x78 */ #define SSL_HS_HELLO_REQUEST 0 #define SSL_HS_CLIENT_HELLO 1 diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 16fb264dc..2cbc79853 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -683,6 +683,69 @@ static int ssl_parse_session_ticket_ext( ssl_context *ssl, } #endif /* POLARSSL_SSL_SESSION_TICKETS */ +#if defined(POLARSSL_SSL_ALPN) +static int ssl_parse_alpn_ext( ssl_context *ssl, + unsigned char *buf, size_t len ) +{ + size_t list_len, cur_len; + const unsigned char *theirs, *start, *end; + const char **ours; + + /* If ALPN not configured, just ignore the extension */ + if( ssl->alpn_list == NULL ) + return( 0 ); + + /* + * opaque ProtocolName<1..2^8-1>; + * + * struct { + * ProtocolName protocol_name_list<2..2^16-1> + * } ProtocolNameList; + */ + + /* Min length is 2 (list_len) + 1 (name_len) + 1 (name) */ + if( len < 4 ) + return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); + + list_len = ( buf[0] << 8 ) | buf[1]; + if( list_len != len - 2 ) + return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); + + /* + * Use our order of preference + */ + start = buf + 2; + end = buf + len; + for( ours = ssl->alpn_list; *ours != NULL; ours++ ) + { + for( theirs = start; theirs != end; theirs += cur_len ) + { + /* If the list is well formed, we should get equality first */ + if( theirs > end ) + return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); + + cur_len = *theirs++; + + /* Empty strings MUST NOT be included */ + if( cur_len == 0 ) + return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); + + if( cur_len == strlen( *ours ) && + memcmp( theirs, *ours, cur_len ) == 0 ) + { + ssl->alpn_chosen = *ours; + return( 0 ); + } + } + } + + /* If we get there, no match was found */ + ssl_send_alert_message( ssl, SSL_ALERT_LEVEL_FATAL, + SSL_ALERT_MSG_NO_APPLICATION_PROTOCOL ); + return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_HELLO ); +} +#endif /* POLARSSL_SSL_ALPN */ + /* * Auxiliary functions for ServerHello parsing and related actions */ @@ -1385,6 +1448,16 @@ static int ssl_parse_client_hello( ssl_context *ssl ) break; #endif /* POLARSSL_SSL_SESSION_TICKETS */ +#if defined(POLARSSL_SSL_ALPN) + case TLS_EXT_ALPN: + SSL_DEBUG_MSG( 3, ( "found ALPN extension" ) ); + + ret = ssl_parse_alpn_ext( ssl, ext + 4, ext_size ); + if( ret != 0 ) + return( ret ); + break; +#endif /* POLARSSL_SSL_SESSION_TICKETS */ + default: SSL_DEBUG_MSG( 3, ( "unknown extension found: %d (ignoring)", ext_id ) ); @@ -1625,6 +1698,42 @@ static void ssl_write_supported_point_formats_ext( ssl_context *ssl, } #endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */ +#if defined(POLARSSL_SSL_ALPN ) +static void ssl_write_alpn_ext( ssl_context *ssl, + unsigned char *buf, size_t *olen ) +{ + if( ssl->alpn_chosen == NULL ) + { + *olen = 0; + return; + } + + SSL_DEBUG_MSG( 3, ( "server hello, alpn extension" ) ); + + /* + * 0 . 1 ext identifier + * 2 . 3 ext length + * 4 . 5 protocol list length + * 6 . 6 protocol name length + * 7 . 7+n protocol name + */ + buf[0] = (unsigned char)( ( TLS_EXT_ALPN >> 8 ) & 0xFF ); + buf[1] = (unsigned char)( ( TLS_EXT_ALPN ) & 0xFF ); + + *olen = 7 + strlen( ssl->alpn_chosen ); + + buf[2] = (unsigned char)( ( ( *olen - 4 ) >> 8 ) & 0xFF ); + buf[3] = (unsigned char)( ( ( *olen - 4 ) ) & 0xFF ); + + buf[4] = (unsigned char)( ( ( *olen - 6 ) >> 8 ) & 0xFF ); + buf[5] = (unsigned char)( ( ( *olen - 6 ) ) & 0xFF ); + + buf[6] = (unsigned char)( ( ( *olen - 7 ) ) & 0xFF ); + + memcpy( buf + 7, ssl->alpn_chosen, *olen - 7 ); +} +#endif /* POLARSSL_ECDH_C || POLARSSL_ECDSA_C */ + static int ssl_write_server_hello( ssl_context *ssl ) { #if defined(POLARSSL_HAVE_TIME) @@ -1791,6 +1900,11 @@ static int ssl_write_server_hello( ssl_context *ssl ) ext_len += olen; #endif +#if defined(POLARSSL_SSL_ALPN) + ssl_write_alpn_ext( ssl, p + 2 + ext_len, &olen ); + ext_len += olen; +#endif + SSL_DEBUG_MSG( 3, ( "server hello, total extension length: %d", ext_len ) ); *p++ = (unsigned char)( ( ext_len >> 8 ) & 0xFF );