blob: fe9f9b76af6569e0344841d2b735a23014574361 [file] [log] [blame]
# Copyright 2016 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
'''Generates a test suite from NIST PKITS test descriptions.
The output is a set of Type Parameterized Tests which are included by
pkits_unittest.h. See pkits_unittest.h for information on using the tests.
GoogleTest has a limit of 50 tests per type parameterized testcase, so the tests
are split up by section number (this also makes it possible to easily skip
sections that pertain to non-implemented features).
Usage:
generate_tests.py <PKITS.pdf> <output.h>
'''
import os
import re
import subprocess
import sys
import tempfile
def sanitize_name(s):
return s.translate(None, ' -')
def finalize_test_case(test_case_name, sanitized_test_names, output):
output.write('\nWRAPPED_REGISTER_TYPED_TEST_CASE_P(%s' % test_case_name)
for name in sanitized_test_names:
output.write(',\n %s' % name)
output.write(');\n')
def generate_test(test_case_name, test_number, raw_test_name, certs, crls, should_validate,
output):
sanitized_test_name = 'Section%s%s' % (test_number.split('.')[1],
sanitize_name(raw_test_name))
certs_formatted = ', '.join('"%s"' % n for n in certs)
crls_formatted = ', '.join('"%s"' % n for n in crls)
assert_function = 'ASSERT_TRUE' if should_validate else 'ASSERT_FALSE'
output.write('''
// %(test_number)s %(raw_test_name)s
WRAPPED_TYPED_TEST_P(%(test_case_name)s, %(sanitized_test_name)s) {
const char* const certs[] = {
%(certs_formatted)s
};
const char* const crls[] = {
%(crls_formatted)s
};
%(assert_function)s(this->Verify(certs, crls));
}
''' % vars())
return sanitized_test_name
# Matches a section header, ex: "4.1 Signature Verification"
SECTION_MATCHER = re.compile('^\s*(\d+\.\d+)\s+(.+)\s*$')
# Matches a test header, ex: "4.1.1 Valid Signatures Test1"
TEST_MATCHER = re.compile('^\s*(\d+\.\d+.\d+)\s+(.+)\s*$')
# Match an expected test result. Note that some results in the PDF have a typo
# "path not should validate" instead of "path should not validate".
TEST_RESULT_MATCHER = re.compile(
'^\s*Expected Result:.*path (should validate|'
'should not validate|not should validate)')
PATH_HEADER_MATCHER = re.compile('^\s*Certification Path:')
# Matches a line in the certification path, ex: "\u2022 Good CA Cert, Good CA CRL"
PATH_MATCHER = re.compile('^\s*\xe2\x80\xa2\s*(.+)\s*$')
# Matches a page number. These may appear in the middle of multi-line fields and
# thus need to be ignored.
PAGE_NUMBER_MATCHER = re.compile('^\s*\d+\s*$')
# Matches if an entry in a certification path refers to a CRL, ex:
# "onlySomeReasons CA2 CRL1".
CRL_MATCHER = re.compile('^.*CRL\d*$')
def parse_test(lines, i, test_case_name, test_number, test_name, output):
expected_result = None
certs = []
crls = []
while i < len(lines):
result_match = TEST_RESULT_MATCHER.match(lines[i])
i += 1
if result_match:
expected_result = result_match.group(1) == 'should validate'
break
while i < len(lines):
path_match = PATH_HEADER_MATCHER.match(lines[i])
i += 1
if path_match:
break
path_lines = []
while i < len(lines):
line = lines[i].strip()
if TEST_MATCHER.match(line) or SECTION_MATCHER.match(line):
break
i += 1
if not line or PAGE_NUMBER_MATCHER.match(line):
continue
path_match = PATH_MATCHER.match(line)
if path_match:
path_lines.append(path_match.group(1))
continue
# Continuation of previous path line.
path_lines[-1] += ' ' + line
for path_line in path_lines:
for path in path_line.split(','):
path = sanitize_name(path.strip())
if CRL_MATCHER.match(path):
crls.append(path)
else:
certs.append(path)
assert certs
assert crls
assert expected_result is not None
sanitized_test_name = generate_test(test_case_name, test_number, test_name,
certs, crls, expected_result, output)
return i, sanitized_test_name
def main():
pkits_pdf_path, output_path = sys.argv[1:]
pkits_txt_file = tempfile.NamedTemporaryFile()
subprocess.check_call(['pdftotext', '-layout', '-nopgbrk', '-eol', 'unix',
pkits_pdf_path, pkits_txt_file.name])
test_descriptions = pkits_txt_file.read()
# Extract section 4 of the text, which is the part that contains the tests.
test_descriptions = test_descriptions.split(
'4 Certification Path Validation Tests')[-1]
test_descriptions = test_descriptions.split(
'5 Relationship to Previous Test Suite', 1)[0]
output = open(output_path, 'w')
output.write('// Autogenerated by %s, do not edit\n\n' % sys.argv[0])
output.write('// Hack to allow disabling type parameterized test cases.\n'
'// See https://github.com/google/googletest/issues/389\n')
output.write('#define WRAPPED_TYPED_TEST_P(CaseName, TestName) '
'TYPED_TEST_P(CaseName, TestName)\n')
output.write('#define WRAPPED_REGISTER_TYPED_TEST_CASE_P(CaseName, ...) '
'REGISTER_TYPED_TEST_CASE_P(CaseName, __VA_ARGS__)\n\n')
test_case_name = None
sanitized_test_names = []
lines = test_descriptions.splitlines()
i = 0
while i < len(lines):
section_match = SECTION_MATCHER.match(lines[i])
match = TEST_MATCHER.match(lines[i])
i += 1
if section_match:
if test_case_name:
finalize_test_case(test_case_name, sanitized_test_names, output)
sanitized_test_names = []
# TODO(mattm): Handle certificate policies tests.
if section_match.group(1) in ('4.8', '4.9', '4.10', '4.11', '4.12'):
test_case_name = None
output.write('\n// Skipping section %s\n' % section_match.group(1))
continue
test_case_name = 'PkitsTest%02d%s' % (
int(section_match.group(1).split('.')[-1]),
sanitize_name(section_match.group(2)))
output.write('\ntemplate <typename PkitsTestDelegate>\n')
output.write('class %s : public PkitsTest<PkitsTestDelegate> {};\n' % test_case_name)
output.write('TYPED_TEST_CASE_P(%s);\n' % test_case_name)
if match:
test_number = match.group(1)
test_name = match.group(2)
if not test_case_name:
output.write('// Skipped %s %s\n' % (test_number, test_name))
continue
i, sanitized_test_name = parse_test(lines, i, test_case_name, test_number,
test_name, output)
if sanitized_test_name:
sanitized_test_names.append(sanitized_test_name)
if test_case_name:
finalize_test_case(test_case_name, sanitized_test_names, output)
if __name__ == '__main__':
main()