blob: 3dc5311473263752ec53a81e6d3564765f7a7ea9 [file] [log] [blame]
#! /usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import cStringIO
import functools
import imp
import inspect
import itertools
import operator
import os
import re
import sys
from catapult_base import refactor
from catapult_base.refactor_util import move
from telemetry.internal.util import command_line
from telemetry.internal.util import path
_RELATIVE_BASE_DIRS = (
('chrome', 'test', 'telemetry'),
('content', 'test', 'gpu'),
('tools', 'bisect-manual-test.py'),
('tools', 'chrome_proxy'),
('tools', 'perf'),
('tools', 'profile_chrome', 'perf_controller.py'),
('tools', 'run-bisect-manual-test.py'),
('third_party', 'skia', 'tools', 'skp', 'page_sets'),
('third_party', 'trace-viewer'),
)
# All folders dependent on Telemetry, found using a code search.
# Note that this is not the same as the directory that imports are relative to.
BASE_DIRS = [path.GetTelemetryDir()] + [
os.path.join(path.GetChromiumSrcDir(), *dir_path)
for dir_path in _RELATIVE_BASE_DIRS]
def SortImportGroups(module_path):
"""Sort each group of imports in the given Python module.
A group is a collection of adjacent import statements, with no non-import
lines in between. Groups are sorted according to the Google Python Style
Guide: "lexicographically, ignoring case, according to each module's full
package path."
"""
_TransformImportGroups(module_path, _SortImportGroup)
def _SortImportGroup(import_group):
def _ImportComparator(import1, import2):
_, root1, module1, _, _ = import1
_, root2, module2, _, _ = import2
full_module1 = (root1 + '.' + module1 if root1 else module1).lower()
full_module2 = (root2 + '.' + module2 if root2 else module2).lower()
return cmp(full_module1, full_module2)
return sorted(import_group, cmp=_ImportComparator)
def _TransformImportGroups(module_path, transformation):
"""Apply a transformation to each group of imports in the given module.
An import is a tuple of (indent, root, module, alias, suffix),
serialized as <indent>from <root> import <module> as <alias><suffix>.
Args:
module_path: The module to apply transformations on.
transformation: A function that takes in an import group and returns a
modified import group. An import group is a list of import tuples.
Returns:
True iff the module was modified, and False otherwise.
"""
def _WriteImports(output_stream, import_group):
for indent, root, module, alias, suffix in transformation(import_group):
output_stream.write(indent)
if root:
output_stream.write('from ')
output_stream.write(root)
output_stream.write(' ')
output_stream.write('import ')
output_stream.write(module)
if alias:
output_stream.write(' as ')
output_stream.write(alias)
output_stream.write(suffix)
output_stream.write('\n')
# Read the file so we can diff it later to determine if we made any changes.
with open(module_path, 'r') as module_file:
original_file = module_file.read()
# Locate imports using regex, group them, and transform each one.
# This regex produces a tuple of (indent, root, module, alias, suffix).
regex = (r'(\s*)(?:from ((?:[a-z0-9_]+\.)*[a-z0-9_]+) )?'
r'import ((?:[a-z0-9_]+\.)*[A-Za-z0-9_]+)(?: as ([A-Za-z0-9_]+))?(.*)')
pattern = re.compile(regex)
updated_file = cStringIO.StringIO()
with open(module_path, 'r') as module_file:
import_group = []
for line in module_file:
import_match = pattern.match(line)
if import_match:
import_group.append(list(import_match.groups()))
continue
if not import_group:
updated_file.write(line)
continue
_WriteImports(updated_file, import_group)
import_group = []
updated_file.write(line)
if import_group:
_WriteImports(updated_file, import_group)
import_group = []
if original_file == updated_file.getvalue():
return False
with open(module_path, 'w') as module_file:
module_file.write(updated_file.getvalue())
return True
def _CountInterfaces(module):
return (len(list(module.FindChildren(refactor.Class))) +
len(list(module.FindChildren(refactor.Function))))
def _IsSourceDir(dir_name):
return dir_name[0] != '.' and dir_name != 'third_party'
def _IsPythonModule(file_name):
_, ext = os.path.splitext(file_name)
return ext == '.py'
class Count(command_line.Command):
"""Print the number of public modules."""
@classmethod
def AddCommandLineArgs(cls, parser):
parser.add_argument('type', nargs='?', choices=('interfaces', 'modules'),
default='modules')
def Run(self, args):
module_paths = path.ListFiles(
path.GetTelemetryDir(), self._IsPublicApiDir, self._IsPublicApiFile)
if args.type == 'modules':
print len(module_paths)
elif args.type == 'interfaces':
print reduce(operator.add, refactor.Transform(_CountInterfaces, module_paths))
return 0
@staticmethod
def _IsPublicApiDir(dir_name):
return (dir_name[0] != '.' and dir_name[0] != '_' and
dir_name != 'internal' and dir_name != 'third_party')
@staticmethod
def _IsPublicApiFile(file_name):
root, ext = os.path.splitext(file_name)
return (file_name[0] != '.' and
not root.endswith('_unittest') and ext == '.py')
def _TelemetryFiles():
list_files = functools.partial(path.ListFiles,
should_include_dir=_IsSourceDir,
should_include_file=_IsPythonModule)
return sorted(itertools.chain(*map(list_files, BASE_DIRS)))
class Mv(command_line.Command):
"""Move modules or packages."""
@classmethod
def AddCommandLineArgs(cls, parser):
parser.add_argument('source', nargs='+')
parser.add_argument('target')
@classmethod
def ProcessCommandLineArgs(cls, parser, args):
# Check source file paths.
for source_path in args.source:
# Ensure source path exists.
if not os.path.exists(source_path):
parser.error('"%s" not found.' % source_path)
# Ensure source path is in one of the BASE_DIRS.
for base_dir in BASE_DIRS:
if path.IsSubpath(source_path, base_dir):
break
else:
parser.error('"%s" is not in any of the base dirs.')
# Ensure target directory exists.
if not (os.path.exists(args.target) or
os.path.exists(os.path.dirname(args.target))):
parser.error('"%s" not found.' % args.target)
# Ensure target path is in one of the BASE_DIRS.
for base_dir in BASE_DIRS:
if path.IsSubpath(args.target, base_dir):
break
else:
parser.error('"%s" is not in any of the base dirs.')
# If there are multiple source paths, ensure target is a directory.
if len(args.source) > 1 and not os.path.isdir(args.target):
parser.error('Target "%s" is not a directory.' % args.target)
# Ensure target is not in any of the source paths.
for source_path in args.source:
if path.IsSubpath(args.target, source_path):
parser.error('Cannot move "%s" to a subdirectory of itself, "%s".' %
(source_path, args.target))
def Run(self, args):
move.Run(args.source, args.target, _TelemetryFiles())
for module_path in _TelemetryFiles():
SortImportGroups(module_path)
return 0
class Sort(command_line.Command):
"""Sort imports."""
@classmethod
def AddCommandLineArgs(cls, parser):
parser.add_argument('target', nargs='*')
@classmethod
def ProcessCommandLineArgs(cls, parser, args):
for target in args.target:
if not os.path.exists(target):
parser.error('"%s" not found.' % target)
def Run(self, args):
if args.target:
targets = args.target
else:
targets = BASE_DIRS
for base_dir in targets:
for module_path in path.ListFiles(base_dir, _IsSourceDir, _IsPythonModule):
SortImportGroups(module_path)
return 0
class RefactorCommand(command_line.SubcommandCommand):
commands = (Count, Mv, Sort,)
if __name__ == '__main__':
sys.exit(RefactorCommand.main())