blob: 9f88c720dd2b5372d8158dc39d2c6f8769cca025 [file] [edit]
# Licensed to the Software Freedom Conservancy (SFC) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The SFC licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Generate Python WebDriver BiDi command modules from CDDL specification.
This generator reads CDDL (Concise Data Definition Language) specification files
and produces Python type definitions and command classes that conform to the
WebDriver BiDi protocol.
Usage:
python generate_bidi.py <cddl_file> <output_dir> <spec_version>
Example:
python generate_bidi.py local.cddl ./selenium/webdriver/common/bidi 1.0
"""
import argparse
import importlib.util
import logging
import re
import sys
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from textwrap import indent as tw_indent
from typing import Any
__version__ = "1.0.0"
# Logging setup
log_level = logging.INFO
logging.basicConfig(level=log_level)
logger = logging.getLogger("generate_bidi")
# File headers
SHARED_HEADER = """# DO NOT EDIT THIS FILE!
#
# This file is generated from the WebDriver BiDi specification. If you need to make
# changes, edit the generator and regenerate all of the modules."""
# Split header: comments section and imports section are separate so a
# module_docstring can be injected between them (before imports → real __doc__).
_MODULE_HEADER_COMMENTS = f"""{SHARED_HEADER}
#
# WebDriver BiDi module: {{}}
"""
_MODULE_HEADER_IMPORTS = "from __future__ import annotations\n\n"
def indent(s: str, n: int) -> str:
"""Indent a string by n spaces."""
return tw_indent(s, n * " ")
def _docstring_text(custom: str | None, fallback_name: str, fallback_desc: str = "") -> str:
"""Select the appropriate raw docstring text (no triple-quotes).
Priority: manifest custom string > CDDL description (if different from name) > 'ClassName.'
"""
if custom:
return custom.strip()
if fallback_desc and fallback_desc != fallback_name:
return fallback_desc
return f"{fallback_name}."
def _emit_docstring(text: str, indent_width: int) -> str:
r"""Produce a PEP 257-compliant docstring block with a trailing newline.
Single-line output: <indent>\"\"\"text.\"\"\"\n
Multi-line output:
<indent>\"\"\"
<indent>First line.
<indent>
<indent>Continuation.
<indent>\"\"\"
The opening and closing triple-quotes each occupy their own line for
multi-line strings so that inspect.getdoc() / Sphinx dedent cleanly.
"""
prefix = " " * indent_width
stripped = text.strip()
lines = stripped.splitlines()
if len(lines) <= 1:
return f'{prefix}"""{stripped}"""\n'
content = "\n".join(f"{prefix}{line}" if line.strip() else "" for line in lines)
return f'{prefix}"""\n{content}\n{prefix}"""\n'
def load_enhancements_manifest(manifest_path: str | None) -> dict[str, Any]:
"""Load enhancement manifest from a Python file.
Args:
manifest_path: Path to Python file containing ENHANCEMENTS dict
Returns:
Dictionary with enhancement rules, or empty dict if no manifest provided
"""
if not manifest_path:
return {}
manifest_file = Path(manifest_path)
if not manifest_file.exists():
logger.warning(f"Enhancement manifest not found: {manifest_path}")
return {}
try:
spec = importlib.util.spec_from_file_location("bidi_enhancements", manifest_file)
if spec is None or spec.loader is None:
logger.warning(f"Could not load manifest: {manifest_path}")
return {}
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
enhancements = getattr(module, "ENHANCEMENTS", {})
dataclass_methods = getattr(module, "DATACLASS_METHOD_TEMPLATES", {})
method_docstrings = getattr(module, "DATACLASS_METHOD_DOCSTRINGS", {})
logger.info(f"Loaded enhancement manifest from: {manifest_path}")
logger.debug(f"Enhancements for modules: {list(enhancements.keys())}")
return {
"enhancements": enhancements,
"dataclass_methods": dataclass_methods,
"method_docstrings": method_docstrings,
}
except Exception as e:
logger.error(f"Failed to load enhancement manifest: {e}", exc_info=True)
return {}
class CddlType(Enum):
"""CDDL type mappings to Python types."""
TSTR = "str" # text string
TEXT = "str" # text (alias)
UINT = "int" # unsigned integer
INT = "int" # signed integer
NINT = "int" # negative integer
BOOL = "bool" # boolean
NULL = "None" # null
ANY = "Any" # any type
@classmethod
def get_annotation(cls, cddl_type: str) -> str:
"""Get Python type annotation for a CDDL type."""
cddl_type = cddl_type.strip().lower()
# Handle basic types
for member in cls:
if cddl_type == member.name.lower():
return member.value
# Handle composite types
if cddl_type.startswith("["): # Array
inner = cddl_type.strip("[]+ ")
inner_type = cls.get_annotation(inner)
return f"list[{inner_type}]"
if cddl_type.startswith("{"): # Map/Dict
return "dict[str, Any]"
# Default to Any for unknown types
return "Any"
@dataclass
class CddlCommand:
"""Represents a CDDL command definition."""
module: str
name: str
params: dict[str, str] = field(default_factory=dict)
required_params: set[str] = field(default_factory=set)
result: str | None = None
description: str = ""
def to_python_method(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python method code for this command.
Args:
enhancements: Dictionary with enhancement rules for this method
"""
enhancements = enhancements or {}
method_name = self._camel_to_snake(self.name)
# Build parameter list with type hints
# Check if there's a params_override for user-friendly named arguments
params_to_use = self.params
if "params_override" in enhancements:
params_to_use = enhancements["params_override"]
param_strs = []
param_names = [] # Keep track of parameter names for later use
for param_name, param_type in params_to_use.items():
if param_type in ["bool", "str", "int"]:
python_type = param_type
else:
python_type = CddlType.get_annotation(param_type)
snake_param = self._camel_to_snake(param_name)
param_names.append((param_name, snake_param))
param_strs.append(f"{snake_param}: {python_type} | None = None")
if param_strs:
# Check if full signature would exceed line length limit (120 chars)
single_line_signature = f" def {method_name}(self, {', '.join(param_strs)}):"
if len(single_line_signature) > 120:
# Format parameters on multiple lines
body = f" def {method_name}(\n"
body += " self,\n"
for i, param_str in enumerate(param_strs):
if i < len(param_strs) - 1:
body += f" {param_str},\n"
else:
body += f" {param_str},\n"
body += " ):\n"
else:
param_list = "self, " + ", ".join(param_strs)
body = f" def {method_name}({param_list}):\n"
else:
body = f" def {method_name}(self):\n"
docstring = enhancements.get("docstring") or self.description or f"Execute {self.module}.{self.name}."
body += _emit_docstring(docstring, 8)
# Add automatic validation for required parameters
# (This is used unless there's no required_params, in which case all params are optional)
if self.required_params:
method_snake = self._camel_to_snake(self.name)
for param_name, snake_param in param_names:
if param_name in self.required_params:
body += f" if {snake_param} is None:\n"
msg = f"{method_snake}() missing required argument:"
error_message = f"{msg} {snake_param!r}"
body += f" raise TypeError({error_message!r})\n"
body += "\n"
# Add validation if specified in enhancements (for additional business logic validation)
if "validate" in enhancements:
validate_func = enhancements["validate"]
# Build parameter list for validation function
param_args = ", ".join(f"{snake}={snake}" for _, snake in param_names)
body += f" {validate_func}({param_args})\n"
body += "\n"
# Add transformation and preprocessing
# First, check if any transform is needed
if "transform" in enhancements:
transform_spec = enhancements["transform"]
if isinstance(transform_spec, dict):
# New format with explicit function and result parameter
transform_func = transform_spec.get("func")
result_param = transform_spec.get("result_param", "params")
input_params = [
transform_spec.get(k) for k in ["allowed", "destination_folder"] if transform_spec.get(k)
]
if transform_func and result_param:
body += f" {result_param} = None\n"
param_args = ", ".join(input_params)
body += f" {result_param} = {transform_func}({param_args})\n"
body += "\n"
else:
# Legacy format for backward compatibility
transform_func = transform_spec
if self.name == "setDownloadBehavior":
body += " download_behavior = None\n"
body += f" download_behavior = {transform_func}(allowed, destination_folder)\n"
body += "\n"
# Add preprocessing for serialization (check for to_bidi_dict method)
if "preprocess" in enhancements:
preprocess_rules = enhancements["preprocess"]
for param_name, preprocess_type in preprocess_rules.items():
snake_param = self._camel_to_snake(param_name)
if preprocess_type == "check_serialize_method":
body += f" if {snake_param} and hasattr({snake_param}, 'to_bidi_dict'):\n"
body += f" {snake_param} = {snake_param}.to_bidi_dict()\n"
body += "\n"
# Build params dict
body += " params = {\n"
# If there's a transform with a result parameter, map it to the BiDi protocol name
if "transform" in enhancements and isinstance(enhancements["transform"], dict):
transform_spec = enhancements["transform"]
result_param = transform_spec.get("result_param")
# Map the result parameter to the original CDDL parameter name
if result_param == "download_behavior":
body += ' "downloadBehavior": download_behavior,\n'
# Add remaining parameters that weren't part of the transform
for cddl_param_name in self.params:
if cddl_param_name not in ["downloadBehavior"]:
snake_name = self._camel_to_snake(cddl_param_name)
body += f' "{cddl_param_name}": {snake_name},\n'
else:
# Standard parameter mapping from CDDL
for param_name, snake_param in param_names:
body += f' "{param_name}": {snake_param},\n'
body += " }\n"
body += " params = {k: v for k, v in params.items() if v is not None}\n"
body += f' cmd = command_builder("{self.module}.{self.name}", params)\n'
body += " result = self._conn.execute(cmd)\n"
# Add response handling for extraction/deserialization
if "extract_field" in enhancements:
extract_field = enhancements["extract_field"]
extract_property = enhancements.get("extract_property")
# Check if we also need to deserialize the extracted field
deserialize_rules = enhancements.get("deserialize", {})
if extract_property:
# Extract property from list items
body += f' if result and "{extract_field}" in result:\n'
body += f' items = result.get("{extract_field}", [])\n'
body += " return [\n"
body += f' item.get("{extract_property}")\n'
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
elif extract_field in deserialize_rules:
# Extract field and deserialize to typed objects
type_name = deserialize_rules[extract_field]
body += f' if result and "{extract_field}" in result:\n'
body += f' items = result.get("{extract_field}", [])\n'
body += " return [\n"
body += f" {type_name}(\n"
body += self._generate_field_args(extract_field, type_name)
body += " )\n"
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
else:
# Simple field extraction (return the value directly, not wrapped in result dict)
body += f' if result and "{extract_field}" in result:\n'
body += f' extracted = result.get("{extract_field}")\n'
body += " return extracted\n"
body += " return result\n"
elif "deserialize" in enhancements:
# Deserialize response to typed objects (legacy, without extract_field)
deserialize_rules = enhancements["deserialize"]
for response_field, type_name in deserialize_rules.items():
body += f' if result and "{response_field}" in result:\n'
body += f' items = result.get("{response_field}", [])\n'
body += " return [\n"
body += f" {type_name}(\n"
body += self._generate_field_args(response_field, type_name)
body += " )\n"
body += " for item in items\n"
body += " if isinstance(item, dict)\n"
body += " ]\n"
body += " return []\n"
else:
# No special response handling, just return the result
body += " return result\n"
return body
def _generate_field_args(self, response_field: str, type_name: str) -> str:
"""Generate constructor arguments for deserializing response objects.
For now, this handles ClientWindowInfo and Info specifically.
Could be extended to be more generic.
"""
if type_name == "ClientWindowInfo":
return (
' active=item.get("active"),\n'
' client_window=item.get("clientWindow"),\n'
' height=item.get("height"),\n'
' state=item.get("state"),\n'
' width=item.get("width"),\n'
' x=item.get("x"),\n'
' y=item.get("y")\n'
)
elif type_name == "Info":
return (
' children=_deserialize_info_list(item.get("children", [])),\n'
' client_window=item.get("clientWindow"),\n'
' context=item.get("context"),\n'
' original_opener=item.get("originalOpener"),\n'
' url=item.get("url"),\n'
' user_context=item.get("userContext"),\n'
' parent=item.get("parent")\n'
)
# For other types, return empty
return ""
@staticmethod
def _camel_to_snake(name: str) -> str:
"""Convert camelCase to snake_case."""
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
@dataclass
class CddlTypeDefinition:
"""Represents a CDDL type definition."""
module: str
name: str
fields: dict[str, str] = field(default_factory=dict)
description: str = ""
def to_python_dataclass(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python dataclass code for this type.
Args:
enhancements: Dictionary containing dataclass_methods and method_docstrings
"""
enhancements = enhancements or {}
dataclass_methods = enhancements.get("dataclass_methods", {})
method_docstrings = enhancements.get("method_docstrings", {})
# Generate class name from type name (keep it as-is, don't split on underscores)
class_name = self.name
code = "@dataclass\n"
code += f"class {class_name}:\n"
class_docstrings = enhancements.get("class_docstrings", {})
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
code += _emit_docstring(class_doc, 4)
code += "\n"
if not self.fields:
code += " pass\n"
else:
for field_name, field_type in self.fields.items():
# Convert CDDL type to Python type
python_type = self._get_python_type(field_type)
snake_name = CddlCommand._camel_to_snake(field_name)
# Check if the CDDL field type is a quoted string literal (e.g., type: "key")
# These are discriminant fields: auto-populate and exclude from __init__
# so callers don't need to pass them as positional or keyword arguments.
literal_match = re.match(r'^"([^"]+)"$', field_type.strip())
if literal_match:
literal_value = literal_match.group(1)
code += f' {snake_name}: str = field(default="{literal_value}", init=False)\n'
# Check if this field is a list type (using lowercase 'list[' from Python 3.10+ syntax)
elif python_type.startswith("list["):
# Remove the trailing ' | None' from list types since default_factory=list ensures non-None
type_annotation = python_type.replace(" | None", "")
code += f" {snake_name}: {type_annotation} = field(default_factory=list)\n"
# Check if this field is a dict type (using lowercase 'dict[' from Python 3.10+ syntax)
elif python_type.startswith("dict["):
# Remove the trailing ' | None' from dict types since default_factory=dict ensures non-None
type_annotation = python_type.replace(" | None", "")
code += f" {snake_name}: {type_annotation} = field(default_factory=dict)\n"
else:
code += f" {snake_name}: {python_type} = None\n"
# Add custom methods if defined for this class
if class_name in dataclass_methods:
code += "\n"
methods_dict = dataclass_methods[class_name]
docstrings_dict = method_docstrings.get(class_name, {})
for method_name in methods_dict:
method_impl = methods_dict[method_name]
docstring = docstrings_dict.get(method_name, "")
code += f" def {method_name}(self):\n"
if docstring:
code += f' """{docstring}"""\n'
code += f" {method_impl}\n"
code += "\n"
return code
@staticmethod
def _get_python_type(cddl_type: str) -> str:
"""Convert CDDL type to Python type annotation using Python 3.10+ syntax."""
cddl_type = cddl_type.strip().lower()
# Handle basic types
type_mapping = {
"tstr": "str",
"text": "str",
"uint": "int",
"int": "int",
"nint": "int",
"bool": "bool",
"null": "None",
}
for cddl, python in type_mapping.items():
if cddl_type == cddl:
# Use Python 3.10+ union syntax: type | None
return f"{python} | None"
# Handle arrays
if cddl_type.startswith("["):
inner = cddl_type.strip("[]+ ")
inner_type = CddlTypeDefinition._get_python_type(inner)
# Remove " | None" from inner type since it might be wrapped
if " | None" in inner_type:
inner_base = inner_type.replace(" | None", "")
return f"list[{inner_base} | None] | None"
return f"list[{inner_type}] | None"
# Handle maps/dicts
if cddl_type.startswith("{"):
return "dict[str, Any] | None"
# Default to Any for unknown/complex types
return "Any | None"
@dataclass
class CddlEnum:
"""Represents a CDDL enum definition (string union)."""
module: str
name: str
values: list[str] = field(default_factory=list)
description: str = ""
def to_python_class(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python enum class code.
Generates a simple class with string constants to match the existing
pattern in the codebase (e.g., ClientWindowState).
"""
enhancements = enhancements or {}
class_name = self.name
class_docstrings = enhancements.get("class_docstrings", {})
class_doc = _docstring_text(class_docstrings.get(class_name), class_name, self.description)
code = f"class {class_name}:\n"
code += _emit_docstring(class_doc, 4)
code += "\n"
for value in self.values:
# Convert value to UPPER_SNAKE_CASE constant name
const_name = self._value_to_const_name(value)
code += f' {const_name} = "{value}"\n'
return code
@staticmethod
def _value_to_const_name(value: str) -> str:
"""Convert enum string value to constant name.
Examples:
"none" -> "NONE"
"portrait-primary" -> "PORTRAIT_PRIMARY"
"interactive" -> "INTERACTIVE"
"""
# Replace hyphens with underscores
const_name = value.replace("-", "_")
# Convert to uppercase
return const_name.upper()
@dataclass
class CddlEvent:
"""Represents a CDDL event definition (incoming message from browser)."""
module: str
name: str
method: str
params_type: str
description: str = ""
def to_python_dataclass(self) -> str:
"""Generate Python dataclass code for the event info type.
Returns a dataclass code that attempts to use globals().get() for safety.
"""
class_name = self.name
# Extract the type name from params_type (e.g., "browsingContext.Info" -> "Info")
# The params_type comes from the CDDL and includes module prefix
type_name = self.params_type.split(".")[-1] if "." in self.params_type else self.params_type
# Special case: if the type is BaseNavigationInfo, use BaseNavigationInfo directly
# (NavigationInfo will be created as an alias to it)
if type_name == "NavigationInfo":
type_name = "BaseNavigationInfo"
# Generate type alias using globals().get() for safety
code = f"# Event: {self.method}\n"
code += f"{class_name} = globals().get('{type_name}', dict) # Fallback to dict if type not defined\n"
return code
@dataclass
class CddlModule:
"""Represents a CDDL module (e.g., script, network, browsing_context)."""
name: str
commands: list[CddlCommand] = field(default_factory=list)
types: list[CddlTypeDefinition] = field(default_factory=list)
enums: list[CddlEnum] = field(default_factory=list)
events: list[CddlEvent] = field(default_factory=list)
@staticmethod
def _convert_method_to_event_name(method_suffix: str) -> str:
"""Convert BiDi method suffix to friendly event name.
Examples:
"contextCreated" -> "context_created"
"navigationStarted" -> "navigation_started"
"userPromptOpened" -> "user_prompt_opened"
"""
# Convert camelCase to snake_case
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", method_suffix)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def generate_code(self, enhancements: dict[str, Any] | None = None) -> str:
"""Generate Python code for this module.
Args:
enhancements: Dictionary with module-level enhancements
"""
enhancements = enhancements or {}
module_docstring = enhancements.get("module_docstring", "")
code = _MODULE_HEADER_COMMENTS.format(self.name)
if module_docstring:
code += _emit_docstring(module_docstring, 0) + "\n"
code += _MODULE_HEADER_IMPORTS
# Collect needed imports to avoid duplicates
needs_command_builder = bool(self.commands)
needs_dataclass = self.commands or self.types or self.events
needs_callable = self.events
stdlib_imports = []
local_imports = []
# Add imports (field import will be added conditionally after code generation)
if needs_callable:
stdlib_imports.append("from collections.abc import Callable")
if needs_dataclass:
stdlib_imports.append("from dataclasses import dataclass")
stdlib_imports.append("from typing import Any")
if needs_command_builder:
local_imports.append("from selenium.webdriver.common.bidi.common import command_builder")
if self.events:
local_imports.append(
"from selenium.webdriver.common.bidi._event_manager import EventConfig, _EventWrapper, _EventManager"
)
code += "\n".join(stdlib_imports) + "\n"
if local_imports:
code += "\n" + "\n".join(local_imports) + "\n"
code += "\n"
# Add helper function definitions from enhancements
# Collect all referenced helper functions (validate, transform)
helper_funcs_to_add = set()
for cmd in self.commands:
method_name_snake = cmd._camel_to_snake(cmd.name)
method_enhancements = enhancements.get(method_name_snake, {})
if "validate" in method_enhancements:
helper_funcs_to_add.add(("validate", method_enhancements["validate"]))
if "transform" in method_enhancements and isinstance(method_enhancements["transform"], dict):
transform_spec = method_enhancements["transform"]
if "func" in transform_spec:
helper_funcs_to_add.add(("transform", transform_spec["func"]))
# Generate helper functions if needed
if helper_funcs_to_add:
for func_type, func_name in sorted(helper_funcs_to_add):
if func_type == "validate" and func_name == "validate_download_behavior":
code += """def validate_download_behavior(
allowed: bool | None,
destination_folder: str | None,
user_contexts: Any | None = None,
) -> None:
\"\"\"Validate download behavior parameters.
Args:
allowed: Whether downloads are allowed
destination_folder: Destination folder for downloads
user_contexts: Optional list of user contexts
Raises:
ValueError: If parameters are invalid
\"\"\"
if allowed is True and not destination_folder:
raise ValueError("destination_folder is required when allowed=True")
if allowed is False and destination_folder:
raise ValueError("destination_folder should not be provided when allowed=False")
"""
elif func_type == "transform" and func_name == "transform_download_params":
code += """def transform_download_params(
allowed: bool | None,
destination_folder: str | None,
) -> dict[str, Any] | None:
\"\"\"Transform download parameters into download_behavior object.
Args:
allowed: Whether downloads are allowed
destination_folder: Destination folder for downloads (accepts str or
pathlib.Path; will be coerced to str)
Returns:
Dictionary representing the download_behavior object, or None if allowed is None
\"\"\"
if allowed is True:
return {
"type": "allowed",
# Coerce pathlib.Path (or any path-like) to str so the BiDi
# protocol always receives a plain JSON string.
"destinationFolder": str(destination_folder) if destination_folder is not None else None,
}
elif allowed is False:
return {"type": "denied"}
else: # None — reset to browser default (sent as JSON null)
return None
"""
# Generate enums first (excluding those in exclude_types)
exclude_types = set(enhancements.get("exclude_types", []))
# Also exclude any types that have extra_dataclasses overrides
# Extract class names from extra_dataclasses strings
for extra_cls in enhancements.get("extra_dataclasses", []):
# Match "class ClassName:" pattern
match = re.search(r"class\s+(\w+)\s*:", extra_cls)
if match:
exclude_types.add(match.group(1))
for enum_def in self.enums:
if enum_def.name in exclude_types:
continue
code += enum_def.to_python_class(enhancements)
code += "\n\n"
# Emit module-level aliases from enhancement manifest (e.g. LogLevel = Level)
for alias, target in enhancements.get("aliases", {}).items():
code += f"{alias} = {target}\n\n"
# Generate type dataclasses, skipping any overridden by extra_dataclasses
for type_def in self.types:
if type_def.name in exclude_types:
continue
code += type_def.to_python_dataclass(enhancements)
code += "\n\n"
# Emit extra dataclasses from enhancement manifest (non-CDDL additions)
for extra_cls in enhancements.get("extra_dataclasses", []):
code += extra_cls
code += "\n\n"
# Emit extra type aliases from enhancement manifest (e.g., union types for events)
for extra_alias in enhancements.get("extra_type_aliases", []):
code += extra_alias
code += "\n\n"
# NOTE: Don't generate event type aliases here - they reference types that may not be defined yet
# They will be generated after the class definition instead
# Generate EVENT_NAME_MAPPING for the module (before the module class)
if self.events:
# Generate EVENT_NAME_MAPPING for the module
code += "# BiDi Event Name to Parameter Type Mapping\n"
code += "EVENT_NAME_MAPPING = {\n"
for event_def in self.events:
# Convert method name to user-friendly event name
# e.g., "browsingContext.contextCreated" -> "context_created"
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
code += f' "{event_name}": "{event_def.method}",\n'
# Extra events not in the CDDL spec (e.g. Chromium-specific events)
for extra_evt in enhancements.get("extra_events", []):
code += f' "{extra_evt["event_key"]}": "{extra_evt["bidi_event"]}",\n'
code += "}\n\n"
# Add custom method function definitions before the class (for browsingContext)
if self.name == "browsingContext":
# Add helper function for recursive Info deserialization
code += """def _deserialize_info_list(items: list) -> list | None:
\"\"\"Recursively deserialize a list of dicts to Info objects.
Args:
items: List of dicts from the API response
Returns:
List of Info objects with properly nested children, or None if empty
\"\"\"
if not items or not isinstance(items, list):
return None
result = []
for item in items:
if isinstance(item, dict):
# Recursively deserialize children only if the key exists in response
children_list = None
if "children" in item:
children_list = _deserialize_info_list(item.get("children", []))
info = Info(
children=children_list,
client_window=item.get("clientWindow"),
context=item.get("context"),
original_opener=item.get("originalOpener"),
url=item.get("url"),
user_context=item.get("userContext"),
parent=item.get("parent"),
)
result.append(info)
return result if result else None
"""
code += "\n\n"
# EventConfig, _EventWrapper, and _EventManager are imported from
# selenium.webdriver.common.bidi._event_manager (see import section above)
# rather than being duplicated inline in every generated module.
if False: # placeholder to preserve indentation structure
pass
# Generate class
# Convert module name (camelCase or snake_case) to proper class name (PascalCase)
class_name = module_name_to_class_name(self.name)
class_docstrings = enhancements.get("class_docstrings", {})
module_class_doc = _docstring_text(
class_docstrings.get(class_name),
class_name,
f"WebDriver BiDi {self.name} module.",
)
code += f"class {class_name}:\n"
code += _emit_docstring(module_class_doc, 4)
code += "\n"
# Add EVENT_CONFIGS dict if there are events
if self.events:
code += " EVENT_CONFIGS: dict[str, EventConfig] = {}\n" # Will be populated after types are defined
if self.name == "script":
code += " def __init__(self, conn, driver=None) -> None:\n"
code += " self._conn = conn\n"
code += " self._driver = driver\n"
else:
code += " def __init__(self, conn) -> None:\n"
code += " self._conn = conn\n"
# Initialize _event_manager if there are events
if self.events:
code += " self._event_manager = _EventManager(conn, self.EVENT_CONFIGS)\n"
# Append extra init code from enhancements (e.g. self.intercepts = [])
for init_line in enhancements.get("extra_init_code", []):
code += f" {init_line}\n"
code += "\n"
# Generate command methods
exclude_methods = enhancements.get("exclude_methods", [])
# Automatically exclude methods that are defined in extra_methods
# to prevent generating duplicates
if "extra_methods" in enhancements:
for extra_method in enhancements["extra_methods"]:
# Extract method name from "def method_name("
match = re.search(r"def\s+(\w+)\s*\(", extra_method)
if match:
exclude_methods = list(exclude_methods) + [match.group(1)]
if self.commands:
command_docstrings = enhancements.get("command_docstrings", {})
for command in self.commands:
# Get method-specific enhancements
# Convert command name to snake_case to match enhancement manifest keys
method_name_snake = command._camel_to_snake(command.name)
if method_name_snake in exclude_methods:
continue
method_enhancements = enhancements.get(method_name_snake, {})
# Inject command_docstrings entry if no per-method docstring is set
if method_name_snake in command_docstrings and "docstring" not in method_enhancements:
method_enhancements = {**method_enhancements, "docstring": command_docstrings[method_name_snake]}
code += command.to_python_method(method_enhancements)
code += "\n"
elif not self.events and not enhancements.get("extra_methods", []):
code += " pass\n"
# Emit extra methods from enhancement manifest
for extra_method in enhancements.get("extra_methods", []):
code += extra_method
code += "\n"
# Add delegating event handler methods if events are present
if self.events:
code += """
def add_event_handler(self, event: str, callback: Callable, contexts: list[str] | None = None) -> int:
\"\"\"Add an event handler.
Args:
event: The event to subscribe to.
callback: The callback function to execute on event.
contexts: The context IDs to subscribe to (optional).
Returns:
The callback ID.
\"\"\"
return self._event_manager.add_event_handler(event, callback, contexts)
def remove_event_handler(self, event: str, callback_id: int) -> None:
\"\"\"Remove an event handler.
Args:
event: The event to unsubscribe from.
callback_id: The callback ID.
\"\"\"
return self._event_manager.remove_event_handler(event, callback_id)
def clear_event_handlers(self) -> None:
\"\"\"Clear all event handlers.\"\"\"
return self._event_manager.clear_event_handlers()
"""
# Generate event info type aliases AFTER the class definition
# This ensures all types are available when we create the aliases
if self.events:
code += "\n# Event Info Type Aliases\n"
# Check for explicit event_type_aliases in the enhancement manifest
event_type_aliases = enhancements.get("event_type_aliases", {})
for event_def in self.events:
# Convert method name to user-friendly event name
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
# Check if there's an explicit alias defined in the enhancement manifest
if event_name in event_type_aliases:
# Use the alias directly
type_name = event_type_aliases[event_name]
code += f"# Event: {event_def.method}\n"
code += f"{event_def.name} = {type_name}\n"
else:
# Fall back to the original behavior
code += event_def.to_python_dataclass()
code += "\n"
# Now populate EVENT_CONFIGS after the aliases are defined
code += "\n# Populate EVENT_CONFIGS with event configuration mappings\n"
# Use globals() to look up types dynamically to handle missing types gracefully
code += "_globals = globals()\n"
code += f"{class_name}.EVENT_CONFIGS = {{\n"
for event_def in self.events:
# Convert method name to user-friendly event name
method_parts = event_def.method.split(".")
if len(method_parts) == 2:
event_name = self._convert_method_to_event_name(method_parts[1])
# Try to get event class from globals, default to dict if not found
getter = f'_globals.get("{event_def.name}", dict)'
condition = f'_globals.get("{event_def.name}")'
event_class = f"{getter} if {condition} else dict"
# Build the entry line and check if it exceeds 120 chars
single_line = (
f' "{event_name}": EventConfig("{event_name}", "{event_def.method}", {event_class}),'
)
if len(single_line) > 120:
# Break into multiple lines
code += f' "{event_name}": EventConfig(\n'
code += f' "{event_name}",\n'
code += f' "{event_def.method}",\n'
code += f" {event_class},\n"
code += " ),\n"
else:
code += single_line + "\n"
# Extra events not in the CDDL spec
for extra_evt in enhancements.get("extra_events", []):
ek = extra_evt["event_key"]
be = extra_evt["bidi_event"]
ec = extra_evt["event_class"]
code += f' "{ek}": EventConfig("{ek}", "{be}", _globals.get("{ec}", dict)),\n'
code += "}\n"
# Check if field() is actually used in the generated code
# If so, add the field import after the dataclass import
if "field(" in code:
# Find where to insert the field import
# It should go after "from dataclasses import dataclass" line
dataclass_import_pattern = r"from dataclasses import dataclass\n"
if re.search(dataclass_import_pattern, code):
code = re.sub(
dataclass_import_pattern,
"from dataclasses import dataclass, field\n",
code,
count=1,
)
elif "from dataclasses import" not in code:
# If there's no dataclasses import yet, add field import after typing
code = code.replace(
"from typing import Any\n",
"from dataclasses import field\nfrom typing import Any\n",
)
return code
class CddlParser:
"""Parse CDDL specification files."""
def __init__(self, cddl_path: str):
"""Initialize parser with CDDL file path."""
self.cddl_path = Path(cddl_path)
self.content = ""
self.modules: dict[str, CddlModule] = {}
self.definitions: dict[str, str] = {}
self.event_names: set[str] = set() # Names of definitions that are events
self._read_file()
def _read_file(self) -> None:
"""Read and preprocess CDDL file."""
if not self.cddl_path.exists():
raise FileNotFoundError(f"CDDL file not found: {self.cddl_path}")
with open(self.cddl_path, encoding="utf-8") as f:
self.content = f.read()
logger.info(f"Loaded CDDL file: {self.cddl_path}")
def parse(self) -> dict[str, CddlModule]:
"""Parse CDDL content and return modules."""
# Remove comments
content = self._remove_comments(self.content)
# Extract all definitions
self._extract_definitions(content)
# Extract event names from event union definitions
self._extract_event_names()
# Extract type definitions by module
self._extract_types()
# Extract event definitions by module
self._extract_events()
# Extract command definitions by module
self._extract_commands()
# If no modules found, create a default one from the filename
if not self.modules:
module_name = self.cddl_path.stem
default_module = CddlModule(name=module_name)
self.modules[module_name] = default_module
logger.warning(f"No modules found in CDDL, creating default: {module_name}")
return self.modules
def _remove_comments(self, content: str) -> str:
"""Remove comments from CDDL content."""
# CDDL uses ; for comments to end of line
lines = content.split("\n")
cleaned = []
for line in lines:
if ";" in line and not line.strip().startswith(";"):
line = line[: line.index(";")]
elif line.strip().startswith(";"):
continue
cleaned.append(line)
return "\n".join(cleaned)
def _extract_definitions(self, content: str) -> None:
"""Extract CDDL definitions (type definitions, commands, etc.)."""
# Match pattern: Name = Definition
# Handles multiline definitions properly
pattern = r"(\w+(?:\.\w+)*)\s*=\s*(.+?)(?=\n\w+(?:\.\w+)?\s*=|\Z)"
for match in re.finditer(pattern, content, re.DOTALL):
name = match.group(1).strip()
definition = match.group(2).strip()
self.definitions[name] = definition
logger.debug(f"Extracted definition: {name}")
def _extract_event_names(self) -> None:
"""Extract event names from event union definitions.
Event union definitions follow pattern:
module.ModuleEvent = (
module.EventName1 //
module.EventName2 //
...
)
"""
for def_name, def_content in self.definitions.items():
# Check if this looks like an event union (name ends with "Event") and
# contains a module-qualified reference like "module.EventName".
# Handles both single-item (no //) and multi-item (// separated) unions.
if "Event" in def_name and re.search(r"\w+\.\w+", def_content):
# Extract event names from the union (works for single and multi-item)
event_refs = re.findall(r"(\w+\.\w+)", def_content)
for event_ref in event_refs:
self.event_names.add(event_ref)
logger.debug(f"Identified event: {event_ref} (from {def_name})")
def _extract_types(self) -> None:
"""Extract type definitions from parsed definitions."""
# Type definitions follow pattern: module.TypeName = { field: type, ... }
# They have dots in the name and curly braces in the content
# But they DON'T have method: "..." pattern (which means it's not a command)
# Enums follow pattern: module.EnumName = "value1" / "value2" / ...
for def_name, def_content in self.definitions.items():
# Skip if not a namespaced name (e.g., skip "EmptyParams", "Extensible")
if "." not in def_name:
continue
# Skip if it's a command (contains method: pattern)
if "method:" in def_content:
continue
# Extract module.TypeName
if "." in def_name:
module_name, type_name = def_name.rsplit(".", 1)
# Create module if not exists
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
# Check if this is an enum (string union with /)
if self._is_enum_definition(def_content):
# Extract enum values
values = self._extract_enum_values(def_content)
if values:
enum_def = CddlEnum(
module=module_name,
name=type_name,
values=values,
description=f"{type_name}",
)
self.modules[module_name].enums.append(enum_def)
logger.debug(f"Found enum: {def_name} with {len(values)} values")
else:
# Extract fields from type definition
fields = self._extract_type_fields(def_content)
if fields: # Only create type if it has fields
type_def = CddlTypeDefinition(
module=module_name,
name=type_name,
fields=fields,
description=f"{type_name}",
)
self.modules[module_name].types.append(type_def)
logger.debug(f"Found type: {def_name} with {len(fields)} fields")
def _is_enum_definition(self, definition: str) -> bool:
"""Check if a definition is an enum (string union with /).
Enums are defined as: "value1" / "value2" / "value3"
"""
# Clean whitespace
clean_def = definition.strip()
# Must not have curly braces (that would be a type definition)
if "{" in clean_def or "}" in clean_def:
return False
# Must contain the union operator / surrounded by quotes
# Pattern: "something" / "something_else"
return " / " in clean_def and '"' in clean_def
def _extract_enum_values(self, enum_definition: str) -> list[str]:
"""Extract individual values from an enum definition.
Enums are defined as: "value1" / "value2" / "value3"
Can span multiple lines.
"""
values = []
# Clean the definition and extract quoted strings
# Split by / and extract quoted values
parts = enum_definition.split("/")
for part in parts:
part = part.strip()
# Extract quoted string - use search instead of match to find quotes anywhere
match = re.search(r'"([^"]*)"', part)
if match:
value = match.group(1)
values.append(value)
logger.debug(f"Extracted enum value: {value}")
return values
@staticmethod
def _normalize_cddl_type(field_type: str) -> str:
"""Normalize a CDDL type expression to a simple Python-compatible form.
Strips CDDL control operators (.ge, .le, .gt, .lt, .default, etc.) and
replaces interval/constraint expressions with their base types so that
the caller can safely check for nested struct syntax.
Examples:
'(float .ge 0.0) .default 1.0' -> 'float'
'(float .ge 0.0) / null' -> 'float / null'
'(0.0...360.0) / null' -> 'float / null'
'-90.0..90.0' -> 'float'
'float / null .default null' -> 'float / null'
"""
result = field_type
# Remove trailing .default <value> annotations
result = re.sub(r"\s*\.default\s+\S+", "", result)
# Replace parenthesised constraint expressions: (baseType .operator ...) -> baseType
result = re.sub(r"\((\w+)\s+\.\w+[^)]*\)", r"\1", result)
# Replace parenthesised numeric interval types: (0.0...360.0) -> float
result = re.sub(r"\(-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?\)", "float", result)
# Replace bare numeric interval types: -90.0..90.0 -> float
result = re.sub(r"-?\d+(?:\.\d+)?\.{2,3}-?\d+(?:\.\d+)?", "float", result)
return result.strip()
def _extract_type_fields(self, type_definition: str) -> dict[str, str]:
"""Extract fields from a type definition block."""
fields = {}
# Remove outer braces
clean_def = type_definition.strip()
if clean_def.startswith("{"):
clean_def = clean_def[1:]
if clean_def.endswith("}"):
clean_def = clean_def[:-1]
# Parse each line for field: type patterns
for line in clean_def.split("\n"):
line = line.strip()
if not line or "Extensible" in line or line.startswith("//"):
continue
# Match pattern: [?] fieldName: type
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if not match:
# Try without optional marker
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if match:
field_name = match.group(1).strip()
field_type = match.group(2).strip()
normalized_type = self._normalize_cddl_type(field_type)
# Skip lines that are part of nested definitions
if "{" not in normalized_type and "(" not in normalized_type:
fields[field_name] = normalized_type
logger.debug(f"Extracted field {field_name}: {normalized_type}")
return fields
def _extract_events(self) -> None:
"""Extract event definitions from parsed definitions.
Events are definitions that:
1. Are listed in an event union (e.g., BrowsingContextEvent)
2. Have method: "..." and params: ... fields
Event pattern: module.EventName = (method: "module.eventName", params: module.ParamType)
"""
# Find definitions that are in the event_names set
event_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
for def_name, def_content in self.definitions.items():
# Skip if not identified as an event
if def_name not in self.event_names:
continue
# Extract method and params
match = event_pattern.search(def_content)
if match:
method = match.group(1) # e.g., "browsingContext.contextCreated"
params_type = match.group(2) # e.g., "browsingContext.Info"
# Extract module name from method
if "." in method:
module_name, _ = method.split(".", 1)
# Create module if not exists
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
# Extract event name from definition name (e.g., browsingContext.ContextCreated)
_, event_name = def_name.rsplit(".", 1)
# Create event
event = CddlEvent(
module=module_name,
name=event_name,
method=method,
params_type=params_type,
description=f"Event: {method}",
)
self.modules[module_name].events.append(event)
logger.debug(f"Found event: {def_name} (method={method}, params={params_type})")
def _extract_commands(self) -> None:
"""Extract command definitions from parsed definitions."""
# Find command definitions that follow pattern: module.Command = (method: "...", params: ...)
command_pattern = re.compile(r"method:\s*['\"]([^'\"]+)['\"],\s*params:\s*(\w+(?:\.\w+)*)")
for def_name, def_content in self.definitions.items():
# Skip definitions that are events (they share the same pattern)
if def_name in self.event_names:
continue
matches = list(command_pattern.finditer(def_content))
if matches:
for match in matches:
method = match.group(1) # e.g., "session.new"
params_type = match.group(2) # e.g., "session.NewParameters"
# Extract module name from method
if "." in method:
module_name, command_name = method.split(".", 1)
# Create module if not exists
if module_name not in self.modules:
self.modules[module_name] = CddlModule(name=module_name)
# Extract parameters and required parameters
params, required_params = self._extract_parameters_and_required(params_type)
# Create command
cmd = CddlCommand(
module=module_name,
name=command_name,
params=params,
required_params=required_params,
description=f"Execute {method}",
)
self.modules[module_name].commands.append(cmd)
logger.debug(f"Found command: {method} with params {params_type}")
def _extract_parameters(self, params_type: str, _seen: set[str] | None = None) -> dict[str, str]:
"""Extract parameters from a parameter type definition.
Handles both struct types ({...}) and top-level union types (TypeA / TypeB),
merging all fields from each alternative as optional parameters.
"""
params, _ = self._extract_parameters_and_required(params_type, _seen)
return params
def _extract_parameters_and_required(
self, params_type: str, _seen: set[str] | None = None
) -> tuple[dict[str, str], set[str]]:
"""Extract parameters and track which are required from a parameter type definition.
Returns:
Tuple of (params dict, required_params set)
"""
params = {}
required = set()
if _seen is None:
_seen = set()
if params_type in _seen:
return params, required
_seen.add(params_type)
if params_type not in self.definitions:
logger.debug(f"Parameter type not found: {params_type}")
return params, required
definition = self.definitions[params_type]
# Handle top-level type alias that is a union of other named types:
# e.g. session.UnsubscribeByAttributesRequest / session.UnsubscribeByIDRequest
# These definitions contain a single line with "/" separating type names
# (not the double-slash "//" used for command unions).
stripped = definition.strip()
if not stripped.startswith("{") and "/" in stripped and "//" not in stripped:
# Each token separated by "/" should be a named type reference
alternatives = [a.strip() for a in stripped.split("/") if a.strip()]
all_named = all(re.match(r"^[\w.]+$", a) for a in alternatives)
if all_named:
# For union types, collect parameters from all alternatives
# but treat them as optional since the caller only needs to pass one alternative
for alt_type in alternatives:
alt_params, _ = self._extract_parameters_and_required(alt_type, _seen)
params.update(alt_params)
# Note: We intentionally DON'T add to required, since these are union alternatives
return params, required
# Remove the outer curly braces and split by comma
# Then parse each line for key: type patterns
clean_def = stripped
if clean_def.startswith("{"):
clean_def = clean_def[1:]
if clean_def.endswith("}"):
clean_def = clean_def[:-1]
# Split by newlines and process each line
for line in clean_def.split("\n"):
line = line.strip()
if not line or "Extensible" in line:
continue
# Match pattern: [?] name: type
# Check if parameter has optional marker (?)
is_optional = line.startswith("?")
# Using a simple pattern that handles optional prefix
match = re.match(r"\?\s*(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if not match:
# Try without optional marker
match = re.match(r"(\w+)\s*:\s*(.+?)(?:,\s*)?$", line)
if match:
param_name = match.group(1).strip()
param_type = match.group(2).strip()
normalized_type = self._normalize_cddl_type(param_type)
# Skip lines that are part of nested definitions
if "{" not in normalized_type and "(" not in normalized_type:
params[param_name] = normalized_type
if not is_optional:
required.add(param_name)
logger.debug(
f"Extracted param {param_name}: {normalized_type} "
f"(required={not is_optional}) from {params_type}"
)
return params, required
def module_name_to_class_name(module_name: str) -> str:
"""Convert module name to class name (PascalCase).
Handles both camelCase (browsingContext) and snake_case (browsing_context).
"""
if "_" in module_name:
# Snake_case: browsing_context -> BrowsingContext
return "".join(word.capitalize() for word in module_name.split("_"))
else:
# CamelCase: browsingContext -> BrowsingContext
return module_name[0].upper() + module_name[1:] if module_name else ""
def module_name_to_filename(module_name: str) -> str:
"""Convert module name to Python filename (snake_case).
Handles both camelCase (browsingContext) and snake_case (browsing_context).
Special cases:
- browsingContext -> browsing_context
- webExtension -> webextension
"""
# Handle explicit mappings for known camelCase names
camel_to_snake_map = {
"browsingContext": "browsing_context",
"webExtension": "webextension",
}
if module_name in camel_to_snake_map:
return camel_to_snake_map[module_name]
if "_" in module_name:
# Already snake_case
return module_name
else:
# Convert camelCase to snake_case for other cases
# This handles cases like "myModuleName" -> "my_module_name"
import re
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", module_name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def generate_init_file(output_path: Path, modules: dict[str, CddlModule]) -> None:
"""Generate __init__.py file for the module."""
init_path = output_path / "__init__.py"
code = f"""{SHARED_HEADER}
from __future__ import annotations
"""
for module_name in sorted(modules.keys()):
class_name = module_name_to_class_name(module_name)
filename = module_name_to_filename(module_name)
code += f"from selenium.webdriver.common.bidi.{filename} import {class_name}\n"
code += "\n__all__ = [\n"
for module_name in sorted(modules.keys()):
class_name = module_name_to_class_name(module_name)
code += f' "{class_name}",\n'
code += "]\n"
with open(init_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {init_path}")
def generate_common_file(output_path: Path) -> None:
"""Generate common.py file with shared utilities."""
common_path = output_path / "common.py"
code = (
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
"# or more contributor license agreements. See the NOTICE file\n"
"# distributed with this work for additional information\n"
"# regarding copyright ownership. The SFC licenses this file\n"
"# to you under the Apache License, Version 2.0 (the\n"
'# "License"); you may not use this file except in compliance\n'
"# with the License. You may obtain a copy of the License at\n"
"#\n"
"# http://www.apache.org/licenses/LICENSE-2.0\n"
"#\n"
"# Unless required by applicable law or agreed to in writing,\n"
"# software distributed under the License is distributed on an\n"
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
"# KIND, either express or implied. See the License for the\n"
"# specific language governing permissions and limitations\n"
"# under the License.\n"
"\n"
'"""Common utilities for BiDi command construction."""\n'
"\n"
"from __future__ import annotations\n"
"\n"
"from collections.abc import Generator\n"
"from typing import Any\n"
"\n"
"\n"
"def command_builder(\n"
" method: str, params: dict[str, Any] | None = None\n"
") -> Generator[dict[str, Any], Any, Any]:\n"
' """Build a BiDi command generator.\n'
"\n"
" Args:\n"
' method: The BiDi method name (e.g., "session.status", "browser.close")\n'
" params: The parameters for the command. If omitted, an empty\n"
" dictionary is sent.\n"
"\n"
" Yields:\n"
" A dictionary representing the BiDi command\n"
"\n"
" Returns:\n"
" The result from the BiDi command execution\n"
' """\n'
" if params is None:\n"
" params = {}\n"
' result = yield {"method": method, "params": params}\n'
" return result\n"
)
with open(common_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {common_path}")
def generate_console_file(output_path: Path) -> None:
"""Generate console.py file with Console enum helper."""
console_path = output_path / "console.py"
code = (
"# Licensed to the Software Freedom Conservancy (SFC) under one\n"
"# or more contributor license agreements. See the NOTICE file\n"
"# distributed with this work for additional information\n"
"# regarding copyright ownership. The SFC licenses this file\n"
"# to you under the Apache License, Version 2.0 (the\n"
'# "License"); you may not use this file except in compliance\n'
"# with the License. You may obtain a copy of the License at\n"
"#\n"
"# http://www.apache.org/licenses/LICENSE-2.0\n"
"#\n"
"# Unless required by applicable law or agreed to in writing,\n"
"# software distributed under the License is distributed on an\n"
'# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n'
"# KIND, either express or implied. See the License for the\n"
"# specific language governing permissions and limitations\n"
"# under the License.\n"
"\n"
"from enum import Enum\n"
"\n"
"\n"
"class Console(Enum):\n"
' ALL = "all"\n'
' LOG = "log"\n'
' ERROR = "error"\n'
)
with open(console_path, "w", encoding="utf-8") as f:
f.write(code)
logger.info(f"Generated: {console_path}")
def main(
cddl_file: str,
output_dir: str,
spec_version: str = "1.0",
enhancements_manifest: str | None = None,
) -> None:
"""Main entry point.
Args:
cddl_file: Path to CDDL specification file
output_dir: Output directory for generated modules
spec_version: BiDi spec version
enhancements_manifest: Path to enhancement manifest Python file
"""
output_path = Path(output_dir).resolve()
output_path.mkdir(parents=True, exist_ok=True)
logger.info(f"WebDriver BiDi Code Generator v{__version__}")
logger.info(f"Input CDDL: {cddl_file}")
logger.info(f"Output directory: {output_path}")
logger.info(f"Spec version: {spec_version}")
# Load enhancement manifest
manifest = load_enhancements_manifest(enhancements_manifest)
if manifest:
logger.info(f"Loaded enhancement manifest from: {enhancements_manifest}")
# Parse CDDL
parser = CddlParser(cddl_file)
modules = parser.parse()
logger.info(f"Parsed {len(modules)} modules")
# Clean up existing generated files.
# Keep static helper modules that are staged by Bazel (for example cdp.py)
# as part of create-bidi-src.extra_srcs.
preserved_python_files = {"py.typed", "cdp.py"}
for file_path in output_path.glob("*.py"):
if file_path.name not in preserved_python_files and not file_path.name.startswith("_"):
file_path.unlink()
logger.debug(f"Removed: {file_path}")
# Generate module files using snake_case filenames
for module_name, module in sorted(modules.items()):
filename = module_name_to_filename(module_name)
module_path = output_path / f"{filename}.py"
# Get module-specific enhancements (merge with dataclass templates)
module_enhancements = manifest.get("enhancements", {}).get(module_name, {})
# Add dataclass methods and docstrings to the enhancement data for this module
full_module_enhancements = {
**module_enhancements,
"dataclass_methods": manifest.get("dataclass_methods", {}),
"method_docstrings": manifest.get("method_docstrings", {}),
}
with open(module_path, "w", encoding="utf-8") as f:
f.write(module.generate_code(full_module_enhancements))
logger.info(f"Generated: {module_path}")
# Generate __init__.py
generate_init_file(output_path, modules)
# Generate common.py
generate_common_file(output_path)
# Generate console.py
generate_console_file(output_path)
# Create py.typed marker
py_typed_path = output_path / "py.typed"
py_typed_path.touch()
logger.info(f"Generated type marker: {py_typed_path}")
logger.info("Code generation complete!")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Generate Python WebDriver BiDi modules from CDDL specification")
parser.add_argument(
"cddl_file",
help="Path to CDDL specification file",
)
parser.add_argument(
"output_dir",
help="Output directory for generated Python modules",
)
parser.add_argument(
"spec_version",
nargs="?",
default="1.0",
help="BiDi spec version (default: 1.0)",
)
parser.add_argument(
"--enhancements-manifest",
default=None,
help="Path to enhancement manifest Python file (optional)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Enable verbose logging",
)
args = parser.parse_args()
if args.verbose:
logging.getLogger("generate_bidi").setLevel(logging.DEBUG)
try:
main(
args.cddl_file,
args.output_dir,
args.spec_version,
args.enhancements_manifest,
)
sys.exit(0)
except Exception as e:
logger.error(f"Generation failed: {e}", exc_info=True)
sys.exit(1)