| """ |
| The `ast` module helps Python applications to process trees of the Python |
| abstract syntax grammar. The abstract syntax itself might change with |
| each Python release; this module helps to find out programmatically what |
| the current grammar looks like and allows modifications of it. |
| |
| An abstract syntax tree can be generated by passing `ast.PyCF_ONLY_AST` as |
| a flag to the `compile()` builtin function or by using the `parse()` |
| function from this module. The result will be a tree of objects whose |
| classes all inherit from `ast.AST`. |
| |
| A modified abstract syntax tree can be compiled into a Python code object |
| using the built-in `compile()` function. |
| |
| Additionally various helper functions are provided that make working with |
| the trees simpler. The main intention of the helper functions and this |
| module in general is to provide an easy to use interface for libraries |
| that work tightly with the python syntax (template engines for example). |
| |
| :copyright: Copyright 2008 by Armin Ronacher. |
| :license: Python License. |
| """ |
| from _ast import * |
| |
| |
| def parse(source, filename='<unknown>', mode='exec', *, |
| type_comments=False, feature_version=None, optimize=-1): |
| """ |
| Parse the source into an AST node. |
| Equivalent to compile(source, filename, mode, PyCF_ONLY_AST). |
| Pass type_comments=True to get back type comments where the syntax allows. |
| """ |
| flags = PyCF_ONLY_AST |
| if optimize > 0: |
| flags |= PyCF_OPTIMIZED_AST |
| if type_comments: |
| flags |= PyCF_TYPE_COMMENTS |
| if feature_version is None: |
| feature_version = -1 |
| elif isinstance(feature_version, tuple): |
| major, minor = feature_version # Should be a 2-tuple. |
| if major != 3: |
| raise ValueError(f"Unsupported major version: {major}") |
| feature_version = minor |
| # Else it should be an int giving the minor version for 3.x. |
| return compile(source, filename, mode, flags, |
| _feature_version=feature_version, optimize=optimize) |
| |
| |
| def literal_eval(node_or_string): |
| """ |
| Evaluate an expression node or a string containing only a Python |
| expression. The string or node provided may only consist of the following |
| Python literal structures: strings, bytes, numbers, tuples, lists, dicts, |
| sets, booleans, and None. |
| |
| Caution: A complex expression can overflow the C stack and cause a crash. |
| """ |
| if isinstance(node_or_string, str): |
| node_or_string = parse(node_or_string.lstrip(" \t"), mode='eval').body |
| elif isinstance(node_or_string, Expression): |
| node_or_string = node_or_string.body |
| return _convert_literal(node_or_string) |
| |
| |
| def _convert_literal(node): |
| """ |
| Used by `literal_eval` to convert an AST node into a value. |
| """ |
| if isinstance(node, Constant): |
| return node.value |
| if isinstance(node, Dict) and len(node.keys) == len(node.values): |
| return dict(zip( |
| map(_convert_literal, node.keys), |
| map(_convert_literal, node.values), |
| )) |
| if isinstance(node, Tuple): |
| return tuple(map(_convert_literal, node.elts)) |
| if isinstance(node, List): |
| return list(map(_convert_literal, node.elts)) |
| if isinstance(node, Set): |
| return set(map(_convert_literal, node.elts)) |
| if ( |
| isinstance(node, Call) and isinstance(node.func, Name) |
| and node.func.id == 'set' and node.args == node.keywords == [] |
| ): |
| return set() |
| if ( |
| isinstance(node, UnaryOp) |
| and isinstance(node.op, (UAdd, USub)) |
| and isinstance(node.operand, Constant) |
| and type(operand := node.operand.value) in (int, float, complex) |
| ): |
| if isinstance(node.op, UAdd): |
| return + operand |
| else: |
| return - operand |
| if ( |
| isinstance(node, BinOp) |
| and isinstance(node.op, (Add, Sub)) |
| and isinstance(node.left, (Constant, UnaryOp)) |
| and isinstance(node.right, Constant) |
| and type(left := _convert_literal(node.left)) in (int, float) |
| and type(right := _convert_literal(node.right)) is complex |
| ): |
| if isinstance(node.op, Add): |
| return left + right |
| else: |
| return left - right |
| msg = "malformed node or string" |
| if lno := getattr(node, 'lineno', None): |
| msg += f' on line {lno}' |
| raise ValueError(msg + f': {node!r}') |
| |
| |
| def dump( |
| node, annotate_fields=True, include_attributes=False, |
| *, |
| indent=None, show_empty=False, |
| ): |
| """ |
| Return a formatted dump of the tree in node. This is mainly useful for |
| debugging purposes. If annotate_fields is true (by default), |
| the returned string will show the names and the values for fields. |
| If annotate_fields is false, the result string will be more compact by |
| omitting unambiguous field names. Attributes such as line |
| numbers and column offsets are not dumped by default. If this is wanted, |
| include_attributes can be set to true. If indent is a non-negative |
| integer or string, then the tree will be pretty-printed with that indent |
| level. None (the default) selects the single line representation. |
| If show_empty is False, then empty lists and fields that are None |
| will be omitted from the output for better readability. |
| """ |
| def _format(node, level=0): |
| if indent is not None: |
| level += 1 |
| prefix = '\n' + indent * level |
| sep = ',\n' + indent * level |
| else: |
| prefix = '' |
| sep = ', ' |
| if isinstance(node, AST): |
| cls = type(node) |
| args = [] |
| args_buffer = [] |
| allsimple = True |
| keywords = annotate_fields |
| for name in node._fields: |
| try: |
| value = getattr(node, name) |
| except AttributeError: |
| keywords = True |
| continue |
| if value is None and getattr(cls, name, ...) is None: |
| keywords = True |
| continue |
| if not show_empty: |
| if value == []: |
| field_type = cls._field_types.get(name, object) |
| if getattr(field_type, '__origin__', ...) is list: |
| if not keywords: |
| args_buffer.append(repr(value)) |
| continue |
| elif isinstance(value, Load): |
| field_type = cls._field_types.get(name, object) |
| if field_type is expr_context: |
| if not keywords: |
| args_buffer.append(repr(value)) |
| continue |
| if not keywords: |
| args.extend(args_buffer) |
| args_buffer = [] |
| value, simple = _format(value, level) |
| allsimple = allsimple and simple |
| if keywords: |
| args.append('%s=%s' % (name, value)) |
| else: |
| args.append(value) |
| if include_attributes and node._attributes: |
| for name in node._attributes: |
| try: |
| value = getattr(node, name) |
| except AttributeError: |
| continue |
| if value is None and getattr(cls, name, ...) is None: |
| continue |
| value, simple = _format(value, level) |
| allsimple = allsimple and simple |
| args.append('%s=%s' % (name, value)) |
| if allsimple and len(args) <= 3: |
| return '%s(%s)' % (node.__class__.__name__, ', '.join(args)), not args |
| return '%s(%s%s)' % (node.__class__.__name__, prefix, sep.join(args)), False |
| elif isinstance(node, list): |
| if not node: |
| return '[]', True |
| return '[%s%s]' % (prefix, sep.join(_format(x, level)[0] for x in node)), False |
| return repr(node), True |
| |
| if not isinstance(node, AST): |
| raise TypeError('expected AST, got %r' % node.__class__.__name__) |
| if indent is not None and not isinstance(indent, str): |
| indent = ' ' * indent |
| return _format(node)[0] |
| |
| |
| def copy_location(new_node, old_node): |
| """ |
| Copy source location (`lineno`, `col_offset`, `end_lineno`, and `end_col_offset` |
| attributes) from *old_node* to *new_node* if possible, and return *new_node*. |
| """ |
| for attr in 'lineno', 'col_offset', 'end_lineno', 'end_col_offset': |
| if attr in old_node._attributes and attr in new_node._attributes: |
| value = getattr(old_node, attr, None) |
| # end_lineno and end_col_offset are optional attributes, and they |
| # should be copied whether the value is None or not. |
| if value is not None or ( |
| hasattr(old_node, attr) and attr.startswith("end_") |
| ): |
| setattr(new_node, attr, value) |
| return new_node |
| |
| |
| def fix_missing_locations(node): |
| """ |
| When you compile a node tree with compile(), the compiler expects lineno and |
| col_offset attributes for every node that supports them. This is rather |
| tedious to fill in for generated nodes, so this helper adds these attributes |
| recursively where not already set, by setting them to the values of the |
| parent node. It works recursively starting at *node*. |
| """ |
| def _fix(node, lineno, col_offset, end_lineno, end_col_offset): |
| if 'lineno' in node._attributes: |
| if not hasattr(node, 'lineno'): |
| node.lineno = lineno |
| else: |
| lineno = node.lineno |
| if 'end_lineno' in node._attributes: |
| if getattr(node, 'end_lineno', None) is None: |
| node.end_lineno = end_lineno |
| else: |
| end_lineno = node.end_lineno |
| if 'col_offset' in node._attributes: |
| if not hasattr(node, 'col_offset'): |
| node.col_offset = col_offset |
| else: |
| col_offset = node.col_offset |
| if 'end_col_offset' in node._attributes: |
| if getattr(node, 'end_col_offset', None) is None: |
| node.end_col_offset = end_col_offset |
| else: |
| end_col_offset = node.end_col_offset |
| for child in iter_child_nodes(node): |
| _fix(child, lineno, col_offset, end_lineno, end_col_offset) |
| _fix(node, 1, 0, 1, 0) |
| return node |
| |
| |
| def increment_lineno(node, n=1): |
| """ |
| Increment the line number and end line number of each node in the tree |
| starting at *node* by *n*. This is useful to "move code" to a different |
| location in a file. |
| """ |
| for child in walk(node): |
| # TypeIgnore is a special case where lineno is not an attribute |
| # but rather a field of the node itself. |
| if isinstance(child, TypeIgnore): |
| child.lineno = getattr(child, 'lineno', 0) + n |
| continue |
| |
| if 'lineno' in child._attributes: |
| child.lineno = getattr(child, 'lineno', 0) + n |
| if ( |
| "end_lineno" in child._attributes |
| and (end_lineno := getattr(child, "end_lineno", 0)) is not None |
| ): |
| child.end_lineno = end_lineno + n |
| return node |
| |
| |
| def iter_fields(node): |
| """ |
| Yield a tuple of ``(fieldname, value)`` for each field in ``node._fields`` |
| that is present on *node*. |
| """ |
| for field in node._fields: |
| try: |
| yield field, getattr(node, field) |
| except AttributeError: |
| pass |
| |
| |
| def iter_child_nodes(node): |
| """ |
| Yield all direct child nodes of *node*, that is, all fields that are nodes |
| and all items of fields that are lists of nodes. |
| """ |
| for name, field in iter_fields(node): |
| if isinstance(field, AST): |
| yield field |
| elif isinstance(field, list): |
| for item in field: |
| if isinstance(item, AST): |
| yield item |
| |
| |
| def get_docstring(node, clean=True): |
| """ |
| Return the docstring for the given node or None if no docstring can |
| be found. If the node provided does not have docstrings a TypeError |
| will be raised. |
| |
| If *clean* is `True`, all tabs are expanded to spaces and any whitespace |
| that can be uniformly removed from the second line onwards is removed. |
| """ |
| if not isinstance(node, (AsyncFunctionDef, FunctionDef, ClassDef, Module)): |
| raise TypeError("%r can't have docstrings" % node.__class__.__name__) |
| if not(node.body and isinstance(node.body[0], Expr)): |
| return None |
| node = node.body[0].value |
| if isinstance(node, Constant) and isinstance(node.value, str): |
| text = node.value |
| else: |
| return None |
| if clean: |
| import inspect |
| text = inspect.cleandoc(text) |
| return text |
| |
| |
| _line_pattern = None |
| def _splitlines_no_ff(source, maxlines=None): |
| """Split a string into lines ignoring form feed and other chars. |
| |
| This mimics how the Python parser splits source code. |
| """ |
| global _line_pattern |
| if _line_pattern is None: |
| # lazily computed to speedup import time of `ast` |
| import re |
| _line_pattern = re.compile(r"(.*?(?:\r\n|\n|\r|$))") |
| |
| lines = [] |
| for lineno, match in enumerate(_line_pattern.finditer(source), 1): |
| if maxlines is not None and lineno > maxlines: |
| break |
| lines.append(match[0]) |
| return lines |
| |
| |
| def _pad_whitespace(source): |
| r"""Replace all chars except '\f\t' in a line with spaces.""" |
| result = '' |
| for c in source: |
| if c in '\f\t': |
| result += c |
| else: |
| result += ' ' |
| return result |
| |
| |
| def get_source_segment(source, node, *, padded=False): |
| """Get source code segment of the *source* that generated *node*. |
| |
| If some location information (`lineno`, `end_lineno`, `col_offset`, |
| or `end_col_offset`) is missing, return None. |
| |
| If *padded* is `True`, the first line of a multi-line statement will |
| be padded with spaces to match its original position. |
| """ |
| try: |
| if node.end_lineno is None or node.end_col_offset is None: |
| return None |
| lineno = node.lineno - 1 |
| end_lineno = node.end_lineno - 1 |
| col_offset = node.col_offset |
| end_col_offset = node.end_col_offset |
| except AttributeError: |
| return None |
| |
| lines = _splitlines_no_ff(source, maxlines=end_lineno+1) |
| if end_lineno == lineno: |
| return lines[lineno].encode()[col_offset:end_col_offset].decode() |
| |
| if padded: |
| padding = _pad_whitespace(lines[lineno].encode()[:col_offset].decode()) |
| else: |
| padding = '' |
| |
| first = padding + lines[lineno].encode()[col_offset:].decode() |
| last = lines[end_lineno].encode()[:end_col_offset].decode() |
| lines = lines[lineno+1:end_lineno] |
| |
| lines.insert(0, first) |
| lines.append(last) |
| return ''.join(lines) |
| |
| |
| def walk(node): |
| """ |
| Recursively yield all descendant nodes in the tree starting at *node* |
| (including *node* itself), in no specified order. This is useful if you |
| only want to modify nodes in place and don't care about the context. |
| """ |
| from collections import deque |
| todo = deque([node]) |
| while todo: |
| node = todo.popleft() |
| todo.extend(iter_child_nodes(node)) |
| yield node |
| |
| |
| def compare( |
| a, |
| b, |
| /, |
| *, |
| compare_attributes=False, |
| ): |
| """Recursively compares two ASTs. |
| |
| compare_attributes affects whether AST attributes are considered |
| in the comparison. If compare_attributes is False (default), then |
| attributes are ignored. Otherwise they must all be equal. This |
| option is useful to check whether the ASTs are structurally equal but |
| might differ in whitespace or similar details. |
| """ |
| |
| sentinel = object() # handle the possibility of a missing attribute/field |
| |
| def _compare(a, b): |
| # Compare two fields on an AST object, which may themselves be |
| # AST objects, lists of AST objects, or primitive ASDL types |
| # like identifiers and constants. |
| if isinstance(a, AST): |
| return compare( |
| a, |
| b, |
| compare_attributes=compare_attributes, |
| ) |
| elif isinstance(a, list): |
| # If a field is repeated, then both objects will represent |
| # the value as a list. |
| if len(a) != len(b): |
| return False |
| for a_item, b_item in zip(a, b): |
| if not _compare(a_item, b_item): |
| return False |
| else: |
| return True |
| else: |
| return type(a) is type(b) and a == b |
| |
| def _compare_fields(a, b): |
| if a._fields != b._fields: |
| return False |
| for field in a._fields: |
| a_field = getattr(a, field, sentinel) |
| b_field = getattr(b, field, sentinel) |
| if a_field is sentinel and b_field is sentinel: |
| # both nodes are missing a field at runtime |
| continue |
| if a_field is sentinel or b_field is sentinel: |
| # one of the node is missing a field |
| return False |
| if not _compare(a_field, b_field): |
| return False |
| else: |
| return True |
| |
| def _compare_attributes(a, b): |
| if a._attributes != b._attributes: |
| return False |
| # Attributes are always ints. |
| for attr in a._attributes: |
| a_attr = getattr(a, attr, sentinel) |
| b_attr = getattr(b, attr, sentinel) |
| if a_attr is sentinel and b_attr is sentinel: |
| # both nodes are missing an attribute at runtime |
| continue |
| if a_attr != b_attr: |
| return False |
| else: |
| return True |
| |
| if type(a) is not type(b): |
| return False |
| if not _compare_fields(a, b): |
| return False |
| if compare_attributes and not _compare_attributes(a, b): |
| return False |
| return True |
| |
| |
| class NodeVisitor(object): |
| """ |
| A node visitor base class that walks the abstract syntax tree and calls a |
| visitor function for every node found. This function may return a value |
| which is forwarded by the `visit` method. |
| |
| This class is meant to be subclassed, with the subclass adding visitor |
| methods. |
| |
| Per default the visitor functions for the nodes are ``'visit_'`` + |
| class name of the node. So a `TryFinally` node visit function would |
| be `visit_TryFinally`. This behavior can be changed by overriding |
| the `visit` method. If no visitor function exists for a node |
| (return value `None`) the `generic_visit` visitor is used instead. |
| |
| Don't use the `NodeVisitor` if you want to apply changes to nodes during |
| traversing. For this a special visitor exists (`NodeTransformer`) that |
| allows modifications. |
| """ |
| |
| def visit(self, node): |
| """Visit a node.""" |
| method = 'visit_' + node.__class__.__name__ |
| visitor = getattr(self, method, self.generic_visit) |
| return visitor(node) |
| |
| def generic_visit(self, node): |
| """Called if no explicit visitor function exists for a node.""" |
| for field, value in iter_fields(node): |
| if isinstance(value, list): |
| for item in value: |
| if isinstance(item, AST): |
| self.visit(item) |
| elif isinstance(value, AST): |
| self.visit(value) |
| |
| |
| class NodeTransformer(NodeVisitor): |
| """ |
| A :class:`NodeVisitor` subclass that walks the abstract syntax tree and |
| allows modification of nodes. |
| |
| The `NodeTransformer` will walk the AST and use the return value of the |
| visitor methods to replace or remove the old node. If the return value of |
| the visitor method is ``None``, the node will be removed from its location, |
| otherwise it is replaced with the return value. The return value may be the |
| original node in which case no replacement takes place. |
| |
| Here is an example transformer that rewrites all occurrences of name lookups |
| (``foo``) to ``data['foo']``:: |
| |
| class RewriteName(NodeTransformer): |
| |
| def visit_Name(self, node): |
| return Subscript( |
| value=Name(id='data', ctx=Load()), |
| slice=Constant(value=node.id), |
| ctx=node.ctx |
| ) |
| |
| Keep in mind that if the node you're operating on has child nodes you must |
| either transform the child nodes yourself or call the :meth:`generic_visit` |
| method for the node first. |
| |
| For nodes that were part of a collection of statements (that applies to all |
| statement nodes), the visitor may also return a list of nodes rather than |
| just a single node. |
| |
| Usually you use the transformer like this:: |
| |
| node = YourTransformer().visit(node) |
| """ |
| |
| def generic_visit(self, node): |
| for field, old_value in iter_fields(node): |
| if isinstance(old_value, list): |
| new_values = [] |
| for value in old_value: |
| if isinstance(value, AST): |
| value = self.visit(value) |
| if value is None: |
| continue |
| elif not isinstance(value, AST): |
| new_values.extend(value) |
| continue |
| new_values.append(value) |
| old_value[:] = new_values |
| elif isinstance(old_value, AST): |
| new_node = self.visit(old_value) |
| if new_node is None: |
| delattr(node, field) |
| else: |
| setattr(node, field, new_node) |
| return node |
| |
| class slice(AST): |
| """Deprecated AST node class.""" |
| |
| class Index(slice): |
| """Deprecated AST node class. Use the index value directly instead.""" |
| def __new__(cls, value, **kwargs): |
| return value |
| |
| class ExtSlice(slice): |
| """Deprecated AST node class. Use ast.Tuple instead.""" |
| def __new__(cls, dims=(), **kwargs): |
| return Tuple(list(dims), Load(), **kwargs) |
| |
| # If the ast module is loaded more than once, only add deprecated methods once |
| if not hasattr(Tuple, 'dims'): |
| # The following code is for backward compatibility. |
| # It will be removed in future. |
| |
| def _dims_getter(self): |
| """Deprecated. Use elts instead.""" |
| return self.elts |
| |
| def _dims_setter(self, value): |
| self.elts = value |
| |
| Tuple.dims = property(_dims_getter, _dims_setter) |
| |
| class Suite(mod): |
| """Deprecated AST node class. Unused in Python 3.""" |
| |
| class AugLoad(expr_context): |
| """Deprecated AST node class. Unused in Python 3.""" |
| |
| class AugStore(expr_context): |
| """Deprecated AST node class. Unused in Python 3.""" |
| |
| class Param(expr_context): |
| """Deprecated AST node class. Unused in Python 3.""" |
| |
| |
| def unparse(ast_obj): |
| global _Unparser |
| try: |
| unparser = _Unparser() |
| except NameError: |
| from _ast_unparse import Unparser as _Unparser |
| unparser = _Unparser() |
| return unparser.visit(ast_obj) |
| |
| |
| def main(args=None): |
| import argparse |
| import sys |
| |
| parser = argparse.ArgumentParser(color=True) |
| parser.add_argument('infile', nargs='?', default='-', |
| help='the file to parse; defaults to stdin') |
| parser.add_argument('-m', '--mode', default='exec', |
| choices=('exec', 'single', 'eval', 'func_type'), |
| help='specify what kind of code must be parsed') |
| parser.add_argument('--no-type-comments', default=True, action='store_false', |
| help="don't add information about type comments") |
| parser.add_argument('-a', '--include-attributes', action='store_true', |
| help='include attributes such as line numbers and ' |
| 'column offsets') |
| parser.add_argument('-i', '--indent', type=int, default=3, |
| help='indentation of nodes (number of spaces)') |
| parser.add_argument('--feature-version', |
| type=str, default=None, metavar='VERSION', |
| help='Python version in the format 3.x ' |
| '(for example, 3.10)') |
| parser.add_argument('-O', '--optimize', |
| type=int, default=-1, metavar='LEVEL', |
| help='optimization level for parser (default -1)') |
| parser.add_argument('--show-empty', default=False, action='store_true', |
| help='show empty lists and fields in dump output') |
| args = parser.parse_args(args) |
| |
| if args.infile == '-': |
| name = '<stdin>' |
| source = sys.stdin.buffer.read() |
| else: |
| name = args.infile |
| with open(args.infile, 'rb') as infile: |
| source = infile.read() |
| |
| # Process feature_version |
| feature_version = None |
| if args.feature_version: |
| try: |
| major, minor = map(int, args.feature_version.split('.', 1)) |
| except ValueError: |
| parser.error('Invalid format for --feature-version; ' |
| 'expected format 3.x (for example, 3.10)') |
| |
| feature_version = (major, minor) |
| |
| tree = parse(source, name, args.mode, type_comments=args.no_type_comments, |
| feature_version=feature_version, optimize=args.optimize) |
| print(dump(tree, include_attributes=args.include_attributes, |
| indent=args.indent, show_empty=args.show_empty)) |
| |
| if __name__ == '__main__': |
| main() |