From eb2c65816390eeda045b9012ba036bf95e22a4df Mon Sep 17 00:00:00 2001
From: Paul Bakker
Date: Thu, 27 Sep 2012 19:15:01 +0000
Subject: [PATCH] - Generalized external private key implementation handling
(like PKCS#11) in SSL/TLS
---
ChangeLog | 2 ++
include/polarssl/config.h | 4 +--
include/polarssl/pkcs11.h | 35 ++++++++++++++++++++
include/polarssl/ssl.h | 48 +++++++++++++++++++--------
library/ssl_cli.c | 32 ++++--------------
library/ssl_srv.c | 70 +++++++++------------------------------
library/ssl_tls.c | 41 ++++++++++++++++++++---
7 files changed, 130 insertions(+), 102 deletions(-)
diff --git a/ChangeLog b/ChangeLog
index d662bcf54..bfe342126 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -40,6 +40,8 @@ Changes
POLARSSL_MODE_CFB, to also handle different block size CFB modes.
* Removed handling for SSLv2 Client Hello (as per RFC 5246 recommendation)
* Revamped session resumption handling
+ * Generalized external private key implementation handling (like PKCS#11)
+ in SSL/TLS
Bugfix
* Fixed handling error in mpi_cmp_mpi() on longer B values (found by
diff --git a/include/polarssl/config.h b/include/polarssl/config.h
index 538ef817d..543b96c8c 100644
--- a/include/polarssl/config.h
+++ b/include/polarssl/config.h
@@ -612,7 +612,7 @@
/**
* \def POLARSSL_PKCS11_C
*
- * Enable support for PKCS#11 smartcard support.
+ * Enable wrapper for PKCS#11 smartcard support.
*
* Module: library/ssl_srv.c
* Caller: library/ssl_cli.c
@@ -620,7 +620,7 @@
*
* Requires: POLARSSL_SSL_TLS_C
*
- * This module is required for SSL/TLS PKCS #11 smartcard support.
+ * This module enables SSL/TLS PKCS #11 smartcard support.
* Requires the presence of the PKCS#11 helper library (libpkcs11-helper)
#define POLARSSL_PKCS11_C
*/
diff --git a/include/polarssl/pkcs11.h b/include/polarssl/pkcs11.h
index a65a72e81..ddfae3017 100644
--- a/include/polarssl/pkcs11.h
+++ b/include/polarssl/pkcs11.h
@@ -37,6 +37,14 @@
#include
+#if defined(_MSC_VER) && !defined(inline)
+#define inline _inline
+#else
+#if defined(__ARMCC_VERSION) && !defined(inline)
+#define inline __inline
+#endif /* __ARMCC_VERSION */
+#endif /*_MSC_VER */
+
/**
* Context for PKCS #11 private keys.
*/
@@ -121,6 +129,33 @@ int pkcs11_sign( pkcs11_context *ctx,
const unsigned char *hash,
unsigned char *sig );
+/**
+ * SSL/TLS wrappers for PKCS#11 functions
+ */
+static inline int ssl_pkcs11_decrypt( void *ctx, int mode, size_t *olen,
+ const unsigned char *input, unsigned char *output,
+ unsigned int output_max_len )
+{
+ return pkcs11_decrypt( (pkcs11_context *) ctx, mode, olen, input, output,
+ output_max_len );
+}
+
+static inline int ssl_pkcs11_sign( void *ctx,
+ int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+ int mode, int hash_id, unsigned int hashlen,
+ const unsigned char *hash, unsigned char *sig )
+{
+ ((void) f_rng);
+ ((void) p_rng);
+ return pkcs11_sign( (pkcs11_context *) ctx, mode, hash_id,
+ hashlen, hash, sig );
+}
+
+static inline size_t ssl_pkcs11_key_len( void *ctx )
+{
+ return ( (pkcs11_context *) ctx )->len;
+}
+
#endif /* POLARSSL_PKCS11_C */
#endif /* POLARSSL_PKCS11_H */
diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h
index fcf8a8ffe..62ffba2d3 100644
--- a/include/polarssl/ssl.h
+++ b/include/polarssl/ssl.h
@@ -42,10 +42,6 @@
#include "dhm.h"
#endif
-#if defined(POLARSSL_PKCS11_C)
-#include "pkcs11.h"
-#endif
-
#if defined(POLARSSL_ZLIB_SUPPORT)
#include "zlib.h"
#endif
@@ -253,6 +249,20 @@
#define TLS_EXT_RENEGOTIATION_INFO 0xFF01
+
+/*
+ * Generic function pointers for allowing external RSA private key
+ * implementations.
+ */
+typedef int (*rsa_decrypt_func)( void *ctx, int mode, size_t *olen,
+ const unsigned char *input, unsigned char *output,
+ size_t output_max_len );
+typedef int (*rsa_sign_func)( void *ctx,
+ int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+ int mode, int hash_id, unsigned int hashlen,
+ const unsigned char *hash, unsigned char *sig );
+typedef size_t (*rsa_key_len_func)( void *ctx );
+
/*
* SSL state machine
*/
@@ -446,10 +456,11 @@ struct _ssl_context
/*
* PKI layer
*/
- rsa_context *rsa_key; /*!< own RSA private key */
-#if defined(POLARSSL_PKCS11_C)
- pkcs11_context *pkcs11_key; /*!< own PKCS#11 RSA private key */
-#endif
+ void *rsa_key; /*!< own RSA private key */
+ rsa_decrypt_func rsa_decrypt; /*!< function for RSA decrypt*/
+ rsa_sign_func rsa_sign; /*!< function for RSA sign */
+ rsa_key_len_func rsa_key_len; /*!< function for RSA key len*/
+
x509_cert *own_cert; /*!< own X.509 certificate */
x509_cert *ca_chain; /*!< own trusted CA chain */
x509_crl *ca_crl; /*!< trusted CA CRLs */
@@ -722,17 +733,26 @@ void ssl_set_ca_chain( ssl_context *ssl, x509_cert *ca_chain,
void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert,
rsa_context *rsa_key );
-#if defined(POLARSSL_PKCS11_C)
/**
- * \brief Set own certificate and PKCS#11 private key
+ * \brief Set own certificate and alternate non-PolarSSL private
+ * key and handling callbacks, such as the PKCS#11 wrappers
+ * or any other external private key handler.
+ * (see the respective RSA functions in rsa.h for documentation
+ * of the callback parameters, with the only change being
+ * that the rsa_context * is a void * in the callbacks)
*
* \param ssl SSL context
* \param own_cert own public certificate
- * \param pkcs11_key own PKCS#11 RSA key
+ * \param rsa_key alternate implementation private RSA key
+ * \param rsa_decrypt_func alternate implementation of \c rsa_pkcs1_decrypt()
+ * \param rsa_sign_func alternate implementation of \c rsa_pkcs1_sign()
+ * \param rsa_key_len_func function returning length of RSA key in bytes
*/
-void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert,
- pkcs11_context *pkcs11_key );
-#endif
+void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert,
+ void *rsa_key,
+ rsa_decrypt_func rsa_decrypt,
+ rsa_sign_func rsa_sign,
+ rsa_key_len_func rsa_key_len );
#if defined(POLARSSL_DHM_C)
/**
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index b44af2ba3..3e1b0569f 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -30,10 +30,6 @@
#include "polarssl/debug.h"
#include "polarssl/ssl.h"
-#if defined(POLARSSL_PKCS11_C)
-#include "polarssl/pkcs11.h"
-#endif /* defined(POLARSSL_PKCS11_C) */
-
#include
#include
#include
@@ -1115,15 +1111,8 @@ static int ssl_write_certificate_verify( ssl_context *ssl )
if( ssl->rsa_key == NULL )
{
-#if defined(POLARSSL_PKCS11_C)
- if( ssl->pkcs11_key == NULL )
- {
-#endif /* defined(POLARSSL_PKCS11_C) */
- SSL_DEBUG_MSG( 1, ( "got no private key" ) );
- return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
- }
-#endif /* defined(POLARSSL_PKCS11_C) */
+ SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+ return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
}
/*
@@ -1132,11 +1121,7 @@ static int ssl_write_certificate_verify( ssl_context *ssl )
ssl->handshake->calc_verify( ssl, hash );
if ( ssl->rsa_key )
- n = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
- else
- n = ssl->pkcs11_key->len;
-#endif /* defined(POLARSSL_PKCS11_C) */
+ n = ssl->rsa_key_len ( ssl->rsa_key );
if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
{
@@ -1164,14 +1149,9 @@ static int ssl_write_certificate_verify( ssl_context *ssl )
if( ssl->rsa_key )
{
- ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
- RSA_PRIVATE, hash_id,
- hashlen, hash, ssl->out_msg + 6 + offset );
- } else {
-#if defined(POLARSSL_PKCS11_C)
- ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE, hash_id,
- hashlen, hash, ssl->out_msg + 6 + offset );
-#endif /* defined(POLARSSL_PKCS11_C) */
+ ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
+ RSA_PRIVATE, hash_id,
+ hashlen, hash, ssl->out_msg + 6 + offset );
}
if (ret != 0)
diff --git a/library/ssl_srv.c b/library/ssl_srv.c
index 64b0d2df4..e31145864 100644
--- a/library/ssl_srv.c
+++ b/library/ssl_srv.c
@@ -30,10 +30,6 @@
#include "polarssl/debug.h"
#include "polarssl/ssl.h"
-#if defined(POLARSSL_PKCS11_C)
-#include "polarssl/pkcs11.h"
-#endif /* defined(POLARSSL_PKCS11_C) */
-
#include
#include
#include
@@ -644,15 +640,8 @@ static int ssl_write_server_key_exchange( ssl_context *ssl )
if( ssl->rsa_key == NULL )
{
-#if defined(POLARSSL_PKCS11_C)
- if( ssl->pkcs11_key == NULL )
- {
-#endif /* defined(POLARSSL_PKCS11_C) */
- SSL_DEBUG_MSG( 1, ( "got no private key" ) );
- return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
- }
-#endif /* defined(POLARSSL_PKCS11_C) */
+ SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+ return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
}
/*
@@ -738,11 +727,7 @@ static int ssl_write_server_key_exchange( ssl_context *ssl )
SSL_DEBUG_BUF( 3, "parameters hash", hash, hashlen );
if ( ssl->rsa_key )
- rsa_key_len = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
- else
- rsa_key_len = ssl->pkcs11_key->len;
-#endif /* defined(POLARSSL_PKCS11_C) */
+ rsa_key_len = ssl->rsa_key_len( ssl->rsa_key );
if( ssl->minor_ver == SSL_MINOR_VERSION_3 )
{
@@ -758,16 +743,11 @@ static int ssl_write_server_key_exchange( ssl_context *ssl )
if ( ssl->rsa_key )
{
- ret = rsa_pkcs1_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
- RSA_PRIVATE,
- hash_id, hashlen, hash, ssl->out_msg + 6 + n );
+ ret = ssl->rsa_sign( ssl->rsa_key, ssl->f_rng, ssl->p_rng,
+ RSA_PRIVATE,
+ hash_id, hashlen, hash,
+ ssl->out_msg + 6 + n );
}
-#if defined(POLARSSL_PKCS11_C)
- else {
- ret = pkcs11_sign( ssl->pkcs11_key, RSA_PRIVATE,
- hash_id, hashlen, hash, ssl->out_msg + 6 + n );
- }
-#endif /* defined(POLARSSL_PKCS11_C) */
if( ret != 0 )
{
@@ -898,15 +878,8 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl )
{
if( ssl->rsa_key == NULL )
{
-#if defined(POLARSSL_PKCS11_C)
- if( ssl->pkcs11_key == NULL )
- {
-#endif
- SSL_DEBUG_MSG( 1, ( "got no private key" ) );
- return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
-#if defined(POLARSSL_PKCS11_C)
- }
-#endif
+ SSL_DEBUG_MSG( 1, ( "got no private key" ) );
+ return( POLARSSL_ERR_SSL_PRIVATE_KEY_REQUIRED );
}
/*
@@ -914,11 +887,7 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl )
*/
i = 4;
if( ssl->rsa_key )
- n = ssl->rsa_key->len;
-#if defined(POLARSSL_PKCS11_C)
- else
- n = ssl->pkcs11_key->len;
-#endif
+ n = ssl->rsa_key_len( ssl->rsa_key );
ssl->handshake->pmslen = 48;
if( ssl->minor_ver != SSL_MINOR_VERSION_0 )
@@ -939,21 +908,12 @@ static int ssl_parse_client_key_exchange( ssl_context *ssl )
}
if( ssl->rsa_key ) {
- ret = rsa_pkcs1_decrypt( ssl->rsa_key, RSA_PRIVATE,
- &ssl->handshake->pmslen,
- ssl->in_msg + i,
- ssl->handshake->premaster,
- sizeof(ssl->handshake->premaster) );
+ ret = ssl->rsa_decrypt( ssl->rsa_key, RSA_PRIVATE,
+ &ssl->handshake->pmslen,
+ ssl->in_msg + i,
+ ssl->handshake->premaster,
+ sizeof(ssl->handshake->premaster) );
}
-#if defined(POLARSSL_PKCS11_C)
- else {
- ret = pkcs11_decrypt( ssl->pkcs11_key, RSA_PRIVATE,
- &ssl->handshake->pmslen,
- ssl->in_msg + i,
- ssl->handshake->premaster,
- sizeof(ssl->handshake->premaster) );
- }
-#endif /* defined(POLARSSL_PKCS11_C) */
if( ret != 0 || ssl->handshake->pmslen != 48 ||
ssl->handshake->premaster[0] != ssl->max_major_ver ||
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 61920042b..cc0f65c57 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -65,6 +65,28 @@ int (*ssl_hw_record_read)(ssl_context *ssl) = NULL;
int (*ssl_hw_record_finish)(ssl_context *ssl) = NULL;
#endif
+static int ssl_rsa_decrypt( void *ctx, int mode, size_t *olen,
+ const unsigned char *input, unsigned char *output,
+ size_t output_max_len )
+{
+ return rsa_pkcs1_decrypt( (rsa_context *) ctx, mode, olen, input, output,
+ output_max_len );
+}
+
+static int ssl_rsa_sign( void *ctx,
+ int (*f_rng)(void *, unsigned char *, size_t), void *p_rng,
+ int mode, int hash_id, unsigned int hashlen,
+ const unsigned char *hash, unsigned char *sig )
+{
+ return rsa_pkcs1_sign( (rsa_context *) ctx, f_rng, p_rng, mode, hash_id,
+ hashlen, hash, sig );
+}
+
+static size_t ssl_rsa_key_len( void *ctx )
+{
+ return ( (rsa_context *) ctx )->len;
+}
+
/*
* Key material generation
*/
@@ -2826,6 +2848,10 @@ int ssl_init( ssl_context *ssl )
memset( ssl, 0, sizeof( ssl_context ) );
+ ssl->rsa_decrypt = ssl_rsa_decrypt;
+ ssl->rsa_sign = ssl_rsa_sign;
+ ssl->rsa_key_len = ssl_rsa_key_len;
+
ssl->in_ctr = (unsigned char *) malloc( len );
ssl->in_hdr = ssl->in_ctr + 8;
ssl->in_msg = ssl->in_ctr + 13;
@@ -3002,14 +3028,19 @@ void ssl_set_own_cert( ssl_context *ssl, x509_cert *own_cert,
ssl->rsa_key = rsa_key;
}
-#if defined(POLARSSL_PKCS11_C)
-void ssl_set_own_cert_pkcs11( ssl_context *ssl, x509_cert *own_cert,
- pkcs11_context *pkcs11_key )
+void ssl_set_own_cert_alt( ssl_context *ssl, x509_cert *own_cert,
+ void *rsa_key,
+ rsa_decrypt_func rsa_decrypt,
+ rsa_sign_func rsa_sign,
+ rsa_key_len_func rsa_key_len )
{
ssl->own_cert = own_cert;
- ssl->pkcs11_key = pkcs11_key;
+ ssl->rsa_key = rsa_key;
+ ssl->rsa_decrypt = rsa_decrypt;
+ ssl->rsa_sign = rsa_sign;
+ ssl->rsa_key_len = rsa_key_len;
}
-#endif
+
#if defined(POLARSSL_DHM_C)
int ssl_set_dh_param( ssl_context *ssl, const char *dhm_P, const char *dhm_G )