From 42a0a0aeeab7bf793c510e7c8618fb248855e2a7 Mon Sep 17 00:00:00 2001 From: Gilles Peskine Date: Mon, 27 May 2019 18:29:47 +0200 Subject: [PATCH] Obey Python naming and method structure conventions * Rename internal methods and fields to start with an underscore. * Rename global constants to uppercase. * Change methods that don't use self to be class methods or static methods as appropriate. No behavior change in this commit. --- scripts/generate_psa_constants.py | 112 ++++++++++++----------- tests/scripts/test_psa_constant_names.py | 92 ++++++++++--------- 2 files changed, 106 insertions(+), 98 deletions(-) diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py index 997bd3c95..d772a77fe 100755 --- a/scripts/generate_psa_constants.py +++ b/scripts/generate_psa_constants.py @@ -10,7 +10,7 @@ import os import re import sys -output_template = '''\ +OUTPUT_TEMPLATE = '''\ /* Automatically generated by generate_psa_constant.py. DO NOT EDIT. */ static const char *psa_strerror(psa_status_t status) @@ -154,19 +154,19 @@ static int psa_snprint_key_usage(char *buffer, size_t buffer_size, /* End of automatically generated file. */ ''' -key_type_from_curve_template = '''if (%(tester)s(type)) { +KEY_TYPE_FROM_CURVE_TEMPLATE = '''if (%(tester)s(type)) { append_with_curve(&buffer, buffer_size, &required_size, "%(builder)s", %(builder_length)s, PSA_KEY_TYPE_GET_CURVE(type)); } else ''' -key_type_from_group_template = '''if (%(tester)s(type)) { +KEY_TYPE_FROM_GROUP_TEMPLATE = '''if (%(tester)s(type)) { append_with_group(&buffer, buffer_size, &required_size, "%(builder)s", %(builder_length)s, PSA_KEY_TYPE_GET_GROUP(type)); } else ''' -algorithm_from_hash_template = '''if (%(tester)s(core_alg)) { +ALGORITHM_FROM_HASH_TEMPLATE = '''if (%(tester)s(core_alg)) { append(&buffer, buffer_size, &required_size, "%(builder)s(", %(builder_length)s + 1); append_with_alg(&buffer, buffer_size, &required_size, @@ -175,7 +175,7 @@ algorithm_from_hash_template = '''if (%(tester)s(core_alg)) { append(&buffer, buffer_size, &required_size, ")", 1); } else ''' -bit_test_template = '''\ +BIT_TEST_TEMPLATE = '''\ if (%(var)s & %(flag)s) { if (required_size != 0) { append(&buffer, buffer_size, &required_size, " | ", 3); @@ -274,102 +274,104 @@ class MacroCollector: for line in header_file: self.read_line(line) - def make_return_case(self, name): + @staticmethod + def _make_return_case(name): return 'case %(name)s: return "%(name)s";' % {'name': name} - def make_append_case(self, name): + @staticmethod + def _make_append_case(name): template = ('case %(name)s: ' 'append(&buffer, buffer_size, &required_size, "%(name)s", %(length)d); ' 'break;') return template % {'name': name, 'length': len(name)} - def make_inner_append_case(self, name): - template = ('case %(name)s: ' - 'append(buffer, buffer_size, required_size, "%(name)s", %(length)d); ' - 'break;') - return template % {'name': name, 'length': len(name)} - - def make_bit_test(self, var, flag): - return bit_test_template % {'var': var, + @staticmethod + def _make_bit_test(var, flag): + return BIT_TEST_TEMPLATE % {'var': var, 'flag': flag, 'length': len(flag)} - def make_status_cases(self): - return '\n '.join(map(self.make_return_case, + def _make_status_cases(self): + return '\n '.join(map(self._make_return_case, sorted(self.statuses))) - def make_ecc_curve_cases(self): - return '\n '.join(map(self.make_return_case, + def _make_ecc_curve_cases(self): + return '\n '.join(map(self._make_return_case, sorted(self.ecc_curves))) - def make_dh_group_cases(self): - return '\n '.join(map(self.make_return_case, + def _make_dh_group_cases(self): + return '\n '.join(map(self._make_return_case, sorted(self.dh_groups))) - def make_key_type_cases(self): - return '\n '.join(map(self.make_append_case, + def _make_key_type_cases(self): + return '\n '.join(map(self._make_append_case, sorted(self.key_types))) - def make_key_type_from_curve_code(self, builder, tester): - return key_type_from_curve_template % {'builder': builder, + @staticmethod + def _make_key_type_from_curve_code(builder, tester): + return KEY_TYPE_FROM_CURVE_TEMPLATE % {'builder': builder, 'builder_length': len(builder), 'tester': tester} - def make_key_type_from_group_code(self, builder, tester): - return key_type_from_group_template % {'builder': builder, + @staticmethod + def _make_key_type_from_group_code(builder, tester): + return KEY_TYPE_FROM_GROUP_TEMPLATE % {'builder': builder, 'builder_length': len(builder), 'tester': tester} - def make_ecc_key_type_code(self): + def _make_ecc_key_type_code(self): d = self.key_types_from_curve - make = self.make_key_type_from_curve_code + make = self._make_key_type_from_curve_code return ''.join([make(k, d[k]) for k in sorted(d.keys())]) - def make_dh_key_type_code(self): + def _make_dh_key_type_code(self): d = self.key_types_from_group - make = self.make_key_type_from_group_code + make = self._make_key_type_from_group_code return ''.join([make(k, d[k]) for k in sorted(d.keys())]) - def make_hash_algorithm_cases(self): - return '\n '.join(map(self.make_return_case, + def _make_hash_algorithm_cases(self): + return '\n '.join(map(self._make_return_case, sorted(self.hash_algorithms))) - def make_ka_algorithm_cases(self): - return '\n '.join(map(self.make_return_case, + def _make_ka_algorithm_cases(self): + return '\n '.join(map(self._make_return_case, sorted(self.ka_algorithms))) - def make_algorithm_cases(self): - return '\n '.join(map(self.make_append_case, + def _make_algorithm_cases(self): + return '\n '.join(map(self._make_append_case, sorted(self.algorithms))) - def make_algorithm_from_hash_code(self, builder, tester): - return algorithm_from_hash_template % {'builder': builder, + @staticmethod + def _make_algorithm_from_hash_code(builder, tester): + return ALGORITHM_FROM_HASH_TEMPLATE % {'builder': builder, 'builder_length': len(builder), 'tester': tester} - def make_algorithm_code(self): + def _make_algorithm_code(self): d = self.algorithms_from_hash - make = self.make_algorithm_from_hash_code + make = self._make_algorithm_from_hash_code return ''.join([make(k, d[k]) for k in sorted(d.keys())]) - def make_key_usage_code(self): - return '\n'.join([self.make_bit_test('usage', bit) + def _make_key_usage_code(self): + return '\n'.join([self._make_bit_test('usage', bit) for bit in sorted(self.key_usages)]) def write_file(self, output_file): + """Generate the pretty-printer function code from the gathered + constant definitions.""" data = {} - data['status_cases'] = self.make_status_cases() - data['ecc_curve_cases'] = self.make_ecc_curve_cases() - data['dh_group_cases'] = self.make_dh_group_cases() - data['key_type_cases'] = self.make_key_type_cases() - data['key_type_code'] = (self.make_ecc_key_type_code() + - self.make_dh_key_type_code()) - data['hash_algorithm_cases'] = self.make_hash_algorithm_cases() - data['ka_algorithm_cases'] = self.make_ka_algorithm_cases() - data['algorithm_cases'] = self.make_algorithm_cases() - data['algorithm_code'] = self.make_algorithm_code() - data['key_usage_code'] = self.make_key_usage_code() - output_file.write(output_template % data) + data['status_cases'] = self._make_status_cases() + data['ecc_curve_cases'] = self._make_ecc_curve_cases() + data['dh_group_cases'] = self._make_dh_group_cases() + data['key_type_cases'] = self._make_key_type_cases() + data['key_type_code'] = (self._make_ecc_key_type_code() + + self._make_dh_key_type_code()) + data['hash_algorithm_cases'] = self._make_hash_algorithm_cases() + data['ka_algorithm_cases'] = self._make_ka_algorithm_cases() + data['algorithm_cases'] = self._make_algorithm_cases() + data['algorithm_code'] = self._make_algorithm_code() + data['key_usage_code'] = self._make_key_usage_code() + output_file.write(OUTPUT_TEMPLATE % data) def generate_psa_constants(header_file_names, output_file_name): collector = MacroCollector() diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py index cbe68b10d..43056bca3 100755 --- a/tests/scripts/test_psa_constant_names.py +++ b/tests/scripts/test_psa_constant_names.py @@ -44,10 +44,10 @@ snippet annotates the exception with the file name and line number.''' self.line_number = line_number yield content self.line_number = 'exit' - def __exit__(self, type, value, traceback): - if type is not None: + def __exit__(self, exc_type, exc_value, exc_traceback): + if exc_type is not None: raise ReadFileLineException(self.filename, self.line_number) \ - from value + from exc_value class Inputs: '''Accumulate information about macros to test. @@ -98,7 +98,8 @@ Call this after parsing all the inputs.''' self.arguments_for['curve'] = sorted(self.ecc_curves) self.arguments_for['group'] = sorted(self.dh_groups) - def format_arguments(self, name, arguments): + @staticmethod + def _format_arguments(name, arguments): '''Format a macro call with arguments..''' return name + '(' + ', '.join(arguments) + ')' @@ -117,51 +118,56 @@ where each argument takes each possible value at least once.''' return argument_lists = [self.arguments_for[arg] for arg in argspec] arguments = [values[0] for values in argument_lists] - yield self.format_arguments(name, arguments) + yield self._format_arguments(name, arguments) for i in range(len(arguments)): for value in argument_lists[i][1:]: arguments[i] = value - yield self.format_arguments(name, arguments) + yield self._format_arguments(name, arguments) arguments[i] = argument_lists[0][0] except BaseException as e: raise Exception('distribute_arguments({})'.format(name)) from e + _argument_split_re = re.compile(r' *, *') + @classmethod + def _argument_split(cls, arguments): + return re.split(cls._argument_split_re, arguments) + # Regex for interesting header lines. # Groups: 1=macro name, 2=type, 3=argument list (optional). - header_line_re = \ + _header_line_re = \ re.compile(r'#define +' + r'(PSA_((?:KEY_)?[A-Z]+)_\w+)' + r'(?:\(([^\n()]*)\))?') # Regex of macro names to exclude. - excluded_name_re = re.compile('_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') + _excluded_name_re = re.compile(r'_(?:GET|IS|OF)_|_(?:BASE|FLAG|MASK)\Z') # Additional excluded macros. # PSA_ALG_ECDH and PSA_ALG_FFDH are excluded for now as the script # currently doesn't support them. Deprecated errors are also excluded. - excluded_names = set(['PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH', - 'PSA_ALG_FULL_LENGTH_MAC', - 'PSA_ALG_ECDH', - 'PSA_ALG_FFDH', - 'PSA_ERROR_UNKNOWN_ERROR', - 'PSA_ERROR_OCCUPIED_SLOT', - 'PSA_ERROR_EMPTY_SLOT', - 'PSA_ERROR_INSUFFICIENT_CAPACITY', + _excluded_names = set(['PSA_ALG_AEAD_WITH_DEFAULT_TAG_LENGTH', + 'PSA_ALG_FULL_LENGTH_MAC', + 'PSA_ALG_ECDH', + 'PSA_ALG_FFDH', + 'PSA_ERROR_UNKNOWN_ERROR', + 'PSA_ERROR_OCCUPIED_SLOT', + 'PSA_ERROR_EMPTY_SLOT', + 'PSA_ERROR_INSUFFICIENT_CAPACITY', ]) - argument_split_re = re.compile(r' *, *') + def parse_header_line(self, line): '''Parse a C header line, looking for "#define PSA_xxx".''' - m = re.match(self.header_line_re, line) + m = re.match(self._header_line_re, line) if not m: return name = m.group(1) - if re.search(self.excluded_name_re, name) or \ - name in self.excluded_names: + if re.search(self._excluded_name_re, name) or \ + name in self._excluded_names: return dest = self.table_by_prefix.get(m.group(2)) if dest is None: return dest.add(name) if m.group(3): - self.argspecs[name] = re.split(self.argument_split_re, m.group(3)) + self.argspecs[name] = self._argument_split(m.group(3)) def parse_header(self, filename): '''Parse a C header file, looking for "#define PSA_xxx".''' @@ -193,12 +199,12 @@ where each argument takes each possible value at least once.''' # Regex matching a *.data line containing a test function call and # its arguments. The actual definition is partly positional, but this # regex is good enough in practice. - test_case_line_re = re.compile('(?!depends_on:)(\w+):([^\n :][^:\n]*)') + _test_case_line_re = re.compile(r'(?!depends_on:)(\w+):([^\n :][^:\n]*)') def parse_test_cases(self, filename): '''Parse a test case file (*.data), looking for algorithm metadata tests.''' with read_file_lines(filename) as lines: for line in lines: - m = re.match(self.test_case_line_re, line) + m = re.match(self._test_case_line_re, line) if m: self.add_test_case_line(m.group(1), m.group(2)) @@ -221,9 +227,9 @@ def remove_file_if_exists(filename): except: pass -def run_c(options, type, names): +def run_c(options, type_word, names): '''Generate and run a program to print out numerical values for names.''' - if type == 'status': + if type_word == 'status': cast_to = 'long' printf_format = '%ld' else: @@ -232,7 +238,7 @@ def run_c(options, type, names): c_name = None exe_name = None try: - c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type), + c_fd, c_name = tempfile.mkstemp(prefix='tmp-{}-'.format(type_word), suffix='.c', dir='programs/psa') exe_suffix = '.exe' if platform.system() == 'Windows' else '' @@ -240,7 +246,7 @@ def run_c(options, type, names): remove_file_if_exists(exe_name) c_file = os.fdopen(c_fd, 'w', encoding='ascii') c_file.write('/* Generated by test_psa_constant_names.py for {} values */' - .format(type)) + .format(type_word)) c_file.write(''' #include #include @@ -260,7 +266,7 @@ int main(void) ['-o', exe_name, c_name]) if options.keep_c: sys.stderr.write('List of {} tests kept at {}\n' - .format(type, c_name)) + .format(type_word, c_name)) else: os.remove(c_name) output = subprocess.check_output([exe_name]) @@ -268,31 +274,31 @@ int main(void) finally: remove_file_if_exists(exe_name) -normalize_strip_re = re.compile(r'\s+') +NORMALIZE_STRIP_RE = re.compile(r'\s+') def normalize(expr): '''Normalize the C expression so as not to care about trivial differences. Currently "trivial differences" means whitespace.''' - expr = re.sub(normalize_strip_re, '', expr, len(expr)) + expr = re.sub(NORMALIZE_STRIP_RE, '', expr, len(expr)) return expr.strip().split('\n') -def do_test(options, inputs, type, names): +def do_test(options, inputs, type_word, names): '''Test psa_constant_names for the specified type. Run program on names. Use inputs to figure out what arguments to pass to macros that take arguments.''' names = sorted(itertools.chain(*map(inputs.distribute_arguments, names))) - values = run_c(options, type, names) - output = subprocess.check_output([options.program, type] + values) + values = run_c(options, type_word, names) + output = subprocess.check_output([options.program, type_word] + values) outputs = output.decode('ascii').strip().split('\n') - errors = [(type, name, value, output) + errors = [(type_word, name, value, output) for (name, value, output) in zip(names, values, outputs) if normalize(name) != normalize(output)] return len(names), errors def report_errors(errors): '''Describe each case where the output is not as expected.''' - for type, name, value, output in errors: + for type_word, name, value, output in errors: print('For {} "{}", got "{}" (value: {})' - .format(type, name, output, value)) + .format(type_word, name, output, value)) def run_tests(options, inputs): '''Run psa_constant_names on all the gathered inputs. @@ -301,13 +307,13 @@ that were tested and errors is the list of cases where the output was not as expected.''' count = 0 errors = [] - for type, names in [('status', inputs.statuses), - ('algorithm', inputs.algorithms), - ('ecc_curve', inputs.ecc_curves), - ('dh_group', inputs.dh_groups), - ('key_type', inputs.key_types), - ('key_usage', inputs.key_usage_flags)]: - c, e = do_test(options, inputs, type, names) + for type_word, names in [('status', inputs.statuses), + ('algorithm', inputs.algorithms), + ('ecc_curve', inputs.ecc_curves), + ('dh_group', inputs.dh_groups), + ('key_type', inputs.key_types), + ('key_usage', inputs.key_usage_flags)]: + c, e = do_test(options, inputs, type_word, names) count += c errors += e return count, errors