| # Licensed to the Software Freedom Conservancy (SFC) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The SFC licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| |
| """ |
| Generate Python WebDriver BiDi command modules from CDDL specification. |
| |
| This generator reads CDDL (Concise Data Definition Language) specification files |
| and produces Python type definitions and command classes that conform to the |
| WebDriver BiDi protocol. |
| |
| Usage: |
| python generate_bidi.py <cddl_file> <output_dir> <spec_version> |
| |
| Example: |
| python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0 |
| """ |
| |
| import argparse |
| import importlib.util |
| import logging |
| import re |
| import sys |
| from dataclasses import dataclass, field |
| from enum import Enum |
| from pathlib import Path |
| from textwrap import indent as tw_indent |
| from typing import Any |
| |
| __version__ = "1.0.0" |
| |
| # Logging setup |
| log_level = logging.INFO |
| logging.basicConfig(level=log_level) |
| logger = logging.getLogger("generate_bidi") |
| |
| # File headers |
| SHARED_HEADER = """# DO NOT EDIT THIS FILE! |
| # |
| # This file is generated from the WebDriver BiDi specification. If you need to make |
| # changes, edit the generator and regenerate all of the modules.""" |
| |
| # Split header: comments section and imports section are separate so a |
| # module_docstring can be injected between them (before imports → real __doc__). |
| _MODULE_HEADER_COMMENTS = f"""{SHARED_HEADER} |
| # |
| # WebDriver BiDi module: {{}} |
| """ |
| _MODULE_HEADER_IMPORTS = "from __future__ import annotations\n\n" |
| |
| |
| def indent(s: str, n: int) -> str: |
| """Indent a string by n spaces.""" |
| return tw_indent(s, n * " ") |
| |
| |
| def _docstring_text(custom: str | None, fallback_name: str, fallback_desc: str = "") -> str: |
| """Select the appropriate raw docstring text (no triple-quotes). |
| |
| Priority: manifest custom string > CDDL description (if different from name) > 'ClassName.' |
| """ |
| if custom: |
| return custom.strip() |
| if fallback_desc and fallback_desc != fallback_name: |
| return fallback_desc |
| return f"{fallback_name}." |
| |
| |
| def _emit_docstring(text: str, indent_width: int) -> str: |
| r"""Produce a PEP 257-compliant docstring block with a trailing newline. |
| |
| Single-line output: <indent>\"\"\"text.\"\"\"\n |
| Multi-line output: |
| <indent>\"\"\" |
| <indent>First line. |
| <indent> |
| <indent>Continuation. |
| <indent>\"\"\" |
| The opening and closing triple-quotes each occupy their own line for |
| multi-line strings so that inspect.getdoc() / Sphinx dedent cleanly. |
| """ |
| prefix = " " * indent_width |
| stripped = text.strip() |
| lines = stripped.splitlines() |
| if len(lines) <= 1: |
| return f'{prefix}"""{stripped}"""\n' |
| content = "\n".join(f"{prefix}{line}" if line.strip() else "" for line in lines) |
| return f'{prefix}"""\n{content}\n{prefix}"""\n' |
| |
| |
| def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]: |
| """Load enhancement manifest from a Python file. |
| |
| Args: |
| manifest_path: Path to Python file containing ENHANCEMENTS dict |
| |
| Returns: |
| Dictionary with enhancement rules, or empty dict if no manifest provided |
| """ |
| if not manifest_path: |
| return {} |
| |
| manifest_file = Path(manifest_path) |
| if not manifest_file.exists(): |
| logger.warning(f"Enhancement manifest not found: {manifest_path}") |
| return {} |
| |
| try: |
| spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file) |
| if spec is None or spec.loader is None: |
| logger.warning(f"Could not load manifest: {manifest_path}") |
| return {} |
| |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| |
| enhancements = getattr(module, "ENHANCEMENTS", {}) |
| dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {}) |
| method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {}) |
| |
| logger.info(f"Loaded enhancement manifest from: {manifest_path}") |
| logger.debug(f"Enhancements for modules: {list(enhancements.keys())}") |
| |
| return { |
| "enhancements": enhancements, |
| "dataclass_methods": dataclass_methods, |
| "method_docstrings": method_docstrings, |
| } |
| except Exception as e: |
| logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True) |
| return {} |
| |
| |
| class CddlType(Enum): |
| """CDDL type mappings to Python types.""" |
| |
| TSTR = "str" # text string |
| TEXT = "str" # text (alias) |
| UINT = "int" # unsigned integer |
| INT = "int" # signed integer |
| NINT = "int" # negative integer |
| BOOL = "bool" # boolean |
| NULL = "None" # null |
| ANY = "Any" # any type |
| |
| @classmethod |
| def get_annotation(cls, cddl_type: str) -> str: |
| """Get Python type annotation for a CDDL type.""" |
| cddl_type = cddl_type.strip().lower() |
| |
| # Handle basic types |
| for member in cls: |
| if cddl_type == member.name.lower(): |
| return member.value |
| |
| # Handle composite types |
| if cddl_type.startswith("["): # Array |
| inner = cddl_type.strip("[]+ ") |
| inner_type = cls.get_annotation(inner) |
| return f"list[{inner_type}]" |
| |
| if cddl_type.startswith("{"): # Map/Dict |
| return "dict[str, Any]" |
| |
| # Default to Any for unknown types |
| return "Any" |
| |
| |
| @dataclass |
| class CddlCommand: |
| """Represents a CDDL command definition.""" |
| |
| module: str |
| name: str |
| params: dict[str, str] = field(default_factory=dict) |
| required_params: set[str] = field(default_factory=set) |
| result: str | None = None |
| description: str = "" |
| |
| def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str: |
| """Generate Python method code for this command. |
| |
| Args: |
| enhancements: Dictionary with enhancement rules for this method |
| """ |
| enhancements = enhancements or {} |
| method_name = self._camel_to_snake(self.name) |
| |
| # Build parameter list with type hints |
| # Check if there's a params_override for user-friendly named arguments |
| params_to_use = self.params |
| if "params_override" in enhancements: |
| params_to_use = enhancements["params_override"] |
| |
| param_strs = [] |
| param_names = [] # Keep track of parameter names for later use |
| for param_name, param_type in params_to_use.items(): |
| if param_type in ["bool", "str", "int"]: |
| python_type = param_type |
| else: |
| python_type = CddlType.get_annotation(param_type) |
| snake_param = self._camel_to_snake(param_name) |
| param_names.append((param_name, snake_param)) |
| param_strs.append(f"{snake_param}: {python_type} | None = None") |
| |
| if param_strs: |
| # Check if full signature would exceed line length limit (120 chars) |
| single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):" |
| if len(single_line_signature) > 120: |
| # Format parameters on multiple lines |
| body = f" def {method_name}(\n" |
| body += " self,\n" |
| for i, param_str in enumerate(param_strs): |
| if i < len(param_strs) - 1: |
| body += f" {param_str},\n" |
| else: |
| body += f" {param_str},\n" |
| body += " ):\n" |
| else: |
| param_list = "self, " + ", ".join(param_strs) |
| body = f" def {method_name}({param_list}):\n" |
| else: |
| body = f" def {method_name}(self):\n" |
| docstring = enhancements.get("docstring") or self.description or f"Execute {self.module}.{self.name}." |
| body += _emit_docstring(docstring, 8) |
| |
| # Add automatic validation for required parameters |
| # (This is used unless there's no required_params, in which case all params are optional) |
| if self.required_params: |
| method_snake = self._camel_to_snake(self.name) |
| for param_name, snake_param in param_names: |
| if param_name in self.required_params: |
| body += f" if {snake_param} is None:\n" |
| msg = f"{method_snake}() missing required argument:" |
| error_message = f"{msg} {snake_param!r}" |
| body += f" raise TypeError({error_message!r})\n" |
| body += "\n" |
| |
| # Add validation if specified in enhancements (for additional business logic validation) |
| if "validate" in enhancements: |
| validate_func = enhancements["validate"] |
| # Build parameter list for validation function |
| param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names) |
| body += f" {validate_func}({param_args})\n" |
| body += "\n" |
| |
| # Add transformation and preprocessing |
| # First, check if any transform is needed |
| if "transform" in enhancements: |
| transform_spec = enhancements["transform"] |
| if isinstance(transform_spec, dict): |
| # New format with explicit function and result parameter |
| transform_func = transform_spec.get("func") |
| result_param = transform_spec.get("result_param", "params") |
| input_params = [ |
| transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k) |
| ] |
| |
| if transform_func and result_param: |
| body += f" {result_param} = None\n" |
| param_args = ", ".join(input_params) |
| body += f" {result_param} = {transform_func}({param_args})\n" |
| body += "\n" |
| else: |
| # Legacy format for backward compatibility |
| transform_func = transform_spec |
| if self.name == "setDownloadBehavior": |
| body += " download_behavior = None\n" |
| body += f" download_behavior = {transform_func}(allowed, destination_folder)\n" |
| body += "\n" |
| |
| # Add preprocessing for serialization (check for to_bidi_dict method) |
| if "preprocess" in enhancements: |
| preprocess_rules = enhancements["preprocess"] |
| for param_name, preprocess_type in preprocess_rules.items(): |
| snake_param = self._camel_to_snake(param_name) |
| if preprocess_type == "check_serialize_method": |
| body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n" |
| body += f" {snake_param} = {snake_param}.to_bidi_dict()\n" |
| body += "\n" |
| |
| # Build params dict |
| body += " params = {\n" |
| |
| # If there's a transform with a result parameter, map it to the BiDi protocol name |
| if "transform" in enhancements and isinstance(enhancements["transform"], dict): |
| transform_spec = enhancements["transform"] |
| result_param = transform_spec.get("result_param") |
| |
| # Map the result parameter to the original CDDL parameter name |
| if result_param == "download_behavior": |
| body += ' "downloadBehavior": download_behavior,\n' |
| # Add remaining parameters that weren't part of the transform |
| for cddl_param_name in self.params: |
| if cddl_param_name not in ["downloadBehavior"]: |
| snake_name = self._camel_to_snake(cddl_param_name) |
| body += f' "{cddl_param_name}": {snake_name},\n' |
| else: |
| # Standard parameter mapping from CDDL |
| for param_name, snake_param in param_names: |
| body += f' "{param_name}": {snake_param},\n' |
| |
| body += " }\n" |
| body += " params = {k: v for k, v in params.items() if v is not None}\n" |
| body += f' cmd = command_builder("{self.module}.{self.name}", params)\n' |
| body += " result = self._conn.execute(cmd)\n" |
| |
| # Add response handling for extraction/deserialization |
| if "extract_field" in enhancements: |
| extract_field = enhancements["extract_field"] |
| extract_property = enhancements.get("extract_property") |
| |
| # Check if we also need to deserialize the extracted field |
| deserialize_rules = enhancements.get("deserialize", {}) |
| |
| if extract_property: |
| # Extract property from list items |
| body += f' if result and "{extract_field}" in result:\n' |
| body += f' items = result.get("{extract_field}", [])\n' |
| body += " return [\n" |
| body += f' item.get("{extract_property}")\n' |
| body += " for item in items\n" |
| body += " if isinstance(item, dict)\n" |
| body += " ]\n" |
| body += " return []\n" |
| elif extract_field in deserialize_rules: |
| # Extract field and deserialize to typed objects |
| type_name = deserialize_rules[extract_field] |
| body += f' if result and "{extract_field}" in result:\n' |
| body += f' items = result.get("{extract_field}", [])\n' |
| body += " return [\n" |
| body += f" {type_name}(\n" |
| body += self._generate_field_args(extract_field, type_name) |
| body += " )\n" |
| body += " for item in items\n" |
| body += " if isinstance(item, dict)\n" |
| body += " ]\n" |
| body += " return []\n" |
| else: |
| # Simple field extraction (return the value directly, not wrapped in result dict) |
| body += f' if result and "{extract_field}" in result:\n' |
| body += f' extracted = result.get("{extract_field}")\n' |
| body += " return extracted\n" |
| body += " return result\n" |
| elif "deserialize" in enhancements: |
| # Deserialize response to typed objects (legacy, without extract_field) |
| deserialize_rules = enhancements["deserialize"] |
| for response_field, type_name in deserialize_rules.items(): |
| body += f' if result and "{response_field}" in result:\n' |
| body += f' items = result.get("{response_field}", [])\n' |
| body += " return [\n" |
| body += f" {type_name}(\n" |
| body += self._generate_field_args(response_field, type_name) |
| body += " )\n" |
| body += " for item in items\n" |
| body += " if isinstance(item, dict)\n" |
| body += " ]\n" |
| body += " return []\n" |
| else: |
| # No special response handling, just return the result |
| body += " return result\n" |
| |
| return body |
| |
| def _generate_field_args(self, response_field: str, type_name: str) -> str: |
| """Generate constructor arguments for deserializing response objects. |
| |
| For now, this handles ClientWindowInfo and Info specifically. |
| Could be extended to be more generic. |
| """ |
| if type_name == "ClientWindowInfo": |
| return ( |
| ' active=item.get("active"),\n' |
| ' client_window=item.get("clientWindow"),\n' |
| ' height=item.get("height"),\n' |
| ' state=item.get("state"),\n' |
| ' width=item.get("width"),\n' |
| ' x=item.get("x"),\n' |
| ' y=item.get("y")\n' |
| ) |
| elif type_name == "Info": |
| return ( |
| ' children=_deserialize_info_list(item.get("children", [])),\n' |
| ' client_window=item.get("clientWindow"),\n' |
| ' context=item.get("context"),\n' |
| ' original_opener=item.get("originalOpener"),\n' |
| ' url=item.get("url"),\n' |
| ' user_context=item.get("userContext"),\n' |
| ' parent=item.get("parent")\n' |
| ) |
| # For other types, return empty |
| return "" |
| |
| @staticmethod |
| def _camel_to_snake(name: str) -> str: |
| """Convert camelCase to snake_case.""" |
| name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) |
| return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() |
| |
| |
| @dataclass |
| class CddlTypeDefinition: |
| """Represents a CDDL type definition.""" |
| |
| module: str |
| name: str |
| fields: dict[str, str] = field(default_factory=dict) |
| description: str = "" |
| |
| def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str: |
| """Generate Python dataclass code for this type. |
| |
| Args: |
| enhancements: Dictionary containing dataclass_methods and method_docstrings |
| """ |
| enhancements = enhancements or {} |
| dataclass_methods = enhancements.get("dataclass_methods", {}) |
| method_docstrings = enhancements.get("method_docstrings", {}) |
| |
| # Generate class name from type name (keep it as-is, don't split on underscores) |
| class_name = self.name |
| code = "@dataclass\n" |
| code += f"class {class_name}:\n" |
| class_docstrings = enhancements.get("class_docstrings", {}) |
| class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description) |
| code += _emit_docstring(class_doc, 4) |
| code += "\n" |
| |
| if not self.fields: |
| code += " pass\n" |
| else: |
| for field_name, field_type in self.fields.items(): |
| # Convert CDDL type to Python type |
| python_type = self._get_python_type(field_type) |
| snake_name = CddlCommand._camel_to_snake(field_name) |
| |
| # Check if the CDDL field type is a quoted string literal (e.g., type: "key") |
| # These are discriminant fields: auto-populate and exclude from __init__ |
| # so callers don't need to pass them as positional or keyword arguments. |
| literal_match = re.match(r'^"([^"]+)"$', field_type.strip()) |
| if literal_match: |
| literal_value = literal_match.group(1) |
| code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n' |
| # Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax) |
| elif python_type.startswith("list["): |
| # Remove the trailing ' | None' from list types since default_factory=list ensures non-None |
| type_annotation = python_type.replace(" | None", "") |
| code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n" |
| # Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax) |
| elif python_type.startswith("dict["): |
| # Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None |
| type_annotation = python_type.replace(" | None", "") |
| code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n" |
| else: |
| code += f" {snake_name}: {python_type} = None\n" |
| |
| # Add custom methods if defined for this class |
| if class_name in dataclass_methods: |
| code += "\n" |
| methods_dict = dataclass_methods[class_name] |
| docstrings_dict = method_docstrings.get(class_name, {}) |
| |
| for method_name in methods_dict: |
| method_impl = methods_dict[method_name] |
| docstring = docstrings_dict.get(method_name, "") |
| code += f" def {method_name}(self):\n" |
| if docstring: |
| code += f' """{docstring}"""\n' |
| code += f" {method_impl}\n" |
| code += "\n" |
| |
| return code |
| |
| @staticmethod |
| def _get_python_type(cddl_type: str) -> str: |
| """Convert CDDL type to Python type annotation using Python 3.10+ syntax.""" |
| cddl_type = cddl_type.strip().lower() |
| |
| # Handle basic types |
| type_mapping = { |
| "tstr": "str", |
| "text": "str", |
| "uint": "int", |
| "int": "int", |
| "nint": "int", |
| "bool": "bool", |
| "null": "None", |
| } |
| |
| for cddl, python in type_mapping.items(): |
| if cddl_type == cddl: |
| # Use Python 3.10+ union syntax: type | None |
| return f"{python} | None" |
| |
| # Handle arrays |
| if cddl_type.startswith("["): |
| inner = cddl_type.strip("[]+ ") |
| inner_type = CddlTypeDefinition._get_python_type(inner) |
| # Remove " | None" from inner type since it might be wrapped |
| if " | None" in inner_type: |
| inner_base = inner_type.replace(" | None", "") |
| return f"list[{inner_base} | None] | None" |
| return f"list[{inner_type}] | None" |
| |
| # Handle maps/dicts |
| if cddl_type.startswith("{"): |
| return "dict[str, Any] | None" |
| |
| # Default to Any for unknown/complex types |
| return "Any | None" |
| |
| |
| @dataclass |
| class CddlEnum: |
| """Represents a CDDL enum definition (string union).""" |
| |
| module: str |
| name: str |
| values: list[str] = field(default_factory=list) |
| description: str = "" |
| |
| def to_python_class(self, enhancements: dict[str, Any] | None = None) -> str: |
| """Generate Python enum class code. |
| |
| Generates a simple class with string constants to match the existing |
| pattern in the codebase (e.g., ClientWindowState). |
| """ |
| enhancements = enhancements or {} |
| class_name = self.name |
| class_docstrings = enhancements.get("class_docstrings", {}) |
| class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description) |
| code = f"class {class_name}:\n" |
| code += _emit_docstring(class_doc, 4) |
| code += "\n" |
| |
| for value in self.values: |
| # Convert value to UPPER_SNAKE_CASE constant name |
| const_name = self._value_to_const_name(value) |
| code += f' {const_name} = "{value}"\n' |
| |
| return code |
| |
| @staticmethod |
| def _value_to_const_name(value: str) -> str: |
| """Convert enum string value to constant name. |
| |
| Examples: |
| "none" -> "NONE" |
| "portrait-primary" -> "PORTRAIT_PRIMARY" |
| "interactive" -> "INTERACTIVE" |
| """ |
| # Replace hyphens with underscores |
| const_name = value.replace("-", "_") |
| # Convert to uppercase |
| return const_name.upper() |
| |
| |
| @dataclass |
| class CddlEvent: |
| """Represents a CDDL event definition (incoming message from browser).""" |
| |
| module: str |
| name: str |
| method: str |
| params_type: str |
| description: str = "" |
| |
| def to_python_dataclass(self) -> str: |
| """Generate Python dataclass code for the event info type. |
| |
| Returns a dataclass code that attempts to use globals().get() for safety. |
| """ |
| class_name = self.name |
| |
| # Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info") |
| # The params_type comes from the CDDL and includes module prefix |
| type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type |
| |
| # Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly |
| # (NavigationInfo will be created as an alias to it) |
| if type_name == "NavigationInfo": |
| type_name = "BaseNavigationInfo" |
| |
| # Generate type alias using globals().get() for safety |
| code = f"# Event: {self.method}\n" |
| code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n" |
| |
| return code |
| |
| |
| @dataclass |
| class CddlModule: |
| """Represents a CDDL module (e.g., script, network, browsing_context).""" |
| |
| name: str |
| commands: list[CddlCommand] = field(default_factory=list) |
| types: list[CddlTypeDefinition] = field(default_factory=list) |
| enums: list[CddlEnum] = field(default_factory=list) |
| events: list[CddlEvent] = field(default_factory=list) |
| |
| @staticmethod |
| def _convert_method_to_event_name(method_suffix: str) -> str: |
| """Convert BiDi method suffix to friendly event name. |
| |
| Examples: |
| "contextCreated" -> "context_created" |
| "navigationStarted" -> "navigation_started" |
| "userPromptOpened" -> "user_prompt_opened" |
| """ |
| # Convert camelCase to snake_case |
| s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix) |
| return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() |
| |
| def generate_code(self, enhancements: dict[str, Any] | None = None) -> str: |
| """Generate Python code for this module. |
| |
| Args: |
| enhancements: Dictionary with module-level enhancements |
| """ |
| enhancements = enhancements or {} |
| module_docstring = enhancements.get("module_docstring", "") |
| code = _MODULE_HEADER_COMMENTS.format(self.name) |
| if module_docstring: |
| code += _emit_docstring(module_docstring, 0) + "\n" |
| code += _MODULE_HEADER_IMPORTS |
| |
| # Collect needed imports to avoid duplicates |
| needs_command_builder = bool(self.commands) |
| needs_dataclass = self.commands or self.types or self.events |
| needs_callable = self.events |
| |
| stdlib_imports = [] |
| local_imports = [] |
| |
| # Add imports (field import will be added conditionally after code generation) |
| if needs_callable: |
| stdlib_imports.append("from collections.abc import Callable") |
| if needs_dataclass: |
| stdlib_imports.append("from dataclasses import dataclass") |
| stdlib_imports.append("from typing import Any") |
| |
| if needs_command_builder: |
| local_imports.append("from selenium.webdriver.common.bidi.common import command_builder") |
| if self.events: |
| local_imports.append( |
| "from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager" |
| ) |
| |
| code += "\n".join(stdlib_imports) + "\n" |
| if local_imports: |
| code += "\n" + "\n".join(local_imports) + "\n" |
| |
| code += "\n" |
| |
| # Add helper function definitions from enhancements |
| # Collect all referenced helper functions (validate, transform) |
| helper_funcs_to_add = set() |
| for cmd in self.commands: |
| method_name_snake = cmd._camel_to_snake(cmd.name) |
| method_enhancements = enhancements.get(method_name_snake, {}) |
| if "validate" in method_enhancements: |
| helper_funcs_to_add.add(("validate", method_enhancements["validate"])) |
| if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict): |
| transform_spec = method_enhancements["transform"] |
| if "func" in transform_spec: |
| helper_funcs_to_add.add(("transform", transform_spec["func"])) |
| |
| # Generate helper functions if needed |
| if helper_funcs_to_add: |
| for func_type, func_name in sorted(helper_funcs_to_add): |
| if func_type == "validate" and func_name == "validate_download_behavior": |
| code += """def validate_download_behavior( |
| allowed: bool | None, |
| destination_folder: str | None, |
| user_contexts: Any | None = None, |
| ) -> None: |
| \"\"\"Validate download behavior parameters. |
| |
| Args: |
| allowed: Whether downloads are allowed |
| destination_folder: Destination folder for downloads |
| user_contexts: Optional list of user contexts |
| |
| Raises: |
| ValueError: If parameters are invalid |
| \"\"\" |
| if allowed is True and not destination_folder: |
| raise ValueError("destination_folder is required when allowed=True") |
| if allowed is False and destination_folder: |
| raise ValueError("destination_folder should not be provided when allowed=False") |
| |
| |
| """ |
| elif func_type == "transform" and func_name == "transform_download_params": |
| code += """def transform_download_params( |
| allowed: bool | None, |
| destination_folder: str | None, |
| ) -> dict[str, Any] | None: |
| \"\"\"Transform download parameters into download_behavior object. |
| |
| Args: |
| allowed: Whether downloads are allowed |
| destination_folder: Destination folder for downloads (accepts str or |
| pathlib.Path; will be coerced to str) |
| |
| Returns: |
| Dictionary representing the download_behavior object, or None if allowed is None |
| \"\"\" |
| if allowed is True: |
| return { |
| "type": "allowed", |
| # Coerce pathlib.Path (or any path-like) to str so the BiDi |
| # protocol always receives a plain JSON string. |
| "destinationFolder": str(destination_folder) if destination_folder is not None else None, |
| } |
| elif allowed is False: |
| return {"type": "denied"} |
| else: # None — reset to browser default (sent as JSON null) |
| return None |
| |
| |
| """ |
| |
| # Generate enums first (excluding those in exclude_types) |
| exclude_types = set(enhancements.get("exclude_types", [])) |
| |
| # Also exclude any types that have extra_dataclasses overrides |
| # Extract class names from extra_dataclasses strings |
| for extra_cls in enhancements.get("extra_dataclasses", []): |
| # Match "class ClassName:" pattern |
| match = re.search(r"class\s+(\w+)\s*:", extra_cls) |
| if match: |
| exclude_types.add(match.group(1)) |
| |
| for enum_def in self.enums: |
| if enum_def.name in exclude_types: |
| continue |
| code += enum_def.to_python_class(enhancements) |
| code += "\n\n" |
| |
| # Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level) |
| for alias, target in enhancements.get("aliases", {}).items(): |
| code += f"{alias} = {target}\n\n" |
| |
| # Generate type dataclasses, skipping any overridden by extra_dataclasses |
| for type_def in self.types: |
| if type_def.name in exclude_types: |
| continue |
| code += type_def.to_python_dataclass(enhancements) |
| code += "\n\n" |
| |
| # Emit extra dataclasses from enhancement manifest (non-CDDL additions) |
| for extra_cls in enhancements.get("extra_dataclasses", []): |
| code += extra_cls |
| code += "\n\n" |
| |
| # Emit extra type aliases from enhancement manifest (e.g., union types for events) |
| for extra_alias in enhancements.get("extra_type_aliases", []): |
| code += extra_alias |
| code += "\n\n" |
| |
| # NOTE: Don't generate event type aliases here - they reference types that may not be defined yet |
| # They will be generated after the class definition instead |
| |
| # Generate EVENT_NAME_MAPPING for the module (before the module class) |
| if self.events: |
| # Generate EVENT_NAME_MAPPING for the module |
| code += "# BiDi Event Name to Parameter Type Mapping\n" |
| code += "EVENT_NAME_MAPPING = {\n" |
| for event_def in self.events: |
| # Convert method name to user-friendly event name |
| # e.g., "browsingContext.contextCreated" -> "context_created" |
| method_parts = event_def.method.split(".") |
| if len(method_parts) == 2: |
| event_name = self._convert_method_to_event_name(method_parts[1]) |
| code += f' "{event_name}": "{event_def.method}",\n' |
| # Extra events not in the CDDL spec (e.g. Chromium-specific events) |
| for extra_evt in enhancements.get("extra_events", []): |
| code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n' |
| code += "}\n\n" |
| |
| # Add custom method function definitions before the class (for browsingContext) |
| if self.name == "browsingContext": |
| # Add helper function for recursive Info deserialization |
| code += """def _deserialize_info_list(items: list) -> list | None: |
| \"\"\"Recursively deserialize a list of dicts to Info objects. |
| |
| Args: |
| items: List of dicts from the API response |
| |
| Returns: |
| List of Info objects with properly nested children, or None if empty |
| \"\"\" |
| if not items or not isinstance(items, list): |
| return None |
| |
| result = [] |
| for item in items: |
| if isinstance(item, dict): |
| # Recursively deserialize children only if the key exists in response |
| children_list = None |
| if "children" in item: |
| children_list = _deserialize_info_list(item.get("children", [])) |
| info = Info( |
| children=children_list, |
| client_window=item.get("clientWindow"), |
| context=item.get("context"), |
| original_opener=item.get("originalOpener"), |
| url=item.get("url"), |
| user_context=item.get("userContext"), |
| parent=item.get("parent"), |
| ) |
| result.append(info) |
| return result if result else None |
| |
| |
| """ |
| code += "\n\n" |
| |
| # EventConfig, _EventWrapper, and _EventManager are imported from |
| # selenium.webdriver.common.bidi._event_manager (see import section above) |
| # rather than being duplicated inline in every generated module. |
| if False: # placeholder to preserve indentation structure |
| pass |
| |
| # Generate class |
| # Convert module name (camelCase or snake_case) to proper class name (PascalCase) |
| class_name = module_name_to_class_name(self.name) |
| class_docstrings = enhancements.get("class_docstrings", {}) |
| module_class_doc = _docstring_text( |
| class_docstrings.get(class_name), |
| class_name, |
| f"WebDriver BiDi {self.name} module.", |
| ) |
| code += f"class {class_name}:\n" |
| code += _emit_docstring(module_class_doc, 4) |
| code += "\n" |
| |
| # Add EVENT_CONFIGS dict if there are events |
| if self.events: |
| code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined |
| |
| if self.name == "script": |
| code += " def __init__(self, conn, driver=None) -> None:\n" |
| code += " self._conn = conn\n" |
| code += " self._driver = driver\n" |
| else: |
| code += " def __init__(self, conn) -> None:\n" |
| code += " self._conn = conn\n" |
| |
| # Initialize _event_manager if there are events |
| if self.events: |
| code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n" |
| |
| # Append extra init code from enhancements (e.g. self.intercepts = []) |
| for init_line in enhancements.get("extra_init_code", []): |
| code += f" {init_line}\n" |
| |
| code += "\n" |
| |
| # Generate command methods |
| exclude_methods = enhancements.get("exclude_methods", []) |
| |
| # Automatically exclude methods that are defined in extra_methods |
| # to prevent generating duplicates |
| if "extra_methods" in enhancements: |
| for extra_method in enhancements["extra_methods"]: |
| # Extract method name from "def method_name(" |
| match = re.search(r"def\s+(\w+)\s*\(", extra_method) |
| if match: |
| exclude_methods = list(exclude_methods) + [match.group(1)] |
| |
| if self.commands: |
| command_docstrings = enhancements.get("command_docstrings", {}) |
| for command in self.commands: |
| # Get method-specific enhancements |
| # Convert command name to snake_case to match enhancement manifest keys |
| method_name_snake = command._camel_to_snake(command.name) |
| if method_name_snake in exclude_methods: |
| continue |
| method_enhancements = enhancements.get(method_name_snake, {}) |
| # Inject command_docstrings entry if no per-method docstring is set |
| if method_name_snake in command_docstrings and "docstring" not in method_enhancements: |
| method_enhancements = {**method_enhancements, "docstring": command_docstrings[method_name_snake]} |
| code += command.to_python_method(method_enhancements) |
| code += "\n" |
| elif not self.events and not enhancements.get("extra_methods", []): |
| code += " pass\n" |
| |
| # Emit extra methods from enhancement manifest |
| for extra_method in enhancements.get("extra_methods", []): |
| code += extra_method |
| code += "\n" |
| |
| # Add delegating event handler methods if events are present |
| if self.events: |
| code += """ |
| def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int: |
| \"\"\"Add an event handler. |
| |
| Args: |
| event: The event to subscribe to. |
| callback: The callback function to execute on event. |
| contexts: The context IDs to subscribe to (optional). |
| |
| Returns: |
| The callback ID. |
| \"\"\" |
| return self._event_manager.add_event_handler(event, callback, contexts) |
| |
| def remove_event_handler(self, event: str, callback_id: int) -> None: |
| \"\"\"Remove an event handler. |
| |
| Args: |
| event: The event to unsubscribe from. |
| callback_id: The callback ID. |
| \"\"\" |
| return self._event_manager.remove_event_handler(event, callback_id) |
| |
| def clear_event_handlers(self) -> None: |
| \"\"\"Clear all event handlers.\"\"\" |
| return self._event_manager.clear_event_handlers() |
| """ |
| |
| # Generate event info type aliases AFTER the class definition |
| # This ensures all types are available when we create the aliases |
| if self.events: |
| code += "\n# Event Info Type Aliases\n" |
| # Check for explicit event_type_aliases in the enhancement manifest |
| event_type_aliases = enhancements.get("event_type_aliases", {}) |
| for event_def in self.events: |
| # Convert method name to user-friendly event name |
| method_parts = event_def.method.split(".") |
| if len(method_parts) == 2: |
| event_name = self._convert_method_to_event_name(method_parts[1]) |
| # Check if there's an explicit alias defined in the enhancement manifest |
| if event_name in event_type_aliases: |
| # Use the alias directly |
| type_name = event_type_aliases[event_name] |
| code += f"# Event: {event_def.method}\n" |
| code += f"{event_def.name} = {type_name}\n" |
| else: |
| # Fall back to the original behavior |
| code += event_def.to_python_dataclass() |
| code += "\n" |
| |
| # Now populate EVENT_CONFIGS after the aliases are defined |
| code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n" |
| # Use globals() to look up types dynamically to handle missing types gracefully |
| code += "_globals = globals()\n" |
| code += f"{class_name}.EVENT_CONFIGS = {{\n" |
| for event_def in self.events: |
| # Convert method name to user-friendly event name |
| method_parts = event_def.method.split(".") |
| if len(method_parts) == 2: |
| event_name = self._convert_method_to_event_name(method_parts[1]) |
| # Try to get event class from globals, default to dict if not found |
| getter = f'_globals.get("{event_def.name}", dict)' |
| condition = f'_globals.get("{event_def.name}")' |
| event_class = f"{getter} if {condition} else dict" |
| |
| # Build the entry line and check if it exceeds 120 chars |
| single_line = ( |
| f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),' |
| ) |
| |
| if len(single_line) > 120: |
| # Break into multiple lines |
| code += f' "{event_name}": EventConfig(\n' |
| code += f' "{event_name}",\n' |
| code += f' "{event_def.method}",\n' |
| code += f" {event_class},\n" |
| code += " ),\n" |
| else: |
| code += single_line + "\n" |
| # Extra events not in the CDDL spec |
| for extra_evt in enhancements.get("extra_events", []): |
| ek = extra_evt["event_key"] |
| be = extra_evt["bidi_event"] |
| ec = extra_evt["event_class"] |
| code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n' |
| code += "}\n" |
| |
| # Check if field() is actually used in the generated code |
| # If so, add the field import after the dataclass import |
| if "field(" in code: |
| # Find where to insert the field import |
| # It should go after "from dataclasses import dataclass" line |
| dataclass_import_pattern = r"from dataclasses import dataclass\n" |
| if re.search(dataclass_import_pattern, code): |
| code = re.sub( |
| dataclass_import_pattern, |
| "from dataclasses import dataclass, field\n", |
| code, |
| count=1, |
| ) |
| elif "from dataclasses import" not in code: |
| # If there's no dataclasses import yet, add field import after typing |
| code = code.replace( |
| "from typing import Any\n", |
| "from dataclasses import field\nfrom typing import Any\n", |
| ) |
| |
| return code |
| |
| |
| class CddlParser: |
| """Parse CDDL specification files.""" |
| |
| def __init__(self, cddl_path: str): |
| """Initialize parser with CDDL file path.""" |
| self.cddl_path = Path(cddl_path) |
| self.content = "" |
| self.modules: dict[str, CddlModule] = {} |
| self.definitions: dict[str, str] = {} |
| self.event_names: set[str] = set() # Names of definitions that are events |
| self._read_file() |
| |
| def _read_file(self) -> None: |
| """Read and preprocess CDDL file.""" |
| if not self.cddl_path.exists(): |
| raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}") |
| |
| with open(self.cddl_path, encoding="utf-8") as f: |
| self.content = f.read() |
| |
| logger.info(f"Loaded CDDL file: {self.cddl_path}") |
| |
| def parse(self) -> dict[str, CddlModule]: |
| """Parse CDDL content and return modules.""" |
| # Remove comments |
| content = self._remove_comments(self.content) |
| |
| # Extract all definitions |
| self._extract_definitions(content) |
| |
| # Extract event names from event union definitions |
| self._extract_event_names() |
| |
| # Extract type definitions by module |
| self._extract_types() |
| |
| # Extract event definitions by module |
| self._extract_events() |
| |
| # Extract command definitions by module |
| self._extract_commands() |
| |
| # If no modules found, create a default one from the filename |
| if not self.modules: |
| module_name = self.cddl_path.stem |
| default_module = CddlModule(name=module_name) |
| self.modules[module_name] = default_module |
| logger.warning(f"No modules found in CDDL, creating default: {module_name}") |
| |
| return self.modules |
| |
| def _remove_comments(self, content: str) -> str: |
| """Remove comments from CDDL content.""" |
| # CDDL uses ; for comments to end of line |
| lines = content.split("\n") |
| cleaned = [] |
| for line in lines: |
| if ";" in line and not line.strip().startswith(";"): |
| line = line[: line.index(";")] |
| elif line.strip().startswith(";"): |
| continue |
| cleaned.append(line) |
| return "\n".join(cleaned) |
| |
| def _extract_definitions(self, content: str) -> None: |
| """Extract CDDL definitions (type definitions, commands, etc.).""" |
| # Match pattern: Name = Definition |
| # Handles multiline definitions properly |
| pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)" |
| |
| for match in re.finditer(pattern, content, re.DOTALL): |
| name = match.group(1).strip() |
| definition = match.group(2).strip() |
| self.definitions[name] = definition |
| logger.debug(f"Extracted definition: {name}") |
| |
| def _extract_event_names(self) -> None: |
| """Extract event names from event union definitions. |
| |
| Event union definitions follow pattern: |
| module.ModuleEvent = ( |
| module.EventName1 // |
| module.EventName2 // |
| ... |
| ) |
| """ |
| for def_name, def_content in self.definitions.items(): |
| # Check if this looks like an event union (name ends with "Event") and |
| # contains a module-qualified reference like "module.EventName". |
| # Handles both single-item (no //) and multi-item (// separated) unions. |
| if "Event" in def_name and re.search(r"\w+\.\w+", def_content): |
| # Extract event names from the union (works for single and multi-item) |
| event_refs = re.findall(r"(\w+\.\w+)", def_content) |
| for event_ref in event_refs: |
| self.event_names.add(event_ref) |
| logger.debug(f"Identified event: {event_ref} (from {def_name})") |
| |
| def _extract_types(self) -> None: |
| """Extract type definitions from parsed definitions.""" |
| # Type definitions follow pattern: module.TypeName = { field: type, ... } |
| # They have dots in the name and curly braces in the content |
| # But they DON'T have method: "..." pattern (which means it's not a command) |
| # Enums follow pattern: module.EnumName = "value1" / "value2" / ... |
| |
| for def_name, def_content in self.definitions.items(): |
| # Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible") |
| if "." not in def_name: |
| continue |
| |
| # Skip if it's a command (contains method: pattern) |
| if "method:" in def_content: |
| continue |
| |
| # Extract module.TypeName |
| if "." in def_name: |
| module_name, type_name = def_name.rsplit(".", 1) |
| |
| # Create module if not exists |
| if module_name not in self.modules: |
| self.modules[module_name] = CddlModule(name=module_name) |
| |
| # Check if this is an enum (string union with /) |
| if self._is_enum_definition(def_content): |
| # Extract enum values |
| values = self._extract_enum_values(def_content) |
| if values: |
| enum_def = CddlEnum( |
| module=module_name, |
| name=type_name, |
| values=values, |
| description=f"{type_name}", |
| ) |
| self.modules[module_name].enums.append(enum_def) |
| logger.debug(f"Found enum: {def_name} with {len(values)} values") |
| else: |
| # Extract fields from type definition |
| fields = self._extract_type_fields(def_content) |
| |
| if fields: # Only create type if it has fields |
| type_def = CddlTypeDefinition( |
| module=module_name, |
| name=type_name, |
| fields=fields, |
| description=f"{type_name}", |
| ) |
| self.modules[module_name].types.append(type_def) |
| logger.debug(f"Found type: {def_name} with {len(fields)} fields") |
| |
| def _is_enum_definition(self, definition: str) -> bool: |
| """Check if a definition is an enum (string union with /). |
| |
| Enums are defined as: "value1" / "value2" / "value3" |
| """ |
| # Clean whitespace |
| clean_def = definition.strip() |
| |
| # Must not have curly braces (that would be a type definition) |
| if "{" in clean_def or "}" in clean_def: |
| return False |
| |
| # Must contain the union operator / surrounded by quotes |
| # Pattern: "something" / "something_else" |
| return " / " in clean_def and '"' in clean_def |
| |
| def _extract_enum_values(self, enum_definition: str) -> list[str]: |
| """Extract individual values from an enum definition. |
| |
| Enums are defined as: "value1" / "value2" / "value3" |
| Can span multiple lines. |
| """ |
| values = [] |
| |
| # Clean the definition and extract quoted strings |
| # Split by / and extract quoted values |
| parts = enum_definition.split("/") |
| |
| for part in parts: |
| part = part.strip() |
| |
| # Extract quoted string - use search instead of match to find quotes anywhere |
| match = re.search(r'"([^"]*)"', part) |
| if match: |
| value = match.group(1) |
| values.append(value) |
| logger.debug(f"Extracted enum value: {value}") |
| |
| return values |
| |
| @staticmethod |
| def _normalize_cddl_type(field_type: str) -> str: |
| """Normalize a CDDL type expression to a simple Python-compatible form. |
| |
| Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and |
| replaces interval/constraint expressions with their base types so that |
| the caller can safely check for nested struct syntax. |
| |
| Examples: |
| '(float .ge 0.0) .default 1.0' -> 'float' |
| '(float .ge 0.0) / null' -> 'float / null' |
| '(0.0...360.0) / null' -> 'float / null' |
| '-90.0..90.0' -> 'float' |
| 'float / null .default null' -> 'float / null' |
| """ |
| result = field_type |
| # Remove trailing .default <value> annotations |
| result = re.sub(r"\s*\.default\s+\S+", "", result) |
| # Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType |
| result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result) |
| # Replace parenthesised numeric interval types: (0.0...360.0) -> float |
| result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result) |
| # Replace bare numeric interval types: -90.0..90.0 -> float |
| result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result) |
| return result.strip() |
| |
| def _extract_type_fields(self, type_definition: str) -> dict[str, str]: |
| """Extract fields from a type definition block.""" |
| fields = {} |
| |
| # Remove outer braces |
| clean_def = type_definition.strip() |
| if clean_def.startswith("{"): |
| clean_def = clean_def[1:] |
| if clean_def.endswith("}"): |
| clean_def = clean_def[:-1] |
| |
| # Parse each line for field: type patterns |
| for line in clean_def.split("\n"): |
| line = line.strip() |
| if not line or "Extensible" in line or line.startswith("//"): |
| continue |
| |
| # Match pattern: [?] fieldName: type |
| match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) |
| if not match: |
| # Try without optional marker |
| match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) |
| |
| if match: |
| field_name = match.group(1).strip() |
| field_type = match.group(2).strip() |
| normalized_type = self._normalize_cddl_type(field_type) |
| |
| # Skip lines that are part of nested definitions |
| if "{" not in normalized_type and "(" not in normalized_type: |
| fields[field_name] = normalized_type |
| logger.debug(f"Extracted field {field_name}: {normalized_type}") |
| |
| return fields |
| |
| def _extract_events(self) -> None: |
| """Extract event definitions from parsed definitions. |
| |
| Events are definitions that: |
| 1. Are listed in an event union (e.g., BrowsingContextEvent) |
| 2. Have method: "..." and params: ... fields |
| |
| Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType) |
| """ |
| # Find definitions that are in the event_names set |
| event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") |
| |
| for def_name, def_content in self.definitions.items(): |
| # Skip if not identified as an event |
| if def_name not in self.event_names: |
| continue |
| |
| # Extract method and params |
| match = event_pattern.search(def_content) |
| if match: |
| method = match.group(1) # e.g., "browsingContext.contextCreated" |
| params_type = match.group(2) # e.g., "browsingContext.Info" |
| |
| # Extract module name from method |
| if "." in method: |
| module_name, _ = method.split(".", 1) |
| |
| # Create module if not exists |
| if module_name not in self.modules: |
| self.modules[module_name] = CddlModule(name=module_name) |
| |
| # Extract event name from definition name (e.g., browsingContext.ContextCreated) |
| _, event_name = def_name.rsplit(".", 1) |
| |
| # Create event |
| event = CddlEvent( |
| module=module_name, |
| name=event_name, |
| method=method, |
| params_type=params_type, |
| description=f"Event: {method}", |
| ) |
| |
| self.modules[module_name].events.append(event) |
| logger.debug(f"Found event: {def_name} (method={method}, params={params_type})") |
| |
| def _extract_commands(self) -> None: |
| """Extract command definitions from parsed definitions.""" |
| # Find command definitions that follow pattern: module.Command = (method: "...", params: ...) |
| command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)") |
| |
| for def_name, def_content in self.definitions.items(): |
| # Skip definitions that are events (they share the same pattern) |
| if def_name in self.event_names: |
| continue |
| matches = list(command_pattern.finditer(def_content)) |
| if matches: |
| for match in matches: |
| method = match.group(1) # e.g., "session.new" |
| params_type = match.group(2) # e.g., "session.NewParameters" |
| |
| # Extract module name from method |
| if "." in method: |
| module_name, command_name = method.split(".", 1) |
| |
| # Create module if not exists |
| if module_name not in self.modules: |
| self.modules[module_name] = CddlModule(name=module_name) |
| |
| # Extract parameters and required parameters |
| params, required_params = self._extract_parameters_and_required(params_type) |
| |
| # Create command |
| cmd = CddlCommand( |
| module=module_name, |
| name=command_name, |
| params=params, |
| required_params=required_params, |
| description=f"Execute {method}", |
| ) |
| |
| self.modules[module_name].commands.append(cmd) |
| logger.debug(f"Found command: {method} with params {params_type}") |
| |
| def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]: |
| """Extract parameters from a parameter type definition. |
| |
| Handles both struct types ({...}) and top-level union types (TypeA / TypeB), |
| merging all fields from each alternative as optional parameters. |
| """ |
| params, _ = self._extract_parameters_and_required(params_type, _seen) |
| return params |
| |
| def _extract_parameters_and_required( |
| self, params_type: str, _seen: set[str] | None = None |
| ) -> tuple[dict[str, str], set[str]]: |
| """Extract parameters and track which are required from a parameter type definition. |
| |
| Returns: |
| Tuple of (params dict, required_params set) |
| """ |
| params = {} |
| required = set() |
| |
| if _seen is None: |
| _seen = set() |
| if params_type in _seen: |
| return params, required |
| _seen.add(params_type) |
| |
| if params_type not in self.definitions: |
| logger.debug(f"Parameter type not found: {params_type}") |
| return params, required |
| |
| definition = self.definitions[params_type] |
| |
| # Handle top-level type alias that is a union of other named types: |
| # e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest |
| # These definitions contain a single line with "/" separating type names |
| # (not the double-slash "//" used for command unions). |
| stripped = definition.strip() |
| if not stripped.startswith("{") and "/" in stripped and "//" not in stripped: |
| # Each token separated by "/" should be a named type reference |
| alternatives = [a.strip() for a in stripped.split("/") if a.strip()] |
| all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives) |
| if all_named: |
| # For union types, collect parameters from all alternatives |
| # but treat them as optional since the caller only needs to pass one alternative |
| for alt_type in alternatives: |
| alt_params, _ = self._extract_parameters_and_required(alt_type, _seen) |
| params.update(alt_params) |
| # Note: We intentionally DON'T add to required, since these are union alternatives |
| return params, required |
| |
| # Remove the outer curly braces and split by comma |
| # Then parse each line for key: type patterns |
| clean_def = stripped |
| if clean_def.startswith("{"): |
| clean_def = clean_def[1:] |
| if clean_def.endswith("}"): |
| clean_def = clean_def[:-1] |
| |
| # Split by newlines and process each line |
| for line in clean_def.split("\n"): |
| line = line.strip() |
| if not line or "Extensible" in line: |
| continue |
| |
| # Match pattern: [?] name: type |
| # Check if parameter has optional marker (?) |
| is_optional = line.startswith("?") |
| |
| # Using a simple pattern that handles optional prefix |
| match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) |
| if not match: |
| # Try without optional marker |
| match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line) |
| |
| if match: |
| param_name = match.group(1).strip() |
| param_type = match.group(2).strip() |
| normalized_type = self._normalize_cddl_type(param_type) |
| |
| # Skip lines that are part of nested definitions |
| if "{" not in normalized_type and "(" not in normalized_type: |
| params[param_name] = normalized_type |
| if not is_optional: |
| required.add(param_name) |
| logger.debug( |
| f"Extracted param {param_name}: {normalized_type} " |
| f"(required={not is_optional}) from {params_type}" |
| ) |
| |
| return params, required |
| |
| |
| def module_name_to_class_name(module_name: str) -> str: |
| """Convert module name to class name (PascalCase). |
| |
| Handles both camelCase (browsingContext) and snake_case (browsing_context). |
| """ |
| if "_" in module_name: |
| # Snake_case: browsing_context -> BrowsingContext |
| return "".join(word.capitalize() for word in module_name.split("_")) |
| else: |
| # CamelCase: browsingContext -> BrowsingContext |
| return module_name[0].upper() + module_name[1:] if module_name else "" |
| |
| |
| def module_name_to_filename(module_name: str) -> str: |
| """Convert module name to Python filename (snake_case). |
| |
| Handles both camelCase (browsingContext) and snake_case (browsing_context). |
| Special cases: |
| - browsingContext -> browsing_context |
| - webExtension -> webextension |
| """ |
| # Handle explicit mappings for known camelCase names |
| camel_to_snake_map = { |
| "browsingContext": "browsing_context", |
| "webExtension": "webextension", |
| } |
| |
| if module_name in camel_to_snake_map: |
| return camel_to_snake_map[module_name] |
| |
| if "_" in module_name: |
| # Already snake_case |
| return module_name |
| else: |
| # Convert camelCase to snake_case for other cases |
| # This handles cases like "myModuleName" -> "my_module_name" |
| import re |
| |
| s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name) |
| return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() |
| |
| |
| def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None: |
| """Generate __init__.py file for the module.""" |
| init_path = output_path / "__init__.py" |
| |
| code = f"""{SHARED_HEADER} |
| |
| from __future__ import annotations |
| |
| """ |
| |
| for module_name in sorted(modules.keys()): |
| class_name = module_name_to_class_name(module_name) |
| filename = module_name_to_filename(module_name) |
| code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n" |
| |
| code += "\n__all__ = [\n" |
| for module_name in sorted(modules.keys()): |
| class_name = module_name_to_class_name(module_name) |
| code += f' "{class_name}",\n' |
| code += "]\n" |
| |
| with open(init_path, "w", encoding="utf-8") as f: |
| f.write(code) |
| |
| logger.info(f"Generated: {init_path}") |
| |
| |
| def generate_common_file(output_path: Path) -> None: |
| """Generate common.py file with shared utilities.""" |
| common_path = output_path / "common.py" |
| |
| code = ( |
| "# Licensed to the Software Freedom Conservancy (SFC) under one\n" |
| "# or more contributor license agreements. See the NOTICE file\n" |
| "# distributed with this work for additional information\n" |
| "# regarding copyright ownership. The SFC licenses this file\n" |
| "# to you under the Apache License, Version 2.0 (the\n" |
| '# "License"); you may not use this file except in compliance\n' |
| "# with the License. You may obtain a copy of the License at\n" |
| "#\n" |
| "# http://www.apache.org/licenses/LICENSE-2.0\n" |
| "#\n" |
| "# Unless required by applicable law or agreed to in writing,\n" |
| "# software distributed under the License is distributed on an\n" |
| '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' |
| "# KIND, either express or implied. See the License for the\n" |
| "# specific language governing permissions and limitations\n" |
| "# under the License.\n" |
| "\n" |
| '"""Common utilities for BiDi command construction."""\n' |
| "\n" |
| "from __future__ import annotations\n" |
| "\n" |
| "from collections.abc import Generator\n" |
| "from typing import Any\n" |
| "\n" |
| "\n" |
| "def command_builder(\n" |
| " method: str, params: dict[str, Any] | None = None\n" |
| ") -> Generator[dict[str, Any], Any, Any]:\n" |
| ' """Build a BiDi command generator.\n' |
| "\n" |
| " Args:\n" |
| ' method: The BiDi method name (e.g., "session.status", "browser.close")\n' |
| " params: The parameters for the command. If omitted, an empty\n" |
| " dictionary is sent.\n" |
| "\n" |
| " Yields:\n" |
| " A dictionary representing the BiDi command\n" |
| "\n" |
| " Returns:\n" |
| " The result from the BiDi command execution\n" |
| ' """\n' |
| " if params is None:\n" |
| " params = {}\n" |
| ' result = yield {"method": method, "params": params}\n' |
| " return result\n" |
| ) |
| |
| with open(common_path, "w", encoding="utf-8") as f: |
| f.write(code) |
| |
| logger.info(f"Generated: {common_path}") |
| |
| |
| def generate_console_file(output_path: Path) -> None: |
| """Generate console.py file with Console enum helper.""" |
| console_path = output_path / "console.py" |
| |
| code = ( |
| "# Licensed to the Software Freedom Conservancy (SFC) under one\n" |
| "# or more contributor license agreements. See the NOTICE file\n" |
| "# distributed with this work for additional information\n" |
| "# regarding copyright ownership. The SFC licenses this file\n" |
| "# to you under the Apache License, Version 2.0 (the\n" |
| '# "License"); you may not use this file except in compliance\n' |
| "# with the License. You may obtain a copy of the License at\n" |
| "#\n" |
| "# http://www.apache.org/licenses/LICENSE-2.0\n" |
| "#\n" |
| "# Unless required by applicable law or agreed to in writing,\n" |
| "# software distributed under the License is distributed on an\n" |
| '# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n' |
| "# KIND, either express or implied. See the License for the\n" |
| "# specific language governing permissions and limitations\n" |
| "# under the License.\n" |
| "\n" |
| "from enum import Enum\n" |
| "\n" |
| "\n" |
| "class Console(Enum):\n" |
| ' ALL = "all"\n' |
| ' LOG = "log"\n' |
| ' ERROR = "error"\n' |
| ) |
| |
| with open(console_path, "w", encoding="utf-8") as f: |
| f.write(code) |
| |
| logger.info(f"Generated: {console_path}") |
| |
| |
| def main( |
| cddl_file: str, |
| output_dir: str, |
| spec_version: str = "1.0", |
| enhancements_manifest: str | None = None, |
| ) -> None: |
| """Main entry point. |
| |
| Args: |
| cddl_file: Path to CDDL specification file |
| output_dir: Output directory for generated modules |
| spec_version: BiDi spec version |
| enhancements_manifest: Path to enhancement manifest Python file |
| """ |
| output_path = Path(output_dir).resolve() |
| output_path.mkdir(parents=True, exist_ok=True) |
| |
| logger.info(f"WebDriver BiDi Code Generator v{__version__}") |
| logger.info(f"Input CDDL: {cddl_file}") |
| logger.info(f"Output directory: {output_path}") |
| logger.info(f"Spec version: {spec_version}") |
| |
| # Load enhancement manifest |
| manifest = load_enhancements_manifest(enhancements_manifest) |
| if manifest: |
| logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}") |
| |
| # Parse CDDL |
| parser = CddlParser(cddl_file) |
| modules = parser.parse() |
| |
| logger.info(f"Parsed {len(modules)} modules") |
| |
| # Clean up existing generated files. |
| # Keep static helper modules that are staged by Bazel (for example cdp.py) |
| # as part of create-bidi-src.extra_srcs. |
| preserved_python_files = {"py.typed", "cdp.py"} |
| for file_path in output_path.glob("*.py"): |
| if file_path.name not in preserved_python_files and not file_path.name.startswith("_"): |
| file_path.unlink() |
| logger.debug(f"Removed: {file_path}") |
| |
| # Generate module files using snake_case filenames |
| for module_name, module in sorted(modules.items()): |
| filename = module_name_to_filename(module_name) |
| module_path = output_path / f"{filename}.py" |
| |
| # Get module-specific enhancements (merge with dataclass templates) |
| module_enhancements = manifest.get("enhancements", {}).get(module_name, {}) |
| |
| # Add dataclass methods and docstrings to the enhancement data for this module |
| full_module_enhancements = { |
| **module_enhancements, |
| "dataclass_methods": manifest.get("dataclass_methods", {}), |
| "method_docstrings": manifest.get("method_docstrings", {}), |
| } |
| |
| with open(module_path, "w", encoding="utf-8") as f: |
| f.write(module.generate_code(full_module_enhancements)) |
| logger.info(f"Generated: {module_path}") |
| |
| # Generate __init__.py |
| generate_init_file(output_path, modules) |
| |
| # Generate common.py |
| generate_common_file(output_path) |
| |
| # Generate console.py |
| generate_console_file(output_path) |
| |
| # Create py.typed marker |
| py_typed_path = output_path / "py.typed" |
| py_typed_path.touch() |
| logger.info(f"Generated type marker: {py_typed_path}") |
| |
| logger.info("Code generation complete!") |
| |
| |
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification") |
| parser.add_argument( |
| "cddl_file", |
| help="Path to CDDL specification file", |
| ) |
| parser.add_argument( |
| "output_dir", |
| help="Output directory for generated Python modules", |
| ) |
| parser.add_argument( |
| "spec_version", |
| nargs="?", |
| default="1.0", |
| help="BiDi spec version (default: 1.0)", |
| ) |
| parser.add_argument( |
| "--enhancements-manifest", |
| default=None, |
| help="Path to enhancement manifest Python file (optional)", |
| ) |
| parser.add_argument( |
| "-v", |
| "--verbose", |
| action="store_true", |
| help="Enable verbose logging", |
| ) |
| |
| args = parser.parse_args() |
| |
| if args.verbose: |
| logging.getLogger("generate_bidi").setLevel(logging.DEBUG) |
| |
| try: |
| main( |
| args.cddl_file, |
| args.output_dir, |
| args.spec_version, |
| args.enhancements_manifest, |
| ) |
| sys.exit(0) |
| except Exception as e: |
| logger.error(f"Generation failed: {e}", exc_info=True) |
| sys.exit(1) |