| #!/usr/bin/env python3 |
| # |
| # Copyright 2024 The Chromium Authors |
| # Use of this source code is governed by a BSD-style license that can be |
| # found in the LICENSE file. |
| |
| import argparse |
| import collections |
| import copy |
| import functools |
| import os |
| import pathlib |
| import sys |
| import typing |
| import re |
| import dataclasses |
| |
| |
| def _GetDirAbove(dirname: str): |
| """Returns the directory "above" this file containing |dirname| (which must |
| also be "above" this file).""" |
| path = os.path.abspath(__file__) |
| while True: |
| path, tail = os.path.split(path) |
| if not tail: |
| return None |
| if tail == dirname: |
| return path |
| |
| |
| SOURCE_DIR = _GetDirAbove('testing') |
| |
| # //build imports. |
| sys.path.append(os.path.join(SOURCE_DIR, 'build')) |
| import action_helpers |
| |
| # //third_party imports. |
| sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party')) |
| import jinja2 |
| |
| # //third_party/domato/src imports. |
| sys.path.insert(1, os.path.join(SOURCE_DIR, 'third_party/domato/src')) |
| import grammar |
| |
| # TODO(crbug.com/361369290): Remove this disable once DomatoLPM development is |
| # finished and upstream changes can be made to expose the relevant protected |
| # fields. |
| # pylint: disable=protected-access |
| |
| def to_snake_case(name): |
| name = re.sub(r'([A-Z]{2,})([A-Z][a-z])', r'\1_\2', name) |
| return re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name, count=sys.maxsize).lower() |
| |
| |
| def GetProtoId(name): |
| # We reserve ids [0,15] |
| # Protobuf implementation reserves [19000,19999] |
| # Max proto id is 2^29-1 |
| # 32-bit fnv-1a |
| fnv = 2166136261 |
| for c in name: |
| fnv = fnv ^ ord(c) |
| fnv = (fnv * 16777619) & 0xffffffff |
| # xor-fold to 29-bits |
| fnv = (fnv >> 29) ^ (fnv & 0x1fffffff) |
| # now use a modulo to reduce to [0,2^29-1 - 1016] |
| fnv = fnv % 536869895 |
| # now we move out the disallowed ranges |
| fnv = fnv + 15 |
| if fnv >= 19000: |
| fnv += 1000 |
| return fnv |
| |
| |
| DOMATO_INT_TYPE_TO_CPP_INT_TYPE = { |
| 'int': 'int', |
| 'int32': 'int32_t', |
| 'uint32': 'uint32_t', |
| 'int8': 'int8_t', |
| 'uint8': 'uint8_t', |
| 'int16': 'int16_t', |
| 'uint16': 'uint16_t', |
| 'int64': 'uint64_t', |
| 'uint64': 'uint64_t', |
| } |
| |
| DOMATO_TO_PROTO_BUILT_IN = { |
| 'int': 'int32', |
| 'int32': 'int32', |
| 'uint32': 'uint32', |
| 'int8': 'int32', |
| 'uint8': 'uint32', |
| 'int16': 'int32', |
| 'uint16': 'uint32', |
| 'int64': 'int64', |
| 'uint64': 'uint64', |
| 'float': 'float', |
| 'double': 'double', |
| 'char': 'int32', |
| 'string': 'string', |
| 'htmlsafestring': 'string', |
| 'hex': 'int32', |
| 'lines': 'repeated lines', |
| } |
| |
| DOMATO_TO_CPP_HANDLERS = { |
| 'int': 'handle_int_conversion<int32_t, int>', |
| 'int32': 'handle_int_conversion<int32_t, int32_t>', |
| 'uint32': 'handle_int_conversion<uint32_t, uint32_t>', |
| 'int8': 'handle_int_conversion<int32_t, int8_t>', |
| 'uint8': 'handle_int_conversion<uint32_t, uint8_t>', |
| 'int16': 'handle_int_conversion<int16_t, int16_t>', |
| 'uint16': 'handle_int_conversion<uint16_t, uint16_t>', |
| 'int64': 'handle_int_conversion<int64_t, int64_t>', |
| 'uint64': 'handle_int_conversion<uint64_t, uint64_t>', |
| 'float': 'handle_float', |
| 'double': 'handle_double', |
| 'char': 'handle_char', |
| 'string': 'handle_string', |
| 'htmlsafestring': 'handle_string', |
| 'hex': 'handle_hex', |
| } |
| |
| _C_STR_TRANS = str.maketrans({ |
| '\n': '\\n', |
| '\r': '\\r', |
| '\t': '\\t', |
| '\"': '\\\"', |
| '\\': '\\\\' |
| }) |
| |
| BASE_PROTO_NS = 'domatolpm.generated' |
| |
| |
| def to_cpp_ns(proto_ns: str) -> str: |
| return proto_ns.replace('.', '::') |
| |
| |
| CPP_HANDLER_PREFIX = 'handle_' |
| |
| |
| def to_proto_field_name(name: str) -> str: |
| """Converts a creator or rule name to a proto field name. This tries to |
| respect the protobuf naming convention that field names should be snake case. |
| |
| Args: |
| name: the name of the creator or the rule. |
| |
| Returns: |
| the proto field name to use. |
| """ |
| res = to_snake_case(name.replace('-', '_')) |
| if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: |
| res += '_proto' |
| return res |
| |
| |
| def to_proto_type(creator_name: str) -> str: |
| """Converts a creator name to a proto type. This is deliberately very simple |
| so that we avoid naming conflicts. |
| |
| Args: |
| creator_name: the name of the creator. |
| |
| Returns: |
| the name of the proto type. |
| """ |
| res = creator_name.replace('-', '_') |
| if res in ['short', 'class', 'bool', 'boolean', 'long', 'void']: |
| res += '_proto' |
| return res |
| |
| |
| def c_escape(v: str) -> str: |
| return v.translate(_C_STR_TRANS) |
| |
| |
| def tarjan(g): |
| """This is a simple implementation of Tarjan's algorithm for finding the |
| strongly connected components of the graph @g in a topological order.""" |
| stack = [] |
| index = {} |
| lowlink = {} |
| ret = [] |
| |
| def visit(v): |
| index[v] = len(index) |
| lowlink[v] = index[v] |
| stack.append(v) |
| for w in g.get(v, ()): |
| if w not in index: |
| visit(w) |
| lowlink[v] = min(lowlink[w], lowlink[v]) |
| elif w in stack: |
| lowlink[v] = min(lowlink[v], index[w]) |
| if lowlink[v] == index[v]: |
| scc = [] |
| w = None |
| while v != w: |
| w = stack.pop() |
| scc.append(w) |
| ret.append(scc) |
| |
| for v in g: |
| if v not in index: |
| visit(v) |
| return ret |
| |
| |
| @dataclasses.dataclass |
| class ProtoType: |
| """Represents a Proto type.""" |
| name: str |
| |
| @property |
| def is_one_of(self) -> bool: |
| return False |
| |
| |
| @dataclasses.dataclass |
| class ProtoField: |
| """Represents a proto message field.""" |
| type: ProtoType |
| name: str |
| proto_id: int |
| |
| |
| @dataclasses.dataclass |
| class ProtoMessage(ProtoType): |
| """Represents a Proto message.""" |
| fields: typing.List[ProtoField] |
| |
| |
| @dataclasses.dataclass |
| class OneOfProtoMessage(ProtoMessage): |
| """Represents a Proto message with a oneof field.""" |
| oneofname: str |
| |
| @property |
| def is_one_of(self) -> bool: |
| return True |
| |
| |
| class CppExpression: |
| |
| # pylint: disable=no-self-use |
| def repr(self): |
| raise Exception('Not implemented.') |
| # pylint: enable=no-self-use |
| |
| |
| @dataclasses.dataclass |
| class CppTxtExpression(CppExpression): |
| """Represents a Raw text expression.""" |
| content: str |
| |
| def repr(self): |
| return self.content |
| |
| |
| @dataclasses.dataclass |
| class CppCallExpr(CppExpression): |
| """Represents a CallExpr.""" |
| fct_name: str |
| args: typing.List[CppExpression] |
| ns: str = '' |
| |
| def repr(self): |
| arg_s = ', '.join([a.repr() for a in self.args]) |
| return f'{self.ns}{self.fct_name}({arg_s})' |
| |
| |
| class CppHandlerCallExpr(CppCallExpr): |
| |
| def __init__(self, |
| handler: str, |
| field_name: str, |
| extra_args: typing.Optional[typing.List[CppExpression]] = None): |
| args = [CppTxtExpression('ctx'), CppTxtExpression(f'arg.{field_name}()')] |
| if extra_args: |
| args += extra_args |
| super().__init__(fct_name=handler, args=args) |
| self.handler = handler |
| self.field_name = field_name |
| self.extra_args = extra_args |
| |
| |
| @dataclasses.dataclass |
| class CppStringExpr(CppExpression): |
| """Represents a C++ literal string. |
| """ |
| content: str |
| |
| def repr(self): |
| return f'\"{c_escape(self.content)}\"' |
| |
| |
| @dataclasses.dataclass |
| class CppFunctionHandler: |
| """Represents a C++ function. |
| """ |
| name: str |
| exprs: typing.List[CppExpression] |
| |
| @property |
| def is_oneof_handler(self) -> bool: |
| return False |
| |
| @property |
| def is_string_table_handler(self) -> bool: |
| return False |
| |
| @property |
| def is_message_handler(self) -> bool: |
| return False |
| |
| |
| class CppStringTableHandler(CppFunctionHandler): |
| """Represents a C++ function that implements a string table and returns one |
| of the represented strings. |
| """ |
| |
| def __init__(self, name: str, var_name: str, |
| strings: typing.List[CppStringExpr]): |
| super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) |
| self.proto_type = f'{name}& arg' |
| self.strings = strings |
| self.var_name = var_name |
| |
| @property |
| def is_string_table_handler(self) -> bool: |
| return True |
| |
| |
| class CppProtoMessageFunctionHandler(CppFunctionHandler): |
| """Represents a C++ function that handles a ProtoMessage. |
| """ |
| |
| def __init__(self, |
| name: str, |
| exprs: typing.List[CppExpression], |
| creator: typing.Optional[typing.Dict[str, str]] = None): |
| super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=exprs) |
| self.proto_type = f'{name}& arg' |
| self.creator = creator |
| |
| def creates_new(self): |
| return self.creator is not None |
| |
| @property |
| def is_message_handler(self) -> bool: |
| return True |
| |
| |
| class CppOneOfMessageFunctionHandler(CppFunctionHandler): |
| """Represents a C++ function that handles a OneOfProtoMessage. |
| """ |
| |
| def __init__(self, name: str, switch_name: str, |
| cases: typing.Dict[str, typing.List[CppExpression]]): |
| super().__init__(name=f'{CPP_HANDLER_PREFIX}{name}', exprs=[]) |
| self.proto_type = f'{name}& arg' |
| self.switch_name = switch_name |
| self.cases = cases |
| |
| def all_except_last(self): |
| a = list(self.cases.keys())[:-1] |
| return {e: self.cases[e] for e in a} |
| |
| def last(self): |
| a = list(self.cases.keys())[-1] |
| return self.cases[a] |
| |
| @property |
| def is_oneof_handler(self) -> bool: |
| return True |
| |
| |
| @dataclasses.dataclass |
| class File: |
| name: str |
| deps = [] |
| protos = [] |
| cpps = [] |
| |
| |
| class DomatoBuilder: |
| """DomatoBuilder is the class that takes a Domato grammar, and modelize it |
| into a protobuf representation and its corresponding C++ parsing code. |
| """ |
| |
| @dataclasses.dataclass |
| class Entry: |
| msg: ProtoMessage |
| func: CppFunctionHandler |
| |
| def __init__(self, g: grammar.Grammar, stabilize_grammar=False): |
| self.handlers: typing.Dict[str, DomatoBuilder.Entry] = {} |
| self.backrefs: typing.Dict[str, |
| typing.List[str]] = collections.defaultdict(list) |
| self.grammar = g |
| self.stabilize_grammar = stabilize_grammar |
| if self.grammar._root and self.grammar._root != 'root': |
| self.root = self.grammar._root |
| else: |
| self.root = 'line' |
| if self.grammar._root and self.grammar._root == 'root': |
| rules = self.grammar._creators[self.grammar._root] |
| # multiple roots doesn't make sense, so we only consider the last defined |
| # one. |
| rule = rules[-1] |
| for part in rule['parts']: |
| if part['type'] == 'tag' and part[ |
| 'tagname'] == 'lines' and 'count' in part: |
| self.root = f'lines_{part["count"]}' |
| break |
| self._built_in_types_parser = { |
| 'int': self._int_handler, |
| 'int32': self._int_handler, |
| 'uint32': self._int_handler, |
| 'int8': self._int_handler, |
| 'uint8': self._int_handler, |
| 'int16': self._int_handler, |
| 'uint16': self._int_handler, |
| 'int64': self._int_handler, |
| 'uint64': self._int_handler, |
| 'float': self._default_handler, |
| 'double': self._default_handler, |
| 'char': self._default_handler, |
| 'string': self._default_handler, |
| 'htmlsafestring': self._default_handler, |
| 'hex': self._default_handler, |
| 'lines': self._lines_handler, |
| } |
| self.unique_id = 0 |
| |
| def create_internal_message(self) -> str: |
| """Returns a unique name for a newly created message. |
| """ |
| self.unique_id += 1 |
| return f'DomatoLPMInternalMsg{self.unique_id}' |
| |
| def parse_grammar(self): |
| for creator, rules in self.grammar._creators.items(): |
| field_name = to_proto_field_name(creator) |
| type_name = to_proto_type(creator) |
| messages = self._parse_rule(creator, rules) |
| proto_fields: typing.List[ProtoField] = [] |
| for proto_id, msg in enumerate(messages, start=1): |
| proto_fields.append( |
| ProtoField(type=ProtoType(name=msg.name), |
| name=f'{field_name}_{proto_id}', |
| proto_id=proto_id)) |
| msg = OneOfProtoMessage(name=type_name, |
| oneofname='oneoffield', |
| fields=proto_fields) |
| cases = { |
| f.name: [ |
| CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{f.type.name}', |
| field_name=f.name) |
| ] |
| for f in proto_fields |
| } |
| func = CppOneOfMessageFunctionHandler(name=type_name, |
| switch_name='oneoffield', |
| cases=cases) |
| self._add(msg, func) |
| |
| def all_proto_messages(self): |
| return [v.msg for v in self.handlers.values()] |
| |
| def all_cpp_functions(self): |
| return [v.func for v in self.handlers.values()] |
| |
| def get_line_prefix(self) -> str: |
| if not self.grammar._line_guard: |
| return '' |
| return self.grammar._line_guard.split('<line>')[0] |
| |
| def get_line_suffix(self) -> str: |
| if not self.grammar._line_guard: |
| return '' |
| return self.grammar._line_guard.split('<line>')[1] |
| |
| def maybe_add_lines_handler(self, number: int) -> bool: |
| name = f'lines_{number}' |
| if name in self.handlers: |
| return False |
| fields = [] |
| exprs = [] |
| for i in range(1, number + 1): |
| fields.append(ProtoField(ProtoType('line'), f'line_{i}', i)) |
| exprs.append(CppHandlerCallExpr('handle_one_line', f'line_{i}')) |
| msg = ProtoMessage(name, fields=fields) |
| handler = CppProtoMessageFunctionHandler(name, exprs=exprs) |
| self._add(msg, handler) |
| return True |
| |
| def get_root(self) -> typing.Tuple[ProtoMessage, CppFunctionHandler]: |
| # If the root is 'line', we actually want to generate an arbitrary number |
| # of lines. In this case, we'll invoke the special proto message 'lines'. |
| # In any other case, we just use the existing root, which has been defined |
| # during grammar construction. |
| root = self.root |
| if self.root == 'line': |
| root = 'lines' |
| root_handler = f'{CPP_HANDLER_PREFIX}{root}' |
| fuzz_case = ProtoMessage( |
| name='fuzzcase', |
| fields=[ProtoField(type=ProtoType(name=root), name='root', proto_id=1)]) |
| fuzz_fct = CppProtoMessageFunctionHandler( |
| name='fuzzcase', |
| exprs=[CppHandlerCallExpr(handler=root_handler, field_name='root')]) |
| return fuzz_case, fuzz_fct |
| |
| def split_files(self, file_prefix: str, file_num=15): |
| res = self._fusion_similar_messages() |
| protos = self._split_protos(file_num) |
| files = [File(f'tmp{i}') for i in range(0, file_num)] |
| for i, (file, proto) in enumerate(zip(files, protos)): |
| file.deps = [] |
| all_deps = set() |
| for field in (f for e in proto for f in self.handlers[e].msg.fields): |
| if field.type.name in self.handlers: |
| all_deps.add(field.type.name) |
| for j in range(i + 1, len(protos)): |
| if any(dep in protos[j] for dep in all_deps): |
| file.deps.append(files[j]) |
| |
| file.protos = [self.handlers[elt].msg for elt in proto] |
| file.cpps = [] |
| for elt in proto: |
| if elt in res: |
| file.cpps += res[elt] |
| else: |
| file.cpps.append(self.handlers[elt].func) |
| |
| # sort the files with most def first and least last. |
| def comp(f1, f2): |
| return len(f1.protos) - len(f2.protos) |
| |
| files = list(sorted(files, key=functools.cmp_to_key(comp), reverse=True)) |
| for i, file in enumerate(files): |
| file.name = f'{file_prefix}{i}' |
| return files |
| |
| def simplify(self): |
| """Simplifies the proto and functions.""" |
| should_continue = True |
| while should_continue: |
| should_continue = False |
| should_continue |= self._merge_unary_oneofs() |
| should_continue |= self._merge_strings() |
| should_continue |= self._merge_multistrings_oneofs() |
| should_continue |= self._remove_unlinked_nodes() |
| should_continue |= self._merge_proto_messages() |
| should_continue |= self._merge_oneofs() |
| if self.stabilize_grammar: |
| self._hash_line_proto_ids() |
| self._oneofs_reorderer() |
| self._oneof_message_renamer() |
| self._message_renamer() |
| |
| def _hash_line_proto_ids(self): |
| if 'line' not in self.handlers: |
| return |
| rules = self.grammar._creators['line'] |
| for (rule, field) in zip(rules, self.handlers['line'].msg.fields): |
| concat = ''.join(p['text'] if p['type'] == 'text' else p['tagname'] |
| for p in rule['parts']) |
| field.proto_id = GetProtoId(concat) |
| |
| def _add(self, message: ProtoMessage, |
| handler: CppProtoMessageFunctionHandler): |
| self.handlers[message.name] = DomatoBuilder.Entry(message, handler) |
| for field in message.fields: |
| self.backrefs[field.type.name].append(message.name) |
| |
| # Handlers should be together even if some of them don't actually use self. |
| # pylint: disable=no-self-use |
| def _int_handler( |
| self, part, |
| field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: |
| proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] |
| handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] |
| extra_args = [] |
| if 'min' in part: |
| extra_args.append(CppTxtExpression(part['min'])) |
| if 'max' in part: |
| if not extra_args: |
| cpp_type = DOMATO_INT_TYPE_TO_CPP_INT_TYPE[part['tagname']] |
| extra_args.append( |
| CppTxtExpression(f'std::numeric_limits<{cpp_type}>::min()')) |
| extra_args.append(CppTxtExpression(part['max'])) |
| contents = CppHandlerCallExpr(handler=handler, |
| field_name=field_name, |
| extra_args=extra_args) |
| return proto_type, contents |
| |
| def _default_handler( |
| self, part, |
| field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: |
| proto_type = DOMATO_TO_PROTO_BUILT_IN[part['tagname']] |
| handler = DOMATO_TO_CPP_HANDLERS[part['tagname']] |
| contents = CppHandlerCallExpr(handler=handler, field_name=field_name) |
| return proto_type, contents |
| |
| # pylint: enable=no-self-use |
| |
| def _lines_handler( |
| self, part, |
| field_name: str) -> typing.Tuple[ProtoType, CppHandlerCallExpr]: |
| handler_name = 'lines' |
| if 'count' in part: |
| count = part['count'] |
| handler_name = f'{handler_name}_{count}' |
| self.maybe_add_lines_handler(int(part['count'])) |
| proto_type = handler_name |
| contents = CppHandlerCallExpr(handler=f'handle_{handler_name}', |
| field_name=field_name) |
| return proto_type, contents |
| |
| def _parse_rule(self, creator_name, rules): |
| messages = [] |
| for rule_id, rule in enumerate(rules, start=1): |
| rule_msg_field_name = f'{to_proto_field_name(creator_name)}_{rule_id}' |
| proto_fields = [] |
| cpp_contents = [] |
| ret_vars = 0 |
| for part_id, part in enumerate(rule['parts'], start=1): |
| field_name = f'{rule_msg_field_name}_{part_id}' |
| proto_type = None |
| if rule['type'] == 'code' and 'new' in part: |
| proto_fields.insert( |
| 0, |
| ProtoField(type=ProtoType('optional int32'), |
| name='old', |
| proto_id=part_id)) |
| ret_vars += 1 |
| continue |
| if part['type'] == 'text': |
| contents = CppStringExpr(part['text']) |
| elif part['tagname'] == 'import': |
| # The current domato project is currently not handling that either in |
| # its built-in rules, and I do not plan on using the feature with |
| # newly written rules, as I think this directive has a lot of |
| # constraints with not much added value. |
| continue |
| elif part['tagname'] == 'call': |
| raise Exception( |
| 'DomatoLPM does not implement <call> and <import> tags.') |
| elif part['tagname'] in self.grammar._constant_types.keys(): |
| contents = CppStringExpr( |
| self.grammar._constant_types[part['tagname']]) |
| elif part['tagname'] in self._built_in_types_parser: |
| handler = self._built_in_types_parser[part['tagname']] |
| proto_type, contents = handler(part, field_name) |
| elif part['type'] == 'tag': |
| proto_type = to_proto_type(part['tagname']) |
| contents = CppHandlerCallExpr( |
| handler=f'{CPP_HANDLER_PREFIX}{proto_type}', |
| field_name=field_name) |
| if proto_type: |
| proto_fields.append( |
| ProtoField(type=ProtoType(name=proto_type), |
| name=field_name, |
| proto_id=part_id)) |
| cpp_contents.append(contents) |
| |
| if ret_vars > 1: |
| raise Exception('Not implemented.') |
| |
| creator = None |
| if rule['type'] == 'code' and ret_vars > 0: |
| creates = rule['creates'] |
| # For some reason, Domato sets a dictionary when the creator is a line |
| # and a list when its a helper. Thus the unpacking code below. The |
| # assertion ensures we are not dealing with another unknown format. |
| if isinstance(creates, list): |
| assert len(creates) == 1 |
| creates = creates[0] |
| creator = {'var_type': creates['tagname'], 'var_prefix': 'var'} |
| proto_type = to_proto_type(creator_name) |
| rule_msg = ProtoMessage(name=f'{proto_type}_{rule_id}', |
| fields=proto_fields) |
| rule_func = CppProtoMessageFunctionHandler(name=f'{proto_type}_{rule_id}', |
| exprs=cpp_contents, |
| creator=creator) |
| |
| self._add(rule_msg, rule_func) |
| messages.append(rule_msg) |
| return messages |
| |
| def _remove(self, name: str): |
| assert name in self.handlers |
| for field in self.handlers[name].msg.fields: |
| if field.type.name in self.backrefs: |
| self.backrefs[field.type.name].remove(name) |
| if name in self.backrefs: |
| self.backrefs.pop(name) |
| self.handlers.pop(name) |
| |
| def _update(self, name: str): |
| assert name in self.handlers |
| for field in self.handlers[name].msg.fields: |
| self.backrefs[field.type.name].append(name) |
| |
| def _count_backref(self, proto_name: str) -> int: |
| """Counts the number of backreference a given proto message has. |
| |
| Args: |
| proto_name: the proto message name. |
| |
| Returns: |
| the number of backreferences. |
| """ |
| return len(self.backrefs[proto_name]) |
| |
| def _merge_proto_messages(self) -> bool: |
| """Merges messages referencing other messages into the same message. This |
| allows to tremendously reduce the number of protobuf messages that will be |
| generated. |
| """ |
| to_merge = collections.defaultdict(set) |
| for name in self.handlers: |
| msg = self.handlers[name].msg |
| func = self.handlers[name].func |
| if msg.is_one_of or not func.is_message_handler or func.creates_new( |
| ) or name == self.root: |
| continue |
| if name not in self.backrefs: |
| continue |
| for elt in self.backrefs[name]: |
| if elt == name or elt not in self.handlers: |
| continue |
| if self.handlers[elt].msg.is_one_of: |
| continue |
| to_merge[elt].add(name) |
| |
| for parent, childs in to_merge.items(): |
| msg = self.handlers[parent].msg |
| fct = self.handlers[parent].func |
| for child in childs: |
| new_contents = [] |
| for expr in fct.exprs: |
| if isinstance(expr, CppStringExpr): |
| new_contents.append(expr) |
| continue |
| assert isinstance(expr, CppHandlerCallExpr) |
| field: ProtoField = next( |
| (f for f in msg.fields if f.type.name == child), None) |
| if not field or not expr.field_name == field.name: |
| new_contents.append(expr) |
| continue |
| self.backrefs[field.type.name].remove(msg.name) |
| idx = msg.fields.index(field) |
| field_msg = self.handlers[child].msg |
| field_fct = self.handlers[child].func |
| |
| # The following deepcopy is required because we might change the |
| # child's messages fields at some point, and we don't want those |
| # changes to affect this current's message fields. |
| fields_copy = copy.deepcopy(field_msg.fields) |
| msg.fields = msg.fields[:idx] + fields_copy + msg.fields[idx + 1:] |
| new_contents += copy.deepcopy(field_fct.exprs) |
| for f in field_msg.fields: |
| self.backrefs[f.type.name].append(msg.name) |
| fct.exprs = new_contents |
| return len(to_merge) > 0 |
| |
| def _message_renamer(self): |
| """Renames ProtoMessage fields that might have been merged. This ensures |
| proto field naming remains consistent with the current rule being |
| generated. |
| """ |
| for entry in self.handlers.values(): |
| if entry.msg.is_one_of or entry.func.is_string_table_handler: |
| continue |
| for proto_id, field in enumerate(entry.msg.fields, start=1): |
| field.proto_id = proto_id |
| if entry.func.creates_new() and field.name == 'old': |
| continue |
| field.name = to_proto_field_name(f'field_{proto_id}') |
| index = 2 if entry.func.creates_new() else 1 |
| new_contents = [] |
| for expr in entry.func.exprs: |
| if not isinstance(expr, CppHandlerCallExpr): |
| new_contents.append(expr) |
| continue |
| new_contents.append( |
| CppHandlerCallExpr(expr.handler, |
| to_proto_field_name(f'field_{index}'), |
| expr.extra_args)) |
| index += 1 |
| entry.func.exprs = new_contents |
| |
| def _oneof_message_renamer(self): |
| """Renames OneOfProtoMessage fields that might have been merged. This |
| ensures proto field naming remains consistent with the current rule being |
| generated. |
| """ |
| for entry in self.handlers.values(): |
| if not entry.msg.is_one_of: |
| continue |
| cases = {} |
| for proto_id, field in enumerate(entry.msg.fields, start=1): |
| if entry.msg.name != 'line': |
| field.proto_id = proto_id |
| exprs = entry.func.cases.pop(field.name) |
| field.name = to_proto_field_name(f'field_{proto_id}') |
| new_contents = [] |
| for expr in exprs: |
| if not isinstance(expr, CppHandlerCallExpr): |
| new_contents.append(expr) |
| continue |
| new_contents.append( |
| CppHandlerCallExpr(expr.handler, field.name, expr.extra_args)) |
| cases[field.name] = new_contents |
| entry.func.cases = cases |
| |
| def _merge_multistrings_oneofs(self) -> bool: |
| """Merges multiple strings into a string table function.""" |
| has_made_changes = False |
| for name in list(self.handlers.keys()): |
| msg = self.handlers[name].msg |
| |
| if not msg.is_one_of: |
| continue |
| |
| if not all(f.type.name in self.handlers and len(self.handlers[ |
| f.type.name].msg.fields) == 0 and not self.handlers[f.type.name].msg. |
| is_one_of and len(self.handlers[f.type.name].func.exprs) == 1 |
| for f in msg.fields): |
| continue |
| |
| fields = [ProtoField(type=ProtoType('uint32'), name='val', proto_id=1)] |
| new_msg = ProtoMessage(name=msg.name, fields=fields) |
| strings = [] |
| for field in msg.fields: |
| self.backrefs[field.type.name].remove(name) |
| for expr in self.handlers[field.type.name].func.exprs: |
| assert isinstance(expr, CppStringExpr) |
| strings += [expr] |
| new_func = CppStringTableHandler(name=msg.name, |
| var_name='val', |
| strings=strings) |
| self.handlers[name] = DomatoBuilder.Entry(new_msg, new_func) |
| self._update(name) |
| has_made_changes = True |
| return has_made_changes |
| |
| def _oneofs_reorderer(self): |
| """Reorders the OneOfProtoMessage so that the last element can be extracted |
| out of the protobuf oneof's field in order to always have a correct |
| path to be generated. This requires having at least one terminal path in |
| the grammar. |
| """ |
| _terminal_messages = set() |
| _being_visited = set() |
| |
| def recursive_terminal_marker(name: str): |
| if name in _terminal_messages or name not in self.handlers: |
| return True |
| if name in _being_visited: |
| return False |
| _being_visited.add(name) |
| msg = self.handlers[name].msg |
| func = self.handlers[name].func |
| if len(msg.fields) == 0: |
| _terminal_messages.add(name) |
| _being_visited.remove(name) |
| return True |
| if msg.is_one_of: |
| f = next( |
| (f for f in msg.fields if recursive_terminal_marker(f.type.name)), |
| None) |
| if not f: |
| #FIXME: for testing purpose only, we're not hard-failing on this. |
| _being_visited.remove(name) |
| return False |
| msg.fields.remove(f) |
| msg.fields.append(f) |
| m = next(k for k in func.cases.keys() if k == f.name) |
| func.cases[m] = func.cases.pop(m) |
| _terminal_messages.add(name) |
| _being_visited.remove(name) |
| return True |
| res = all(recursive_terminal_marker(f.type.name) for f in msg.fields) |
| #FIXME: for testing purpose only, we're not hard-failing on this. |
| _being_visited.remove(name) |
| return res |
| |
| for name in self.handlers: |
| recursive_terminal_marker(name) |
| |
| def _merge_oneofs(self) -> bool: |
| has_made_changes = False |
| for name in list(self.handlers.keys()): |
| msg = self.handlers[name].msg |
| func = self.handlers[name].func |
| if not msg.is_one_of: |
| continue |
| |
| for field in msg.fields: |
| if not field.type.name in self.handlers: |
| continue |
| field_msg = self.handlers[field.type.name].msg |
| field_func = self.handlers[field.type.name].func |
| if (field_msg.is_one_of or len(field_msg.fields) != 1 |
| or not field_func.is_message_handler or field_func.creates_new()): |
| continue |
| func.cases.pop(field.name) |
| field.name = field_msg.fields[0].name |
| field.type = field_msg.fields[0].type |
| while field.name in func.cases: |
| field.name += '_1' |
| func.cases[field.name] = copy.deepcopy(field_func.exprs) |
| self.backrefs[field_msg.name].remove(name) |
| self.backrefs[field.type.name].append(name) |
| has_made_changes = True |
| return has_made_changes |
| |
| def _merge_unary_oneofs(self) -> bool: |
| """Transfors OneOfProtoMessage messages containing only one field into a |
| ProtoMessage containing the fields of the contained message. E.g.: |
| message B { |
| int field1 = 1; |
| Whatever field2 = 2; |
| } |
| message A { |
| oneof field { |
| B b = 1; |
| } |
| } |
| Into: |
| message A { |
| int field1 = 1; |
| Whatever field2 = 2; |
| } |
| """ |
| has_made_changes = False |
| for name in list(self.handlers.keys()): |
| msg = self.handlers[name].msg |
| func = self.handlers[name].func |
| |
| if not msg.is_one_of or len(msg.fields) > 1: |
| continue |
| |
| # The message is a unary oneof. Let's make sure it's only child doesn't |
| # have backrefs. |
| if self._count_backref(msg.fields[0].type.name) > 1: |
| continue |
| |
| # The only backref should really only be us. If not we screwed up |
| # somewhere else. |
| assert name in self.backrefs[msg.fields[0].type.name] |
| field_msg: ProtoMessage = self.handlers[msg.fields[0].type.name].msg |
| if field_msg.is_one_of: |
| continue |
| |
| field_func = self.handlers[msg.fields[0].type.name].func |
| self._remove(msg.fields[0].type.name) |
| msg = ProtoMessage(name=msg.name, fields=field_msg.fields) |
| func = CppProtoMessageFunctionHandler(name=msg.name, |
| exprs=field_func.exprs, |
| creator=field_func.creator) |
| self.handlers[name] = DomatoBuilder.Entry(msg, func) |
| self._update(name) |
| has_made_changes = True |
| return has_made_changes |
| |
| def _merge_strings(self) -> bool: |
| """Merges following CppString, e.g. |
| [ CppString("<first>"), CppString("<second>")] |
| Into: |
| [ CppString("<first><second>")] |
| """ |
| has_made_changes = False |
| for name in self.handlers: |
| func: CppFunctionHandler = self.handlers[name].func |
| if not func.is_message_handler or len(func.exprs) <= 1: |
| continue |
| |
| exprs = [] |
| prev = func.exprs[0] |
| for i in range(1, len(func.exprs)): |
| cur = func.exprs[i] |
| if isinstance(prev, CppStringExpr) and isinstance(cur, CppStringExpr): |
| cur = CppStringExpr(prev.content + cur.content) |
| has_made_changes = True |
| else: |
| exprs.append(prev) |
| prev = cur |
| exprs.append(prev) |
| func.exprs = exprs |
| return has_made_changes |
| |
| def _remove_unlinked_nodes(self) -> bool: |
| """Removes proto messages that are neither part of the root definition nor |
| referenced by any other messages. This can happen during other optimization |
| functions. |
| |
| Returns: |
| whether a change was made. |
| """ |
| to_remove = set() |
| for name in (n for n in self.handlers if n != self.root): |
| if name not in self.backrefs or len(self.backrefs[name]) == 0: |
| to_remove.add(name) |
| seen = set() |
| |
| def visit_msg(msg: ProtoMessage): |
| if msg.name in seen: |
| return |
| seen.add(msg.name) |
| for field in msg.fields: |
| if field.type.name in self.handlers: |
| visit_msg(self.handlers[field.type.name].msg) |
| |
| visit_msg(self.handlers[self.root].msg) |
| not_seen = set(self.handlers.keys()) - seen |
| to_remove.update(set(filter(lambda x: x != self.root, not_seen))) |
| for t in to_remove: |
| self._remove(t) |
| return len(to_remove) > 0 |
| |
| def _split_oneof_internal(self, entry): |
| assert entry.msg.is_one_of |
| low_name = self.create_internal_message() |
| high_name = self.create_internal_message() |
| fields_low = copy.copy(entry.msg.fields[:len(entry.msg.fields) // 2]) |
| fields_high = copy.copy(entry.msg.fields[len(entry.msg.fields) // 2:]) |
| low = OneOfProtoMessage(low_name, fields=fields_low, oneofname='oneoffield') |
| high = OneOfProtoMessage(high_name, |
| fields=fields_high, |
| oneofname='oneoffield') |
| low_cases = {} |
| for field in low.fields: |
| low_cases[field.name] = entry.func.cases[field.name] |
| high_cases = {} |
| for field in high.fields: |
| high_cases[field.name] = entry.func.cases[field.name] |
| func_low = CppOneOfMessageFunctionHandler(low_name, |
| switch_name='oneoffield', |
| cases=low_cases) |
| func_high = CppOneOfMessageFunctionHandler(high_name, |
| switch_name='oneoffield', |
| cases=high_cases) |
| entry.msg.fields = [ |
| ProtoField(type=ProtoType(low_name), name='line_1', proto_id=1), |
| ProtoField(type=ProtoType(high_name), name='line_2', proto_id=2), |
| ] |
| entry.func = CppOneOfMessageFunctionHandler( |
| f'{entry.msg.name}', |
| switch_name='oneoffield', |
| cases={ |
| 'line_1': [ |
| CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{low_name}', |
| field_name='line_1') |
| ], |
| 'line_2': [ |
| CppHandlerCallExpr(handler=f'{CPP_HANDLER_PREFIX}{high_name}', |
| field_name='line_2') |
| ] |
| }) |
| self.handlers[low_name] = DomatoBuilder.Entry(low, func_low) |
| self.handlers[high_name] = DomatoBuilder.Entry(high, func_high) |
| self.backrefs[low_name] = [entry.msg.name] |
| self.backrefs[high_name] = [entry.msg.name] |
| for field in low.fields: |
| self.backrefs[field.type.name].remove(entry.msg.name) |
| self.backrefs[field.type.name].append(low_name) |
| for field in high.fields: |
| self.backrefs[field.type.name].remove(entry.msg.name) |
| self.backrefs[field.type.name].append(high_name) |
| |
| def _split_oneofs(self): |
| """Splits oneofs that are too big and that would grow protobuf files. |
| """ |
| for entry in self.handlers.values(): |
| if entry.msg.is_one_of and len(entry.msg.fields) > 200: |
| self._split_oneof_internal(entry) |
| return True |
| return False |
| |
| def _split_protos(self, num_files: int): |
| """Splits the current proto definitions graph into multiple files |
| referencing each others. This helps reducing the overall pb.cc compile |
| time. |
| |
| Args: |
| num_files: the number of files to be generated. |
| |
| Returns: |
| a list of Files. |
| """ |
| graph_rep = {} |
| for v in self.handlers.values(): |
| graph_rep[v.msg.name] = set() |
| for f in (f for f in v.msg.fields if f.type.name in self.handlers): |
| graph_rep[v.msg.name].add(f.type.name) |
| components = tarjan(graph_rep) |
| |
| assert self.root in components[-1] |
| |
| def weight(elts): |
| return sum([len(self.handlers[entry].msg.fields) + 1 for entry in elts]) |
| |
| def _get_comp_list(max_weight): |
| comp_list = [] |
| current_list = [] |
| for comp in reversed(components): |
| if weight(comp) + weight(current_list) > max_weight: |
| comp_list.append(copy.copy(current_list)) |
| current_list.clear() |
| current_list += comp |
| if len(current_list) > 0: |
| comp_list.append(current_list) |
| return comp_list |
| |
| total_weight = sum( |
| [len(entry.msg.fields) + 1 for entry in self.handlers.values()]) |
| # we purposefuly take a greater number here so that we can lower that until |
| # we have the correct number of file being generated. |
| cur_weight = total_weight / (num_files + 10) |
| comp_list = _get_comp_list(cur_weight) |
| while len(comp_list) > num_files: |
| cur_weight *= 1.2 |
| comp_list = _get_comp_list(cur_weight) |
| |
| return comp_list |
| |
| def _fusion_similar_messages_impl(self, new_msg_name, messages): |
| is_one_of = self.handlers[messages[0]].msg.is_one_of |
| if is_one_of: |
| new_msg = OneOfProtoMessage(new_msg_name, |
| fields=self.handlers[messages[0]].msg.fields, |
| oneofname='oneoffield') |
| new_func = CppOneOfMessageFunctionHandler( |
| new_msg_name, 'oneoffield', |
| {field.name: [CppStringExpr('')] |
| for field in new_msg.fields}) |
| else: |
| new_msg = ProtoMessage(new_msg_name, |
| fields=self.handlers[messages[0]].msg.fields) |
| new_func = CppProtoMessageFunctionHandler(new_msg_name, exprs=[]) |
| self._add(new_msg, new_func) |
| res = [] |
| for e in messages: |
| self.handlers[e].func.proto_type = new_func.proto_type |
| for b in self.backrefs[e]: |
| for f in (f for f in self.handlers[b].msg.fields if f.type.name == e): |
| f.type.name = new_msg.name |
| self.backrefs[new_msg.name].append(b) |
| res.append(self.handlers[e].func) |
| self._remove(e) |
| return res |
| |
| def _fusion_similar_messages(self): |
| oneof_entries = collections.defaultdict(list) |
| msg_entries = collections.defaultdict(list) |
| for e in (e for e in self.handlers.values() if e.msg.is_one_of): |
| h = hash(''.join(f.type.name for f in e.msg.fields)) |
| oneof_entries[h].append(e.msg.name) |
| for e in (e for e in self.handlers.values() if not e.msg.is_one_of): |
| h = hash(''.join(f.type.name for f in e.msg.fields)) |
| msg_entries[h].append(e.msg.name) |
| res = {} |
| for value in list(oneof_entries.values()) + list(msg_entries.values()): |
| if len(value) <= 1: |
| continue |
| msg_name = self.create_internal_message() |
| res[msg_name] = self._fusion_similar_messages_impl(msg_name, value) |
| self._remove_unlinked_nodes() |
| return res |
| |
| |
| def _render_internal(template: jinja2.Template, |
| context: typing.Dict[str, typing.Any], out_f: str): |
| with action_helpers.atomic_output(out_f, mode='w') as f: |
| f.write(template.render(context)) |
| |
| |
| def _render_proto_internal( |
| template: jinja2.Template, out_f: str, |
| proto_messages: typing.List[typing.Union[ProtoMessage, OneOfProtoMessage]], |
| should_generate_repeated_lines: bool, proto_ns: str, |
| imports: typing.List[str]): |
| _render_internal(template, { |
| 'messages': [m for m in proto_messages if not m.is_one_of], |
| 'oneofmessages': [m for m in proto_messages if m.is_one_of], |
| 'generate_repeated_lines': should_generate_repeated_lines, |
| 'proto_ns': proto_ns, |
| 'imports': imports, |
| }, |
| out_f=out_f) |
| |
| |
| def to_relative_path(generated_dir: str, filepath: str): |
| return str( |
| pathlib.PurePosixPath(generated_dir).joinpath( |
| pathlib.PurePosixPath(filepath).name)) |
| |
| |
| def render_proto(environment: jinja2.Environment, generated_dir: str, |
| out_f: str, name: str, builder: DomatoBuilder, |
| files: typing.List[File]): |
| template = environment.get_template('domatolpm.proto.tmpl') |
| ns = f'{BASE_PROTO_NS}.{name}' |
| for file in files: |
| _render_proto_internal( |
| template, f'{file.name}.proto', file.protos, False, ns, [ |
| to_relative_path(generated_dir, f'{dep.name}.proto') |
| for dep in file.deps |
| ]) |
| root, _ = builder.get_root() |
| _render_proto_internal( |
| template, f'{out_f}.proto', [root], builder.root == 'line', ns, [ |
| to_relative_path(generated_dir, f'{file.name}.proto') |
| for file in files if builder.root in (m.name for m in file.protos) |
| ]) |
| |
| |
| def render_cpp(environment: jinja2.Environment, gen_dir: str, out_f: str, |
| name: str, builder: DomatoBuilder, files: typing.List[File]): |
| for file in files: |
| funcs = [f for f in file.cpps if f.is_message_handler] |
| oneofs = [f for f in file.cpps if f.is_oneof_handler] |
| stfunctions = [f for f in file.cpps if f.is_string_table_handler] |
| has_line = 'line' in (f.type.name for msg in file.protos |
| for f in msg.fields) |
| rendering_context = { |
| 'includes': |
| [to_relative_path(gen_dir, f'{dep.name}.h') for dep in file.deps], |
| 'functions': |
| funcs, |
| 'oneoffunctions': |
| oneofs, |
| 'stfunctions': |
| stfunctions, |
| 'root': |
| None, |
| 'generate_root': |
| False, |
| 'generate_repeated_lines': |
| False, |
| 'generate_one_line_handler': |
| has_line, |
| 'line_prefix': |
| builder.get_line_prefix(), |
| 'line_suffix': |
| builder.get_line_suffix(), |
| 'proto_ns': |
| to_cpp_ns(f'{BASE_PROTO_NS}.{name}'), |
| 'cpp_ns': |
| f'domatolpm::{name}', |
| } |
| rendering_context['includes'].append( |
| to_relative_path(gen_dir, f'{file.name}.h')) |
| rendering_context['includes'].append( |
| to_relative_path(gen_dir, f'{file.name}.pb.h')) |
| template = environment.get_template('domatolpm.cc.tmpl') |
| _render_internal(template, rendering_context, f'{file.name}.cc') |
| rendering_context['includes'] = [ |
| to_relative_path(gen_dir, f'{file.name}.pb.h') |
| ] |
| template = environment.get_template('domatolpm.h.tmpl') |
| _render_internal(template, rendering_context, f'{file.name}.h') |
| _, root_func = builder.get_root() |
| |
| rendering_context = { |
| 'includes': |
| [to_relative_path(gen_dir, f'{file.name}.h') |
| for file in files] + [f'{os.path.basename(out_f)}.pb.h'], |
| 'functions': [], |
| 'oneoffunctions': [], |
| 'stfunctions': [], |
| 'root': |
| root_func, |
| 'generate_root': |
| True, |
| 'generate_repeated_lines': |
| builder.root == 'line', |
| 'generate_one_line_handler': |
| builder.root == 'line', |
| 'line_prefix': |
| builder.get_line_prefix(), |
| 'line_suffix': |
| builder.get_line_suffix(), |
| 'proto_ns': |
| to_cpp_ns(f'{BASE_PROTO_NS}.{name}'), |
| 'cpp_ns': |
| f'domatolpm::{name}', |
| } |
| template = environment.get_template('domatolpm.cc.tmpl') |
| _render_internal(template, rendering_context, f'{out_f}.cc') |
| template = environment.get_template('domatolpm.h.tmpl') |
| _render_internal(template, rendering_context, f'{out_f}.h') |
| |
| |
| def main(): |
| parser = argparse.ArgumentParser( |
| description= |
| 'Generate the necessary files for DomatoLPM to function properly.') |
| parser.add_argument('-p', |
| '--path', |
| required=True, |
| help='The path to a Domato grammar file.') |
| parser.add_argument('-n', |
| '--name', |
| required=True, |
| help='The name of this grammar.') |
| parser.add_argument( |
| '-f', |
| '--file-format', |
| required=True, |
| help='The path prefix to which the files should be generated.') |
| parser.add_argument('-d', |
| '--generated-dir', |
| required=True, |
| help='The path to the target gen directory.') |
| parser.add_argument('-s', |
| '--stabilize-grammar', |
| required=False, |
| default=False, |
| action='store_true', |
| help='Whether we should stabilize the proto generation.' |
| 'Grammars should not have duplicate lines') |
| |
| args = parser.parse_args() |
| g = grammar.Grammar() |
| g.parse_from_file(filename=args.path) |
| |
| template_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
| 'templates') |
| environment = jinja2.Environment(loader=jinja2.FileSystemLoader(template_dir)) |
| builder = DomatoBuilder(g, args.stabilize_grammar) |
| builder.parse_grammar() |
| builder.simplify() |
| files = builder.split_files(f'{args.file_format}_sub', file_num=12) |
| render_cpp(environment, args.generated_dir, args.file_format, args.name, |
| builder, files) |
| render_proto(environment, args.generated_dir, args.file_format, args.name, |
| builder, files) |
| |
| |
| if __name__ == '__main__': |
| main() |