blob: e8c144bfb474ed279358bace89a84df6c6c9f669 [file] [log] [blame]
#! /usr/bin/env python
# Copyright 2015 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import cStringIO
import imp
import inspect
import os
import re
import sys
from telemetry.core import command_line
from telemetry.util import path
# All folders dependent on Telemetry, found using a code search.
BASE_DIRS = (
path.GetTelemetryDir(),
os.path.join(path.GetChromiumSrcDir(), 'chrome', 'test', 'telemetry'),
os.path.join(path.GetChromiumSrcDir(), 'content', 'test', 'gpu'),
os.path.join(path.GetChromiumSrcDir(), 'tools', 'bisect-manual-test.py'),
os.path.join(path.GetChromiumSrcDir(), 'tools', 'chrome_proxy'),
os.path.join(path.GetChromiumSrcDir(), 'tools', 'perf'),
os.path.join(path.GetChromiumSrcDir(),
'tools', 'profile_chrome', 'perf_controller.py'),
os.path.join(path.GetChromiumSrcDir(), 'tools', 'run-bisect-manual-test.py'),
os.path.join(path.GetChromiumSrcDir(),
'third_party', 'skia', 'tools', 'skp', 'page_sets'),
os.path.join(path.GetChromiumSrcDir(), 'third_party', 'trace-viewer'),
)
def SortImportGroups(module_path):
"""Sort each group of imports in the given Python module.
A group is a collection of adjacent import statements, with no non-import
lines in between. Groups are sorted according to the Google Python Style
Guide: "lexicographically, ignoring case, according to each module's full
package path."
"""
_TransformImportGroups(module_path, _SortImportGroup)
def _SortImportGroup(import_group):
def _ImportComparator(import1, import2):
_, root1, module1, _, _ = import1
_, root2, module2, _, _ = import2
full_module1 = (root1 + '.' + module1 if root1 else module1).lower()
full_module2 = (root2 + '.' + module2 if root2 else module2).lower()
return cmp(full_module1, full_module2)
return sorted(import_group, cmp=_ImportComparator)
def _TransformImportGroups(module_path, transformation):
"""Apply a transformation to each group of imports in the given module.
An import is a tuple of (indent, root, module, alias, suffix),
serialized as <indent>from <root> import <module> as <alias><suffix>.
Args:
module_path: The module to apply transformations on.
transformation: A function that takes in an import group and returns a
modified import group. An import group is a list of import tuples.
Returns:
True iff the module was modified, and False otherwise.
"""
def _WriteImports(output_stream, import_group):
for indent, root, module, alias, suffix in transformation(import_group):
output_stream.write(indent)
if root:
output_stream.write('from ')
output_stream.write(root)
output_stream.write(' ')
output_stream.write('import ')
output_stream.write(module)
if alias:
output_stream.write(' as ')
output_stream.write(alias)
output_stream.write(suffix)
output_stream.write('\n')
# Read the file so we can diff it later to determine if we made any changes.
with open(module_path, 'r') as module_file:
original_file = module_file.read()
# Locate imports using regex, group them, and transform each one.
# This regex produces a tuple of (indent, root, module, alias, suffix).
regex = (r'(\s*)(?:from ((?:[a-z0-9_]+\.)*[a-z0-9_]+) )?'
r'import ((?:[a-z0-9_]+\.)*[A-Za-z0-9_]+)(?: as ([A-Za-z0-9_]+))?(.*)')
pattern = re.compile(regex)
updated_file = cStringIO.StringIO()
with open(module_path, 'r') as module_file:
import_group = []
for line in module_file:
import_match = pattern.match(line)
if import_match:
import_group.append(list(import_match.groups()))
continue
if not import_group:
updated_file.write(line)
continue
_WriteImports(updated_file, import_group)
import_group = []
updated_file.write(line)
if import_group:
_WriteImports(updated_file, import_group)
import_group = []
if original_file == updated_file.getvalue():
return False
with open(module_path, 'w') as module_file:
module_file.write(updated_file.getvalue())
return True
def _ListFiles(base_directory, should_include_dir, should_include_file):
matching_files = []
for root, dirs, files in os.walk(base_directory):
dirs[:] = [dir_name for dir_name in dirs if should_include_dir(dir_name)]
matching_files += [os.path.join(root, file_name)
for file_name in files if should_include_file(file_name)]
return sorted(matching_files)
class Count(command_line.Command):
"""Print the number of public modules."""
def Run(self, args):
modules = _ListFiles(path.GetTelemetryDir(),
self._IsPublicApiDir, self._IsPublicApiFile)
print len(modules)
return 0
@staticmethod
def _IsPublicApiDir(dir_name):
return (dir_name[0] != '.' and dir_name[0] != '_' and
not dir_name.startswith('internal') and not dir_name == 'third_party')
@staticmethod
def _IsPublicApiFile(file_name):
root, ext = os.path.splitext(file_name)
return (file_name[0] != '.' and
not root.endswith('_unittest') and ext == '.py')
class Mv(command_line.Command):
"""Move modules or packages."""
@classmethod
def AddCommandLineArgs(cls, parser):
parser.add_argument('source', nargs='+')
parser.add_argument('destination')
@classmethod
def ProcessCommandLineArgs(cls, parser, args):
for source in args.source:
# Ensure source path exists.
if not os.path.exists(source):
parser.error('"%s" not found.' % source)
# Ensure source path is in one of the BASE_DIRS.
for base_dir in BASE_DIRS:
if path.IsSubpath(source, base_dir):
break
else:
parser.error('Source path "%s" is not in any of the base dirs.')
# Ensure destination path exists.
if not os.path.exists(args.destination):
parser.error('"%s" not found.' % args.destination)
# Ensure destination path is in one of the BASE_DIRS.
for base_dir in BASE_DIRS:
if path.IsSubpath(args.destination, base_dir):
break
else:
parser.error('Destination path "%s" is not in any of the base dirs.')
# If there are multiple source paths, ensure destination is a directory.
if len(args.source) > 1 and not os.path.isdir(args.destination):
parser.error('Target "%s" is not a directory.' % args.destination)
# Ensure destination is not in any of the source paths.
for source in args.source:
if path.IsSubpath(args.destination, source):
parser.error('Cannot move "%s" to a subdirectory of itself, "%s".' %
(source, args.destination))
def Run(self, args):
for dest_base_dir in BASE_DIRS:
if path.IsSubpath(args.destination, dest_base_dir):
break
# Get a list of old and new module names for renaming imports.
moved_modules = {}
for source in args.source:
for source_base_dir in BASE_DIRS:
if path.IsSubpath(source, source_base_dir):
break
source_dir = os.path.dirname(os.path.normpath(source))
if os.path.isdir(source):
source_files = _ListFiles(source,
self._IsSourceDir, self._IsPythonModule)
else:
source_files = (source,)
for source_file_path in source_files:
source_rel_path = os.path.relpath(source_file_path, source_base_dir)
source_module_name = os.path.splitext(
source_rel_path)[0].replace(os.sep, '.')
source_tree = os.path.relpath(source_file_path, source_dir)
dest_path = os.path.join(args.destination, source_tree)
dest_rel_path = os.path.relpath(dest_path, dest_base_dir)
dest_module_name = os.path.splitext(
dest_rel_path)[0].replace(os.sep, '.')
moved_modules[source_module_name] = dest_module_name
# Move things!
if os.path.isdir(args.destination):
for source in args.source:
destination_path = os.path.join(
args.destination, os.path.split(os.path.normpath(source))[1])
os.rename(source, destination_path)
else:
assert len(args.source) == 1
os.rename(args.source.pop(), args.destination)
# Update imports!
def _UpdateImportGroup(import_group):
modified = False
for import_line in import_group:
_, root, module, _, _ = import_line
full_module = root + '.' + module if root else module
if full_module not in moved_modules:
continue
modified = True
# Update import line.
new_root, _, new_module = moved_modules[full_module].rpartition('.')
import_line[1] = new_root
import_line[2] = new_module
if modified:
return _SortImportGroup(import_group)
else:
return import_group
for base_dir in BASE_DIRS:
for module_path in _ListFiles(base_dir,
self._IsSourceDir, self._IsPythonModule):
if not _TransformImportGroups(module_path, _UpdateImportGroup):
continue
# TODO(dtu): Update occurrences.
print moved_modules
return 0
@staticmethod
def _IsSourceDir(dir_name):
return dir_name[0] != '.' and dir_name != 'third_party'
@staticmethod
def _IsPythonModule(file_name):
_, ext = os.path.splitext(file_name)
return ext == '.py'
class RefactorCommand(command_line.SubcommandCommand):
commands = (Count, Mv,)
if __name__ == '__main__':
sys.exit(RefactorCommand.main())