#!/usr/bin/env python3 """Test the program psa_constant_names. Gather constant names from header files and test cases. Compile a C program to print out their numerical values, feed these numerical values to psa_constant_names, and check that the output is the original name. Return 0 if all test cases pass, 1 if the output was not always as expected, or 1 (with a Python backtrace) if there was an operational error. """ import argparse import itertools import os import platform import re import subprocess import sys import tempfile class ReadFileLineException(Exception): def __init__(self, filename, line_number): message = 'in {} at {}'.format(filename, line_number) super(ReadFileLineException, self).__init__(message) self.filename = filename self.line_number = line_number class read_file_lines: # Dear Pylint, conventionally, a context manager class name is lowercase. # pylint: disable=invalid-name,too-few-public-methods """Context manager to read a text file line by line. ``` with read_file_lines(filename) as lines: for line in lines: process(line) ``` is equivalent to ``` with open(filename, 'r') as input_file: for line in input_file: process(line) ``` except that if process(line) raises an exception, then the read_file_lines snippet annotates the exception with the file name and line number. """ def __init__(self, filename): self.filename = filename self.line_number = 'entry' self.generator = None def __enter__(self): self.generator = enumerate(open(self.filename, 'r')) return self def __iter__(self): for line_number, content in self.generator: self.line_number = line_number yield content self.line_number = 'exit' def __exit__(self, exc_type, exc_value, exc_traceback): if exc_type is not None: raise ReadFileLineException(self.filename, self.line_number) \ from exc_value class Inputs: """Accumulate information about macros to test. This includes macro names as well as information about their arguments when applicable. """ def __init__(self): # Sets of names per type self.statuses = set(['PSA_SUCCESS']) self.algorithms = set(['0xffffffff']) self.ecc_curves = set(['0xffff']) self.dh_groups = set(['0xffff']) self.key_types = set(['0xffffffff']) self.key_usage_flags = set(['0x80000000']) # Hard-coded value for unknown algorithms self.hash_algorithms = set(['0x010000fe']) self.mac_algorithms = set(['0x02ff00ff']) self.ka_algorithms = set(['0x30fc0000']) self.kdf_algorithms = set(['0x200000ff']) # For AEAD algorithms, the only variability is over the tag length, # and this only applies to known algorithms, so don't test an # unknown algorithm. self.aead_algorithms = set() # Identifier prefixes self.table_by_prefix = { 'ERROR': self.statuses, 'ALG': self.algorithms, 'CURVE': self.ecc_curves, 'GROUP': self.dh_groups, 'KEY_TYPE': self.key_types, 'KEY_USAGE': self.key_usage_flags, } # macro name -> list of argument names self.argspecs = {} # argument name -> list of values self.arguments_for = { 'mac_length': ['1', '63'], 'tag_length': ['1', '63'], } def gather_arguments(self): """Populate the list of values for macro arguments. Call this after parsing all the inputs. """ self.arguments_for['hash_alg'] = sorted(self.hash_algorithms) self.arguments_for['mac_alg'] = sorted(self.mac_algorithms) self.arguments_for['ka_alg'] = sorted(self.ka_algorithms) self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms) self.arguments_for['aead_alg'] = sorted(self.aead_algorithms) self.arguments_for['curve'] = sorted(self.ecc_curves) self.arguments_for['group'] = sorted(self.dh_groups) @staticmethod def _format_arguments(name, arguments): """Format a macro call with arguments..""" return name + '(' + ', '.join(arguments) + ')' def distribute_arguments(self, name): """Generate macro calls with each tested argument set. If name is a macro without arguments, just yield "name". If name is a macro with arguments, yield a series of "name(arg1,...,argN)" where each argument takes each possible value at least once. """ try: if name not in self.argspecs: yield name return argspec = self.argspecs[name] if argspec == []: yield name + '()' return argument_lists = [self.arguments_for[arg] for arg in argspec] arguments = [values[0] for values in argument_lists] yield self._format_arguments(name, arguments) # Dear Pylint, enumerate won't work here since we're modifying # the array. # pylint: disable=consider-using-enumerate for i in range(len(arguments)): for value in argument_lists[i][1:]: arguments[i] = value yield self._format_arguments(name, arguments) arguments[i] = argument_lists[0][0] except BaseException as e: raise Exception('distribute_arguments({})'.format(name)) from e _argument_split_re = re.compile(r' *, *') @classmethod def _argument_split(cls, arguments): return re.split(cls._argument_split_re, arguments) # Regex for interesting header lines. # Groups: 1=macro name, 2=type, 3=argument list (optional). _header_line_re = \ re.compile(r'#define +' + r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' + r'(?:\(([^\n()]*)\))?') # Regex of macro names to exclude. _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') # Additional excluded macros. _excluded_names = set([ # Macros that provide an alternative way to build the same # algorithm as another macro. 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH', 'PSA_ALG_FULL_LENGTH_MAC', # Auxiliary macro whose name doesn't fit the usual patterns for # auxiliary macros. 'PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH_CASE', # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script # currently doesn't support them. 'PSA_ALG_ECDH', 'PSA_ALG_FFDH', # Deprecated aliases. 'PSA_ERROR_UNKNOWN_ERROR', 'PSA_ERROR_OCCUPIED_SLOT', 'PSA_ERROR_EMPTY_SLOT', 'PSA_ERROR_INSUFFICIENT_CAPACITY', 'PSA_ERROR_TAMPERING_DETECTED', ]) def parse_header_line(self, line): """Parse a C header line, looking for "#define PSA_xxx".""" m = re.match(self._header_line_re, line) if not m: return name = m.group(1) if re.search(self._excluded_name_re, name) or \ name in self._excluded_names: return dest = self.table_by_prefix.get(m.group(2)) if dest is None: return dest.add(name) if m.group(3): self.argspecs[name] = self._argument_split(m.group(3)) def parse_header(self, filename): """Parse a C header file, looking for "#define PSA_xxx".""" with read_file_lines(filename) as lines: for line in lines: self.parse_header_line(line) def add_test_case_line(self, function, argument): """Parse a test case data line, looking for algorithm metadata tests.""" if function.endswith('_algorithm'): # As above, ECDH and FFDH algorithms are excluded for now. # Support for them will be added in the future. if 'ECDH' in argument or 'FFDH' in argument: return self.algorithms.add(argument) if function == 'hash_algorithm': self.hash_algorithms.add(argument) elif function in ['mac_algorithm', 'hmac_algorithm']: self.mac_algorithms.add(argument) elif function == 'aead_algorithm': self.aead_algorithms.add(argument) elif function == 'key_type': self.key_types.add(argument) elif function == 'ecc_key_types': self.ecc_curves.add(argument) elif function == 'dh_key_types': self.dh_groups.add(argument) # Regex matching a *.data line containing a test function call and # its arguments. The actual definition is partly positional, but this # regex is good enough in practice. _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') def parse_test_cases(self, filename): """Parse a test case file (*.data), looking for algorithm metadata tests.""" with read_file_lines(filename) as lines: for line in lines: m = re.match(self._test_case_line_re, line) if m: self.add_test_case_line(m.group(1), m.group(2)) def gather_inputs(headers, test_suites): """Read the list of inputs to test psa_constant_names with.""" inputs = Inputs() for header in headers: inputs.parse_header(header) for test_cases in test_suites: inputs.parse_test_cases(test_cases) inputs.gather_arguments() return inputs def remove_file_if_exists(filename): """Remove the specified file, ignoring errors.""" if not filename: return try: os.remove(filename) except OSError: pass def run_c(options, type_word, names): """Generate and run a program to print out numerical values for names.""" if type_word == 'status': cast_to = 'long' printf_format = '%ld' else: cast_to = 'unsigned long' printf_format = '0x%08lx' c_name = None exe_name = None try: c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type_word), suffix='.c', dir='programs/psa') exe_suffix = '.exe' if platform.system() == 'Windows' else '' exe_name = c_name[:-2] + exe_suffix remove_file_if_exists(exe_name) c_file = os.fdopen(c_fd, 'w', encoding='ascii') c_file.write('/* Generated by test_psa_constant_names.py for {} values */' .format(type_word)) c_file.write(''' #include #include int main(void) { ''') for name in names: c_file.write(' printf("{}\\n", ({}) {});\n' .format(printf_format, cast_to, name)) c_file.write(''' return 0; } ''') c_file.close() cc = os.getenv('CC', 'cc') subprocess.check_call([cc] + ['-I' + dir for dir in options.include] + ['-o', exe_name, c_name]) if options.keep_c: sys.stderr.write('List of {} tests kept at {}\n' .format(type_word, c_name)) else: os.remove(c_name) output = subprocess.check_output([exe_name]) return output.decode('ascii').strip().split('\n') finally: remove_file_if_exists(exe_name) NORMALIZE_STRIP_RE = re.compile(r'\s+') def normalize(expr): """Normalize the C expression so as not to care about trivial differences. Currently "trivial differences" means whitespace. """ expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr)) return expr.strip().split('\n') def do_test(options, inputs, type_word, names): """Test psa_constant_names for the specified type. Run program on names. Use inputs to figure out what arguments to pass to macros that take arguments. """ names = sorted(itertools.chain(*map(inputs.distribute_arguments, names))) values = run_c(options, type_word, names) output = subprocess.check_output([options.program, type_word] + values) outputs = output.decode('ascii').strip().split('\n') errors = [(type_word, name, value, output) for (name, value, output) in zip(names, values, outputs) if normalize(name) != normalize(output)] return len(names), errors def report_errors(errors): """Describe each case where the output is not as expected.""" for type_word, name, value, output in errors: print('For {} "{}", got "{}" (value: {})' .format(type_word, name, output, value)) def run_tests(options, inputs): """Run psa_constant_names on all the gathered inputs. Return a tuple (count, errors) where count is the total number of inputs that were tested and errors is the list of cases where the output was not as expected. """ count = 0 errors = [] for type_word, names in [('status', inputs.statuses), ('algorithm', inputs.algorithms), ('ecc_curve', inputs.ecc_curves), ('dh_group', inputs.dh_groups), ('key_type', inputs.key_types), ('key_usage', inputs.key_usage_flags)]: c, e = do_test(options, inputs, type_word, names) count += c errors += e return count, errors def main(): parser = argparse.ArgumentParser(description=globals()['__doc__']) parser.add_argument('--include', '-I', action='append', default=['include'], help='Directory for header files') parser.add_argument('--program', default='programs/psa/psa_constant_names', help='Program to test') parser.add_argument('--keep-c', action='store_true', dest='keep_c', default=False, help='Keep the intermediate C file') parser.add_argument('--no-keep-c', action='store_false', dest='keep_c', help='Don\'t keep the intermediate C file (default)') options = parser.parse_args() headers = [os.path.join(options.include[0], 'psa', h) for h in ['crypto.h', 'crypto_extra.h', 'crypto_values.h']] test_suites = ['tests/suites/test_suite_psa_crypto_metadata.data'] inputs = gather_inputs(headers, test_suites) count, errors = run_tests(options, inputs) report_errors(errors) if errors == []: print('{} test cases PASS'.format(count)) else: print('{} test cases, {} FAIL'.format(count, len(errors))) exit(1) if __name__ == '__main__': main()