diff --git a/programs/test/selftest.c b/programs/test/selftest.c index 5c1d354a8..c46b8e758 100644 --- a/programs/test/selftest.c +++ b/programs/test/selftest.c @@ -60,12 +60,41 @@ #else #include #define mbedtls_printf printf +#define mbedtls_snprintf snprintf #endif #if defined(MBEDTLS_MEMORY_BUFFER_ALLOC_C) #include "mbedtls/memory_buffer_alloc.h" #endif +static int test_snprintf( size_t n, const char ref_buf[10], int ref_ret ) +{ + int ret; + char buf[10] = "xxxxxxxxx"; + + ret = mbedtls_snprintf( buf, n, "%s", "123" ); + if( ret < 0 || (size_t) ret >= n ) + ret = -1; + + if( memcmp( ref_buf, buf, sizeof buf ) != 0 || + ref_ret != ret ) + { + return( 1 ); + } + + return( 0 ); +} + +static int run_test_snprintf( void ) +{ + return( test_snprintf( 0, "xxxxxxxxx", -1 ) != 0 || + test_snprintf( 1, "\0xxxxxxxx", -1 ) != 0 || + test_snprintf( 2, "1\0xxxxxxx", -1 ) != 0 || + test_snprintf( 3, "12\0xxxxxx", -1 ) != 0 || + test_snprintf( 4, "123\0xxxxx", 3 ) != 0 || + test_snprintf( 5, "123\0xxxxx", 3 ) != 0 ); +} + int main( int argc, char *argv[] ) { int ret = 0, v; @@ -86,6 +115,15 @@ int main( int argc, char *argv[] ) return( 1 ); } + /* + * Make sure we have a snprintf that correctly zero-terminates + */ + if( run_test_snprintf() != 0 ) + { + mbedtls_printf( "the snprintf implementation is broken\n" ); + return( 0 ); + } + if( argc == 2 && strcmp( argv[1], "-quiet" ) == 0 ) v = 0; else diff --git a/tests/suites/main_test.function b/tests/suites/main_test.function index f1ef9175c..ba1328877 100644 --- a/tests/suites/main_test.function +++ b/tests/suites/main_test.function @@ -4,11 +4,13 @@ #include "mbedtls/platform.h" #else #include +#include #define mbedtls_exit exit #define mbedtls_free free -#define mbedtls_calloc calloc +#define mbedtls_calloc calloc #define mbedtls_fprintf fprintf #define mbedtls_printf printf +#define mbedtls_snprintf snprintf #endif #if defined(MBEDTLS_MEMORY_BUFFER_ALLOC_C) @@ -209,6 +211,34 @@ int parse_arguments( char *buf, size_t len, char *params[50] ) return( cnt ); } +static int test_snprintf( size_t n, const char ref_buf[10], int ref_ret ) +{ + int ret; + char buf[10] = "xxxxxxxxx"; + + ret = mbedtls_snprintf( buf, n, "%s", "123" ); + if( ret < 0 || (size_t) ret >= n ) + ret = -1; + + if( memcmp( ref_buf, buf, sizeof buf ) != 0 || + ref_ret != ret ) + { + return( 1 ); + } + + return( 0 ); +} + +static int run_test_snprintf( void ) +{ + return( test_snprintf( 0, "xxxxxxxxx", -1 ) != 0 || + test_snprintf( 1, "\0xxxxxxxx", -1 ) != 0 || + test_snprintf( 2, "1\0xxxxxxx", -1 ) != 0 || + test_snprintf( 3, "12\0xxxxxx", -1 ) != 0 || + test_snprintf( 4, "123\0xxxxx", 3 ) != 0 || + test_snprintf( 5, "123\0xxxxx", 3 ) != 0 ); +} + int main() { int ret, i, cnt, total_errors = 0, total_tests = 0, total_skipped = 0; @@ -236,6 +266,15 @@ int main() return( 1 ); } + /* + * Make sure we have a snprintf that correctly zero-terminates + */ + if( run_test_snprintf() != 0 ) + { + mbedtls_fprintf( stderr, "the snprintf implementation is broken\n" ); + return( 0 ); + } + file = fopen( filename, "r" ); if( file == NULL ) {