diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index cc8a2fcfb..ab21e50a5 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -136,6 +136,18 @@ typedef struct int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); + /** Decrypt message */ + int (*decrypt_func)( void *ctx, const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ); + + /** Encrypt message */ + int (*encrypt_func)( void *ctx, const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), + void *p_rng ); + /** Allocate a new context */ void * (*ctx_alloc_func)( void ); @@ -244,6 +256,40 @@ int pk_sign( pk_context *ctx, md_type_t md_alg, unsigned char *sig, size_t *sig_len, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); +/** + * \brief Decrypt message + * + * \param ctx PK context to use + * \param input Input to decrypt + * \param ilen Input size + * \param output Decrypted output + * \param olen Decrypted message lenght + * \param osize Size of the output buffer + * + * \return 0 on success, or a specific error code. + */ +int pk_decrypt( pk_context *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); + +/** + * \brief Encrypt message + * + * \param ctx PK context to use + * \param input Message to encrypt + * \param ilen Message size + * \param output Encrypted output + * \param olen Encrypted output length + * \param osize Size of the output buffer + * + * \return 0 on success, or a specific error code. + */ +int pk_encrypt( pk_context *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ); + /** * \brief Export debug information * diff --git a/library/pk.c b/library/pk.c index 6f68c7392..6e6057462 100644 --- a/library/pk.c +++ b/library/pk.c @@ -152,6 +152,42 @@ int pk_sign( pk_context *ctx, md_type_t md_alg, sig, sig_len, f_rng, p_rng ) ); } +/* + * Decrypt message + */ +int pk_decrypt( pk_context *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + if( ctx == NULL || ctx->pk_info == NULL ) + return( POLARSSL_ERR_PK_BAD_INPUT_DATA ); + + if( ctx->pk_info->decrypt_func == NULL ) + return( POLARSSL_ERR_PK_TYPE_MISMATCH ); + + return( ctx->pk_info->decrypt_func( ctx->pk_ctx, input, ilen, + output, olen, osize, f_rng, p_rng ) ); +} + +/* + * Encrypt message + */ +int pk_encrypt( pk_context *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + if( ctx == NULL || ctx->pk_info == NULL ) + return( POLARSSL_ERR_PK_BAD_INPUT_DATA ); + + if( ctx->pk_info->encrypt_func == NULL ) + return( POLARSSL_ERR_PK_TYPE_MISMATCH ); + + return( ctx->pk_info->encrypt_func( ctx->pk_ctx, input, ilen, + output, olen, osize, f_rng, p_rng ) ); +} + /* * Get key size in bits */ diff --git a/library/pk_wrap.c b/library/pk_wrap.c index eb91d895f..2c55ce08f 100644 --- a/library/pk_wrap.c +++ b/library/pk_wrap.c @@ -80,6 +80,34 @@ static int rsa_sign_wrap( void *ctx, md_type_t md_alg, md_alg, hash_len, hash, sig ) ); } +static int rsa_decrypt_wrap( void *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + ((void) f_rng); + ((void) p_rng); + + if( ilen != ((rsa_context *) ctx)->len ) + return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); + + return( rsa_pkcs1_decrypt( (rsa_context *) ctx, + RSA_PRIVATE, olen, input, output, osize ) ); +} + +static int rsa_encrypt_wrap( void *ctx, + const unsigned char *input, size_t ilen, + unsigned char *output, size_t *olen, size_t osize, + int (*f_rng)(void *, unsigned char *, size_t), void *p_rng ) +{ + ((void) osize); + + *olen = ((rsa_context *) ctx)->len; + + return( rsa_pkcs1_encrypt( (rsa_context *) ctx, + f_rng, p_rng, RSA_PUBLIC, ilen, input, output ) ); +} + static void *rsa_alloc_wrap( void ) { void *ctx = polarssl_malloc( sizeof( rsa_context ) ); @@ -116,6 +144,8 @@ const pk_info_t rsa_info = { rsa_can_do, rsa_verify_wrap, rsa_sign_wrap, + rsa_decrypt_wrap, + rsa_encrypt_wrap, rsa_alloc_wrap, rsa_free_wrap, rsa_debug, @@ -222,6 +252,8 @@ const pk_info_t eckey_info = { NULL, NULL, #endif + NULL, + NULL, eckey_alloc_wrap, eckey_free_wrap, eckey_debug, @@ -243,6 +275,8 @@ const pk_info_t eckeydh_info = { eckeydh_can_do, NULL, NULL, + NULL, + NULL, eckey_alloc_wrap, /* Same underlying key structure */ eckey_free_wrap, /* Same underlying key structure */ eckey_debug, /* Same underlying key structure */ @@ -299,6 +333,8 @@ const pk_info_t ecdsa_info = { ecdsa_can_do, ecdsa_verify_wrap, ecdsa_sign_wrap, + NULL, + NULL, ecdsa_alloc_wrap, ecdsa_free_wrap, eckey_debug, /* Compatible key structures */ diff --git a/library/ssl_cli.c b/library/ssl_cli.c index 829e46b75..cd77eb82d 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -1870,26 +1870,24 @@ static int ssl_write_client_key_exchange( ssl_context *ssl ) return( POLARSSL_ERR_SSL_PK_TYPE_MISMATCH ); } - i = 4; - n = pk_get_size( &ssl->session_negotiate->peer_cert->pk ) / 8; + i = ssl->minor_ver == SSL_MINOR_VERSION_0 ? 4 : 6; - if( ssl->minor_ver != SSL_MINOR_VERSION_0 ) - { - i += 2; - ssl->out_msg[4] = (unsigned char)( n >> 8 ); - ssl->out_msg[5] = (unsigned char)( n ); - } - - ret = rsa_pkcs1_encrypt( - pk_rsa( ssl->session_negotiate->peer_cert->pk ), - ssl->f_rng, ssl->p_rng, RSA_PUBLIC, - ssl->handshake->pmslen, ssl->handshake->premaster, - ssl->out_msg + i ); + ret = pk_encrypt( &ssl->session_negotiate->peer_cert->pk, + ssl->handshake->premaster, ssl->handshake->pmslen, + ssl->out_msg + i, &n, SSL_BUFFER_LEN, + ssl->f_rng, ssl->p_rng ); if( ret != 0 ) { SSL_DEBUG_RET( 1, "rsa_pkcs1_encrypt", ret ); return( ret ); } + + if( ssl->minor_ver != SSL_MINOR_VERSION_0 ) + { + ssl->out_msg[4] = (unsigned char)( n >> 8 ); + ssl->out_msg[5] = (unsigned char)( n ); + } + } else #endif /* POLARSSL_KEY_EXCHANGE_RSA_ENABLED */ diff --git a/library/ssl_srv.c b/library/ssl_srv.c index ffd754e36..6fb16ecab 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -2259,9 +2259,9 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) int ret = POLARSSL_ERR_SSL_FEATURE_UNAVAILABLE; size_t i, n = 0; - if( ssl->rsa_key == NULL ) + if( ! pk_can_do( ssl->pk_key, POLARSSL_PK_RSA ) ) { - SSL_DEBUG_MSG( 1, ( "got no private key" ) ); + SSL_DEBUG_MSG( 1, ( "got no RSA private key" ) ); return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED ); } @@ -2269,8 +2269,7 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) * Decrypt the premaster using own private RSA key */ i = 4; - if( ssl->rsa_key ) - n = ssl->rsa_key_len( ssl->rsa_key ); + n = ssl->rsa_key_len( ssl->rsa_key ); ssl->handshake->pmslen = 48; if( ssl->minor_ver != SSL_MINOR_VERSION_0 ) @@ -2290,13 +2289,21 @@ static int ssl_parse_encrypted_pms_secret( ssl_context *ssl ) return( POLARSSL_ERR_SSL_BAD_HS_CLIENT_KEY_EXCHANGE ); } - if( ssl->rsa_key ) { + if( ssl->rsa_use_alt ) { ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE, &ssl->handshake->pmslen, ssl->in_msg + i, ssl->handshake->premaster, sizeof(ssl->handshake->premaster) ); } + else + { + ret = pk_decrypt( ssl->pk_key, + ssl->in_msg + i, n, + ssl->handshake->premaster, &ssl->handshake->pmslen, + sizeof(ssl->handshake->premaster), + ssl->f_rng, ssl->p_rng ); + } if( ret != 0 || ssl->handshake->pmslen != 48 || ssl->handshake->premaster[0] != ssl->handshake->max_major_ver ||