blob: bd4d808e4c65d49f87d2ad72ec16ea5f17c59933 [file] [log] [blame]
#!/usr/bin/env python
#
# Copyright 2016-present the Material Components for iOS authors. All Rights
# Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import sys
from collections import namedtuple
# Prefix of all MDC umbrella headers.
UMBRELLA_HEADER_PREFIX = 'Material'
# Import prefixes that should not be checked.
IMPORT_FILTER_PREFIXES = (
# Objective-C imports
'Availability.h',
'CoreFoundation/',
'CoreGraphics/',
'Foundation/',
'QuartzCore/',
'UIKit/',
'objc/',
# Third-party imports
'MDF',
'Motion',
)
# Import suffixes that should not be checked.
IMPORT_FILTER_SUFFIXES = (
# Color themers do not have umbrella headers.
'ColorThemer.h',
)
# Regex to match import of two kinds:
# 1) #import "Foo.h", group #1
# 2) #import <Foo.h>, group #2
IMPORT_RE = re.compile(r'#import (\"(.*\.h)\"|\<(.*\.h)\>)')
def umbrella_header(component):
"""Returns an umbrella header for the component."""
return '{}{}.h'.format(UMBRELLA_HEADER_PREFIX, component)
def framework_umbrella_header(component):
"""Returns a framework umbrella header for the component."""
return 'MaterialComponents/{}'.format(umbrella_header(component))
# Context in which the check is executed.
# @checked_file - the name of the file that is checked.
# @module_files - list of files in currently analyzed sub-component.
CheckContext = namedtuple('CheckContext', ['checked_file', 'module_files'])
class ImportChecks(object):
"""A class containing factory methods for various checks and filters."""
@staticmethod
def import_filter():
"""Returns a function to filter Objective-C and third-party framework imports."""
def check(import_str, _):
filter = import_str.startswith(IMPORT_FILTER_PREFIXES) or \
import_str.endswith(IMPORT_FILTER_SUFFIXES)
return filter, None
return check
@staticmethod
def module_import_filter():
"""Returns a function to filter same module file imports."""
def check(import_str, check_context):
return any(module_file.endswith(import_str) for module_file in check_context.module_files), None
return check
@staticmethod
def non_umbrella_header_check():
"""Returns a function to check that an umbrella header is imported."""
def check(import_str, check_context):
if not import_str.startswith(UMBRELLA_HEADER_PREFIX):
error_msg = '{} imports a non-umbrella header: {}'.format(check_context.checked_file, import_str)
return True, error_msg
return False, None
return check
@staticmethod
def own_umbrella_header_check(umbrella_header):
"""Returns a function to check that checked component umbrella header is not used."""
def check(import_str, check_context):
if import_str == umbrella_header:
return True, '{} imports an own umbrella header: {}'.format(check_context.checked_file, import_str)
return False, None
return check
class ImportChecker(object):
"""A class to run checks over the import strings."""
def __init__(self, check_functions):
self.check_functions = check_functions
def check_import(self, import_str, check_context):
"""Runs checks over the import string and returns an error message in check failed."""
for check in self.check_functions:
stop, error_msg = check(import_str, check_context)
if error_msg:
return error_msg
if stop:
break
return None
@staticmethod
def src_checker(component):
"""Returns an 'src' code checker."""
return ImportChecker([
ImportChecks.import_filter(),
ImportChecks.own_umbrella_header_check(umbrella_header(component)),
ImportChecks.own_umbrella_header_check(framework_umbrella_header(component)),
ImportChecks.module_import_filter(),
ImportChecks.non_umbrella_header_check(),
])
@staticmethod
def general_checker():
"""Returns a general purpose checker."""
return ImportChecker([
ImportChecks.import_filter(),
ImportChecks.module_import_filter(),
ImportChecks.non_umbrella_header_check(),
])
def imports_in_file(file_path):
"""Returns a list of import strings."""
imports = []
with open(file_path) as f:
for line in f.readlines():
match = IMPORT_RE.match(line)
if not match:
# No imports found in file.
continue
groups = match.groups()
if groups[1]:
imports.append(groups[1])
elif groups[2]:
imports.append(groups[2])
return imports
def list_objc_files(path, skip_dirs=()):
"""Returns a list of Objective-C files in the path."""
files_relative_paths = []
for dirpath, dirnames, files in os.walk(path):
for f in files:
if os.path.basename(dirpath) in skip_dirs:
continue
if f.endswith(('.h', '.m')):
files_relative_paths.append(os.path.relpath(os.path.join(dirpath, f), path))
return files_relative_paths
def check_src_path(src_path, component_name):
"""Returns whether src_path has any import errors."""
sub_modules = []
submodules_have_errors = False
for dir in os.listdir(src_path):
# If a directory in 'src' is capitalized, treat it as a separate module.
if os.path.isdir(os.path.join(src_path, dir)) and dir[0].isupper():
sub_modules.append(dir)
check_successful = check_path(os.path.join(src_path, dir))
submodules_have_errors = submodules_have_errors or check_successful
src_has_errors = check_path(src_path,
checker=ImportChecker.src_checker(component_name),
excludes=sub_modules)
return src_has_errors or submodules_have_errors
def check_path(path, checker=ImportChecker.general_checker(), excludes=()):
"""Returns whether imports in files at path have import errors."""
has_errors = False
files = list_objc_files(path, skip_dirs=excludes)
for f in files:
check_context = CheckContext(checked_file=os.path.basename(f),
module_files=files)
imports = imports_in_file(os.path.join(path, f))
for import_str in imports:
error = checker.check_import(import_str, check_context)
if error:
print 'ERROR:', error
has_errors = True
return has_errors
def main(args):
if len(args) < 1:
print 'ERROR: Component path should be passed as an input parameter.'
sys.exit(-1)
component_path = os.path.normpath(args[0])
component_name = os.path.basename(component_path)
src_path = os.path.join(component_path, 'src')
examples_path = os.path.join(component_path, 'examples')
has_errors = check_src_path(src_path, component_name) or \
check_path(examples_path)
sys.exit(-1 if has_errors else 0)
if __name__ == '__main__':
main(sys.argv[1:])