From 28d81a009cbbe999950b142731b71aea06fe0eb3 Mon Sep 17 00:00:00 2001 From: Steffan Karger Date: Wed, 13 Nov 2013 16:57:58 +0100 Subject: [PATCH] Fix pkcs11.c to conform to PolarSSL 1.3 API. This restores previous functionality, and thus still allows only RSA to be used through PKCS#11. Signed-off-by: Steffan Karger Signed-off-by: Paul Bakker --- include/polarssl/pk.h | 2 +- include/polarssl/pkcs11.h | 6 +-- include/polarssl/ssl.h | 2 +- library/pkcs11.c | 111 ++++++++++++++++++-------------------- 4 files changed, 56 insertions(+), 65 deletions(-) diff --git a/include/polarssl/pk.h b/include/polarssl/pk.h index 251c690e5..958672b0c 100644 --- a/include/polarssl/pk.h +++ b/include/polarssl/pk.h @@ -188,7 +188,7 @@ typedef int (*pk_rsa_alt_decrypt_func)( void *ctx, int mode, size_t *olen, size_t output_max_len ); typedef int (*pk_rsa_alt_sign_func)( void *ctx, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, - int mode, int hash_id, unsigned int hashlen, + int mode, md_type_t md_alg, unsigned int hashlen, const unsigned char *hash, unsigned char *sig ); typedef size_t (*pk_rsa_alt_key_len_func)( void *ctx ); diff --git a/include/polarssl/pkcs11.h b/include/polarssl/pkcs11.h index c0515e67c..707d00a81 100644 --- a/include/polarssl/pkcs11.h +++ b/include/polarssl/pkcs11.h @@ -128,7 +128,7 @@ int pkcs11_decrypt( pkcs11_context *ctx, */ int pkcs11_sign( pkcs11_context *ctx, int mode, - int hash_id, + md_type_t md_alg, unsigned int hashlen, const unsigned char *hash, unsigned char *sig ); @@ -146,12 +146,12 @@ static inline int ssl_pkcs11_decrypt( void *ctx, int mode, size_t *olen, static inline int ssl_pkcs11_sign( void *ctx, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, - int mode, int hash_id, unsigned int hashlen, + int mode, md_type_t md_alg, unsigned int hashlen, const unsigned char *hash, unsigned char *sig ) { ((void) f_rng); ((void) p_rng); - return pkcs11_sign( (pkcs11_context *) ctx, mode, hash_id, + return pkcs11_sign( (pkcs11_context *) ctx, mode, md_alg, hashlen, hash, sig ); } diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index e51e5078d..1608df30e 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -374,7 +374,7 @@ typedef int (*rsa_decrypt_func)( void *ctx, int mode, size_t *olen, size_t output_max_len ); typedef int (*rsa_sign_func)( void *ctx, int (*f_rng)(void *, unsigned char *, size_t), void *p_rng, - int mode, int hash_id, unsigned int hashlen, + int mode, md_type_t md_alg, unsigned int hashlen, const unsigned char *hash, unsigned char *sig ); typedef size_t (*rsa_key_len_func)( void *ctx ); diff --git a/library/pkcs11.c b/library/pkcs11.c index 9f68d782a..8a99f2871 100644 --- a/library/pkcs11.c +++ b/library/pkcs11.c @@ -30,6 +30,9 @@ #include "polarssl/pkcs11.h" #if defined(POLARSSL_PKCS11_C) +#include "polarssl/md.h" +#include "polarssl/oid.h" +#include "polarssl/x509_crt.h" #if defined(POLARSSL_MEMORY_C) #include "polarssl/memory.h" @@ -101,7 +104,7 @@ int pkcs11_priv_key_init( pkcs11_context *priv_key, if( 0 != pkcs11_x509_cert_init( &cert, pkcs11_cert ) ) goto cleanup; - priv_key->len = cert.rsa.len; + priv_key->len = pk_get_len(&cert.pk); priv_key->pkcs11h_cert = pkcs11_cert; ret = 0; @@ -129,7 +132,7 @@ int pkcs11_decrypt( pkcs11_context *ctx, if( NULL == ctx ) return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); - if( RSA_PUBLIC == mode ) + if( RSA_PRIVATE != mode ) return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); output_len = input_len = ctx->len; @@ -158,79 +161,67 @@ int pkcs11_decrypt( pkcs11_context *ctx, int pkcs11_sign( pkcs11_context *ctx, int mode, - int hash_id, + md_type_t md_alg, unsigned int hashlen, const unsigned char *hash, unsigned char *sig ) { - size_t olen, asn_len; + size_t olen, asn_len = 0, oid_size = 0; unsigned char *p = sig; + const char *oid; if( NULL == ctx ) return POLARSSL_ERR_RSA_BAD_INPUT_DATA; - if( RSA_PUBLIC == mode ) + if( RSA_PRIVATE != mode ) return POLARSSL_ERR_RSA_BAD_INPUT_DATA; olen = ctx->len; - switch( hash_id ) + if( md_alg != POLARSSL_MD_NONE ) { - case SIG_RSA_RAW: - asn_len = 0; - memcpy( p, hash, hashlen ); - break; - - case SIG_RSA_MD2: - asn_len = OID_SIZE(ASN1_HASH_MDX); - memcpy( p, ASN1_HASH_MDX, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[13] = 2; break; - - case SIG_RSA_MD4: - asn_len = OID_SIZE(ASN1_HASH_MDX); - memcpy( p, ASN1_HASH_MDX, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[13] = 4; break; - - case SIG_RSA_MD5: - asn_len = OID_SIZE(ASN1_HASH_MDX); - memcpy( p, ASN1_HASH_MDX, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[13] = 5; break; - - case SIG_RSA_SHA1: - asn_len = OID_SIZE(ASN1_HASH_SHA1); - memcpy( p, ASN1_HASH_SHA1, asn_len ); - memcpy( p + 15, hash, hashlen ); - break; - - case SIG_RSA_SHA224: - asn_len = OID_SIZE(ASN1_HASH_SHA2X); - memcpy( p, ASN1_HASH_SHA2X, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[1] += hashlen; p[14] = 4; p[18] += hashlen; break; - - case SIG_RSA_SHA256: - asn_len = OID_SIZE(ASN1_HASH_SHA2X); - memcpy( p, ASN1_HASH_SHA2X, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[1] += hashlen; p[14] = 1; p[18] += hashlen; break; - - case SIG_RSA_SHA384: - asn_len = OID_SIZE(ASN1_HASH_SHA2X); - memcpy( p, ASN1_HASH_SHA2X, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[1] += hashlen; p[14] = 2; p[18] += hashlen; break; - - case SIG_RSA_SHA512: - asn_len = OID_SIZE(ASN1_HASH_SHA2X); - memcpy( p, ASN1_HASH_SHA2X, asn_len ); - memcpy( p + asn_len, hash, hashlen ); - p[1] += hashlen; p[14] = 3; p[18] += hashlen; break; - - default: + const md_info_t *md_info = md_info_from_type( md_alg ); + if( md_info == NULL ) return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); + + if( oid_get_oid_by_md( md_alg, &oid, &oid_size ) != 0 ) + return( POLARSSL_ERR_RSA_BAD_INPUT_DATA ); + + hashlen = md_get_size( md_info ); + } + + if( md_alg == POLARSSL_MD_NONE ) + { + memcpy( p, hash, hashlen ); + } + else + { + /* + * DigestInfo ::= SEQUENCE { + * digestAlgorithm DigestAlgorithmIdentifier, + * digest Digest } + * + * DigestAlgorithmIdentifier ::= AlgorithmIdentifier + * + * Digest ::= OCTET STRING + */ + *p++ = ASN1_SEQUENCE | ASN1_CONSTRUCTED; + *p++ = (unsigned char) ( 0x08 + oid_size + hashlen ); + *p++ = ASN1_SEQUENCE | ASN1_CONSTRUCTED; + *p++ = (unsigned char) ( 0x04 + oid_size ); + *p++ = ASN1_OID; + *p++ = oid_size & 0xFF; + memcpy( p, oid, oid_size ); + p += oid_size; + *p++ = ASN1_NULL; + *p++ = 0x00; + *p++ = ASN1_OCTET_STRING; + *p++ = hashlen; + + /* Determine added ASN length */ + asn_len = p - sig; + + memcpy( p, hash, hashlen ); } if( pkcs11h_certificate_signAny( ctx->pkcs11h_cert, CKM_RSA_PKCS, sig,