diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index c932d14b1..e80038e87 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -77,6 +77,7 @@ #define DFL_KEY_FILE2 "" #define DFL_PSK "" #define DFL_PSK_IDENTITY "Client_identity" +#define DFL_PSK_LIST NULL #define DFL_FORCE_CIPHER 0 #define DFL_RENEGOTIATION SSL_RENEGOTIATION_DISABLED #define DFL_ALLOW_LEGACY SSL_LEGACY_NO_RENEGOTIATION @@ -127,6 +128,7 @@ struct options const char *key_file2; /* the file with the 2nd server key */ const char *psk; /* the pre-shared key */ const char *psk_identity; /* the pre-shared key identity */ + char *psk_list; /* list of PSK id/key pairs for callback */ int force_ciphersuite[2]; /* protocol/ciphersuite to use, or all */ int renegotiation; /* enable / disable renegotiation */ int allow_legacy; /* allow legacy renegotiation */ @@ -474,6 +476,97 @@ int unhexify( unsigned char *output, const char *input, size_t *olen ) return( 0 ); } + +typedef struct _psk_entry psk_entry; + +struct _psk_entry +{ + const char *name; + size_t key_len; + unsigned char key[MAX_PSK_LEN]; + psk_entry *next; +}; + +/* + * Parse a string of pairs name1,key1[,name2,key2[,...]] + * into a usable psk_entry list. + * + * Modifies the input string! This is not production quality! + * (leaks memory if parsing fails, no error reporting, ...) + */ +psk_entry *psk_parse( char *psk_string ) +{ + psk_entry *cur = NULL, *new = NULL; + char *p = psk_string; + char *end = p; + char *key_hex; + + while( *end != '\0' ) + ++end; + *end = ','; + + while( p <= end ) + { + if( ( new = polarssl_malloc( sizeof( psk_entry ) ) ) == NULL ) + return( NULL ); + + memset( new, 0, sizeof( psk_entry ) ); + + new->name = p; + while( *p != ',' ) if( ++p > end ) return( NULL ); + *p++ = '\0'; + + key_hex = p; + while( *p != ',' ) if( ++p > end ) return( NULL ); + *p++ = '\0'; + + if( unhexify( new->key, key_hex, &new->key_len ) != 0 ) + return( NULL ); + + new->next = cur; + cur = new; + } + + return( cur ); +} + +/* + * Free a list of psk_entry's + */ +void psk_free( psk_entry *head ) +{ + psk_entry *next; + + while( head != NULL ) + { + next = head->next; + polarssl_free( head ); + head = next; + } +} + +/* + * PSK callback + */ +int psk_callback( void *p_info, ssl_context *ssl, + const unsigned char *name, size_t name_len ) +{ + psk_entry *cur = (psk_entry *) p_info; + + while( cur != NULL ) + { + if( name_len == strlen( cur->name ) && + memcmp( name, cur->name, name_len ) == 0 ) + { + return( ssl_set_psk( ssl, cur->key, cur->key_len, + name, name_len ) ); + } + + cur = cur->next; + } + + return( -1 ); +} #endif /* POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED */ int main( int argc, char *argv[] ) @@ -485,6 +578,7 @@ int main( int argc, char *argv[] ) #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED) unsigned char psk[MAX_PSK_LEN]; size_t psk_len = 0; + psk_entry *psk_info; #endif const char *pers = "ssl_server2"; @@ -579,6 +673,7 @@ int main( int argc, char *argv[] ) opt.key_file2 = DFL_KEY_FILE2; opt.psk = DFL_PSK; opt.psk_identity = DFL_PSK_IDENTITY; + opt.psk_list = DFL_PSK_LIST; opt.force_ciphersuite[0]= DFL_FORCE_CIPHER; opt.renegotiation = DFL_RENEGOTIATION; opt.allow_legacy = DFL_ALLOW_LEGACY; @@ -640,6 +735,8 @@ int main( int argc, char *argv[] ) opt.psk = q; else if( strcmp( p, "psk_identity" ) == 0 ) opt.psk_identity = q; + else if( strcmp( p, "psk_list" ) == 0 ) + opt.psk_list = q; else if( strcmp( p, "force_ciphersuite" ) == 0 ) { opt.force_ciphersuite[0] = -1; @@ -812,13 +909,19 @@ int main( int argc, char *argv[] ) #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED) /* - * Unhexify the pre-shared key if any is given + * Unhexify the pre-shared key and parse the list if any given */ - if( opt.psk != NULL ) + if( unhexify( psk, opt.psk, &psk_len ) != 0 ) { - if( unhexify( psk, opt.psk, &psk_len ) != 0 ) + printf( "pre-shared key not valid hex\n" ); + goto exit; + } + + if( opt.psk_list != NULL ) + { + if( ( psk_info = psk_parse( opt.psk_list ) ) == NULL ) { - printf("pre-shared key not valid hex\n"); + printf( "psk_list invalid" ); goto exit; } } @@ -1127,6 +1230,8 @@ int main( int argc, char *argv[] ) #if defined(POLARSSL_KEY_EXCHANGE__SOME__PSK_ENABLED) ssl_set_psk( &ssl, psk, psk_len, (const unsigned char *) opt.psk_identity, strlen( opt.psk_identity ) ); + if( opt.psk_list != NULL ) + ssl_set_psk_cb( &ssl, psk_callback, psk_info ); #endif #if defined(POLARSSL_DHM_C)