blob: d0c79f9a9409597529c53c12d0b7ae0eded7be44 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Contains functionality used to implement a partial mock."""
from __future__ import print_function
import collections
import os
import re
import mock
import six
from chromite.lib import cros_build_lib
from chromite.lib import cros_logging as logging
from chromite.lib import osutils
def _PredicateSplit(func, iterable):
"""Splits an iterable into two groups based on a predicate return value.
Args:
func: A functor that takes an item as its argument and returns a boolean
value indicating which group the item belongs.
iterable: The collection to split.
Returns:
A tuple containing two lists, the first containing items that func()
returned True for, and the second containing items that func() returned
False for.
"""
trues, falses = [], []
for x in iterable:
(trues if func(x) else falses).append(x)
return trues, falses
class Comparator(object):
"""Base class for all comparators."""
def Match(self, arg):
"""Match the comparator against an argument."""
raise NotImplementedError('method must be implemented by a subclass.')
def Equals(self, rhs):
"""Returns whether rhs compares the same thing."""
return isinstance(rhs, type(self)) and self.__dict__ == rhs.__dict__
def __eq__(self, rhs):
return self.Equals(rhs)
def __ne__(self, rhs):
return not self.Equals(rhs)
class In(Comparator):
"""Checks whether an item (or key) is in a list (or dict) parameter."""
def __init__(self, key):
"""Initialize.
Args:
key: Any thing that could be in a list or a key in a dict
"""
Comparator.__init__(self)
self._key = key
def Match(self, arg):
try:
return self._key in arg
except TypeError:
return False
def __repr__(self):
return '<sequence or map containing %r>' % str(self._key)
class InOrder(Comparator):
"""Checks whether every items of a list exists in a list/dict parameter."""
def __init__(self, items):
"""Constructor.
Args:
items: A list of things that could be in a list or a key in a dict
"""
super(InOrder, self).__init__()
self.items = items
def Match(self, arg):
"""Checks if args' item matches all expected items in sequence.
Args:
arg: parameter list.
Returns:
True if all expected items are matched.
"""
items = list(self.items)
to_match = items.pop(0)
for a in arg:
if to_match == a:
if len(items) < 1:
return True
to_match = items.pop(0)
return False
def __repr__(self):
return '<sequence or map containing %r>' % str(self.items)
class Regex(Comparator):
"""Checks if a string matches a regular expression."""
def __init__(self, pattern, flags=0):
"""Initialize.
Args:
pattern: is the regular expression to search for
flags: passed to re.compile function as the second argument
"""
Comparator.__init__(self)
self.pattern = pattern
self.flags = flags
self.regex = re.compile(pattern, flags=flags)
def Match(self, arg):
try:
return self.regex.search(arg) is not None
except TypeError:
return False
def __repr__(self):
s = '<regular expression %r' % self.regex.pattern
if self.regex.flags:
s += ', flags=%d' % self.regex.flags
s += '>'
return s
class ListRegex(Regex):
"""Checks if any string from an iterable matches a regular expression.
This can be used to match on a list. For example, a regex of
'dump_fmap -p .*bios.bin' will match ['dump_fmap', '-p', 'mainbios.bin'].
"""
@staticmethod
def _ProcessArg(arg):
if not isinstance(arg, six.string_types):
return ' '.join(arg)
return arg
def Match(self, arg):
try:
return self.regex.search(self._ProcessArg(arg)) is not None
except TypeError:
return False
class Ignore(Comparator):
"""Used when we don't care about an argument of a method call."""
def Match(self, _arg):
return True
def __repr__(self):
return '<IgnoreArg>'
class HasString(str):
"""A substring matcher for mock assertion.
It overrides str's '==' operator so that
HasString('substring') == 'A sentence with substring'
It is used for mock.assert_called_with(). Note that it is not a Comparator.
Examples:
some_mock.assert_called_with(
partial_mock.HasString('need_this_keyword'))
"""
def __eq__(self, target):
return self in target
def _RecursiveCompare(lhs, rhs):
"""Compare parameter specs recursively.
Args:
lhs: Left Hand Side parameter spec to compare.
rhs: Right Hand Side parameter spec to compare.
equality: In the case of comparing Comparator objects, True means we call
the Equals() function. We call Match() if set to False (default).
"""
if isinstance(lhs, Comparator):
return lhs.Match(rhs)
elif isinstance(lhs, (tuple, list)):
return (isinstance(rhs, type(lhs)) and
len(lhs) == len(rhs) and
all(_RecursiveCompare(i, j) for i, j in zip(lhs, rhs)))
elif isinstance(lhs, dict):
return _RecursiveCompare(sorted(lhs.items()), sorted(rhs.items()))
else:
return lhs == rhs
def ListContains(small, big, strict=False):
"""Looks for a sublist within a bigger list.
Args:
small: The sublist or string to search for.
big: The list to search in.
strict: If True, all items in list must be adjacent.
"""
if strict:
for i in range(len(big) - len(small) + 1):
if _RecursiveCompare(small, big[i:i + len(small)]):
return True
return False
else:
j = 0
for s in small:
for j in range(j, len(big)):
if _RecursiveCompare(s, big[j]):
j += 1
break
else:
return False
return True
def DictContains(small, big):
"""Looks for a subset within a dictionary.
Args:
small: The sub-dict to search for.
big: The dict to search in.
"""
for k, v in small.items():
if k not in big or not _RecursiveCompare(v, big[k]):
return False
return True
class MockedCallResults(object):
"""Implements internal result specification for partial mocks.
Used with the PartialMock class.
Internal results are different from external results (return values,
side effects, exceptions, etc.) for functions. Internal results are
*used* by the partial mock to generate external results. Often internal
results represent the external results of the dependencies of the function
being partially mocked. Of course, the partial mock can just pass through
the internal results to become external results.
"""
Params = collections.namedtuple('Params', ['args', 'kwargs'])
MockedCall = collections.namedtuple(
'MockedCall', ['params', 'strict', 'result', 'side_effect'])
def __init__(self, name):
"""Initialize.
Args:
name: The name given to the mock. Will be used in debug output.
"""
self.name = name
self.mocked_calls = []
self.default_result, self.default_side_effect = None, None
@staticmethod
def AssertArgs(args, kwargs):
"""Verify arguments are of expected type."""
assert isinstance(args, (tuple))
if kwargs:
assert isinstance(kwargs, dict)
def AddResultForParams(self, args, result, kwargs=None, side_effect=None,
strict=True):
"""Record the internal results of a given partial mock call.
Args:
args: A list containing the positional args an invocation must have for
it to match the internal result. The list can contain instances of
meta-args (such as IgnoreArg, Regex, In, etc.). Positional argument
matching is always *strict*, meaning extra positional arguments in
the invocation are not allowed.
result: The internal result that will be matched for the command
invocation specified.
kwargs: A dictionary containing the keyword args an invocation must have
for it to match the internal result. The dictionary can contain
instances of meta-args (such as IgnoreArg, Regex, In, etc.). Keyword
argument matching is by default *strict*, but can be modified by the
|strict| argument.
side_effect: A functor that gets called every time a partially mocked
function is invoked. The arguments the partial mock is invoked with are
passed to the functor. This is similar to how side effects work for
mocks.
strict: Specifies whether keyword are matched strictly. With strict
matching turned on, any keyword args a partial mock is invoked with that
are not specified in |kwargs| will cause the match to fail.
"""
self.AssertArgs(args, kwargs)
if kwargs is None:
kwargs = {}
params = self.Params(args=args, kwargs=kwargs)
dup, filtered = _PredicateSplit(
lambda mc: mc.params == params, self.mocked_calls)
new = self.MockedCall(params=params, strict=strict, result=result,
side_effect=side_effect)
filtered.append(new)
self.mocked_calls = filtered
if dup:
logging.debug('%s: replacing mock for arguments %r:\n%r -> %r',
self.name, params, dup, new)
def SetDefaultResult(self, result, side_effect=None):
"""Set the default result for an unmatched partial mock call.
Args:
result: See AddResultsForParams.
side_effect: See AddResultsForParams.
"""
self.default_result, self.default_side_effect = result, side_effect
def LookupResult(self, args, kwargs=None, hook_args=None, hook_kwargs=None):
"""For a given mocked function call lookup the recorded internal results.
Args:
args: A list containing positional args the function was called with.
kwargs: A dict containing keyword args the function was called with.
hook_args: A list of positional args to call the hook with.
hook_kwargs: A dict of key/value args to call the hook with.
Returns:
The recorded result for the invocation.
Raises:
AssertionError when the call is not mocked, or when there is more
than one mock that matches.
"""
def filter_fn(mc):
if mc.strict:
return _RecursiveCompare(mc.params, params)
return (DictContains(mc.params.kwargs, kwargs) and
_RecursiveCompare(mc.params.args, args))
def is_exception(obj):
"""Returns True if obj is an exception instance or class."""
return (
isinstance(obj, BaseException) or
isinstance(obj, type) and issubclass(obj, BaseException))
self.AssertArgs(args, kwargs)
if kwargs is None:
kwargs = {}
params = self.Params(args, kwargs)
matched, _ = _PredicateSplit(filter_fn, self.mocked_calls)
if len(matched) > 1:
raise AssertionError(
'%s: args %r matches more than one mock:\n%s'
% (self.name, params, '\n'.join([repr(c) for c in matched])))
elif matched:
side_effect, result = matched[0].side_effect, matched[0].result
elif (self.default_result, self.default_side_effect) != (None, None):
side_effect, result = self.default_side_effect, self.default_result
else:
raise AssertionError('%s: %r not mocked!' % (self.name, params))
if is_exception(side_effect):
raise side_effect
if side_effect:
assert hook_args is not None
assert hook_kwargs is not None
hook_result = side_effect(*hook_args, **hook_kwargs)
if hook_result is not None:
return hook_result
return result
class PartialMock(object):
"""Provides functionality for partially mocking out a function or method.
Partial mocking is useful in cases where the side effects of a function or
method are complex, and so re-using the logic of the function with
*dependencies* mocked out is preferred over mocking out the entire function
and re-implementing the side effect (return value, state modification) logic
in the test. It is also useful for creating re-usable mocks.
Methods mocked out will retain the same spec as their original, but will still
be bound to this class rather than the mocked class. In the example below,
this is why the mocked print function has an extra |inst| argument after the
first |self| argument.
Examples:
# Defined in chromite/lib/foo.py.
class SomeClass(object):
def print(self, msg):
...
def write(self, fd, msg):
...
# Defined in chromite/lib/foo_unittest.py.
class SomeClassMock(partial_mock.PartialMock):
TARGET = 'chromite.lib.foo.SomeClass'
ATTRS = ('print',)
# NB: |self| refers to the instance of |SomeClassMock| while |inst| refers
# to the instance of |SomeClass|. This allows access to state in either.
def print(self, inst, msg):
...
Attributes:
TARGET: The import spec for the target object to be (partially) mocked.
This is like the target argument to mock.patch().
ATTRS: A tuple of attributes on |TARGET| to mock out. Each attribute must
be defined with a corresponding signature. You may mock out as many or
few attributes as needed (hence, "partial mock").
"""
# The import spec for the object being mocked.
TARGET = None
# Tuples of attribute names on the target object to mock.
ATTRS = None
def __init__(self, create_tempdir=False):
"""Initialize.
Args:
create_tempdir: If set to True, the partial mock will create its own
temporary directory when start() is called, and will set self.tempdir to
the path of the directory. The directory is deleted when stop() is
called.
"""
self.backup = {}
self.patchers = {}
self.patched = {}
self.external_patchers = []
self.create_tempdir = create_tempdir
# Set when start() is called.
self._tempdir_obj = None
self.tempdir = None
self.__saved_env__ = None
self.started = False
self._results = {}
if not all([self.TARGET, self.ATTRS]) and any([self.TARGET, self.ATTRS]):
raise AssertionError('TARGET=%r but ATTRS=%r!'
% (self.TARGET, self.ATTRS))
if self.ATTRS is not None:
# pylint: disable=not-an-iterable
for attr in self.ATTRS:
self._results[attr] = MockedCallResults(attr)
def __enter__(self):
return self.start()
def __exit__(self, exc_type, exc_value, traceback):
self.stop()
def PreStart(self):
"""Called at the beginning of start(). Child classes can override this.
If __init__ was called with |create_tempdir| set, then self.tempdir will
point to an existing temporary directory when this function is called.
"""
def PreStop(self):
"""Called at the beginning of stop(). Child classes can override this.
If __init__ was called with |create_tempdir| set, then self.tempdir will
not be deleted until after this function returns.
"""
def StartPatcher(self, patcher):
"""PartialMock will stop the patcher when stop() is called."""
self.external_patchers.append(patcher)
return patcher.start()
def PatchObject(self, *args, **kwargs):
"""Create and start a mock.patch.object().
stop() will be called automatically during tearDown.
"""
return self.StartPatcher(mock.patch.object(*args, **kwargs))
def _start(self):
if not all([self.TARGET, self.ATTRS]):
return
name, member = self.TARGET.rsplit('.', 1)
module = __import__(name)
# __import__('foo.bar') returns foo, so...
for bit in name.split('.')[1:]:
module = getattr(module, bit)
cls = getattr(module, member)
for attr in self.ATTRS: # pylint: disable=not-an-iterable
self.backup[attr] = getattr(cls, attr)
src_attr = '_target%s' % attr if attr.startswith('__') else attr
if hasattr(self.backup[attr], 'reset_mock'):
raise AssertionError(
'You are trying to nest mock contexts - this is currently '
'unsupported by PartialMock.')
if callable(self.backup[attr]):
patcher = mock.patch.object(cls, attr, autospec=True,
side_effect=getattr(self, src_attr))
else:
patcher = mock.patch.object(cls, attr, getattr(self, src_attr))
self.patched[attr] = patcher.start()
self.patchers[attr] = patcher
return self
def start(self):
"""Activates the mock context."""
try:
self.__saved_env__ = os.environ.copy()
self.tempdir = None
if self.create_tempdir:
self._tempdir_obj = osutils.TempDir(set_global=True)
self.tempdir = self._tempdir_obj.tempdir
self.started = True
self.PreStart()
return self._start()
except:
self.stop()
raise
def stop(self):
"""Restores namespace to the unmocked state."""
try:
if self.__saved_env__ is not None:
osutils.SetEnvironment(self.__saved_env__)
tasks = ([self.PreStop] + [p.stop for p in self.patchers.values()] +
[p.stop for p in self.external_patchers])
if self._tempdir_obj is not None:
tasks += [self._tempdir_obj.Cleanup]
cros_build_lib.SafeRun(tasks)
finally:
self.started = False
self.tempdir, self._tempdir_obj = None, None
def UnMockAttr(self, attr):
"""Unsetting the mock of an attribute/function."""
self.patchers.pop(attr).stop()
def CheckAttr(f):
"""Automatically set mock_attr based on class default.
This function decorator automatically sets the mock_attr keyword argument
based on the class default. The mock_attr specifies which mocked attribute
a given function is referring to.
Raises an AssertionError if mock_attr is left unspecified.
"""
def new_f(self, *args, **kwargs):
mock_attr = kwargs.pop('mock_attr', None)
if mock_attr is None:
mock_attr = self.DEFAULT_ATTR
if self.DEFAULT_ATTR is None:
raise AssertionError(
'mock_attr not specified, and no default configured.')
return f(self, *args, mock_attr=mock_attr, **kwargs)
return new_f
class PartialCmdMock(PartialMock):
"""Base class for mocking functions that wrap command line functionality.
Implements mocking for functions that shell out. The internal results are
'returncode', 'output', 'error'.
"""
DEFAULT_ATTR = None
# TODO(crbug.com/1006587): Drop redundant arguments & backwards compat APIs.
@CheckAttr
def SetDefaultCmdResult(self, returncode=0, output=None, error=None,
stdout=None, stderr=None,
side_effect=None, mock_attr=None):
"""Specify the default command result if no command is matched.
Args:
returncode: See AddCmdResult.
output: (Deprecated) Alias to stdout.
error: (Deprecated) Alias to stderr.
stdout: See AddCmdResult.
stderr: See AddCmdResult.
side_effect: See MockedCallResults.AddResultForParams
mock_attr: Which attributes's mock is being referenced.
"""
if stdout is None:
stdout = output
elif output is not None:
raise ValueError('Only specify |stdout|, not |output|')
if stdout is None:
stdout = ''
if stderr is None:
stderr = error
elif error is not None:
raise ValueError('Only specify |stderr|, not |error|')
if stderr is None:
stderr = ''
result = cros_build_lib.CommandResult(
returncode=returncode, stdout=stdout, stderr=stderr)
self._results[mock_attr].SetDefaultResult(result, side_effect)
# TODO(crbug.com/1006587): Drop redundant arguments & backwards compat APIs.
@CheckAttr
def AddCmdResult(self, cmd, returncode=0, output=None, error=None,
stdout=None, stderr=None, kwargs=None, strict=False,
side_effect=None, mock_attr=None):
"""Specify the result to simulate for a given command.
Args:
cmd: The command string or list to record a result for.
returncode: The returncode of the command (on the command line).
output: (Deprecated) Alias to stdout.
error: (Deprecated) Alias to stderr.
stdout: The stdout output of the command.
stderr: The stderr output of the command.
kwargs: Keyword arguments that the function needs to be invoked with.
strict: Defaults to False. See MockedCallResults.AddResultForParams.
side_effect: See MockedCallResults.AddResultForParams
mock_attr: Which attributes's mock is being referenced.
"""
if stdout is None:
stdout = output
elif output is not None:
raise ValueError('Only specify |stdout|, not |output|')
if stdout is None:
stdout = ''
if stderr is None:
stderr = error
elif error is not None:
raise ValueError('Only specify |stderr|, not |error|')
if stderr is None:
stderr = ''
result = cros_build_lib.CommandResult(
returncode=returncode, stdout=stdout, stderr=stderr)
self._results[mock_attr].AddResultForParams(
(cmd,), result, kwargs=kwargs, side_effect=side_effect, strict=strict)
@CheckAttr
def CommandContains(self, args, cmd_arg_index=-1, mock_attr=None, **kwargs):
"""Verify that at least one command contains the specified args.
Args:
args: Set of expected command-line arguments.
cmd_arg_index: The index of the command list in the positional call_args.
Defaults to the last positional argument.
kwargs: Set of expected keyword arguments.
mock_attr: Which attributes's mock is being referenced.
"""
for call_args, call_kwargs in self.patched[mock_attr].call_args_list:
if (ListContains(args, call_args[cmd_arg_index],
strict=isinstance(args, str)) and
DictContains(kwargs, call_kwargs)):
return True
return False
@CheckAttr
def assertCommandContains(self, args=(), expected=True, mock_attr=None,
**kwargs):
"""Assert that run was called with the specified args.
This verifies that at least one of the run calls contains the
specified arguments on the command line.
Args:
args: Set of expected command-line arguments.
expected: If False, instead verify that none of the run calls
contained the specified arguments.
**kwargs: Set of expected keyword arguments.
mock_attr: Which attributes's mock is being referenced.
"""
if bool(expected) != self.CommandContains(args, **kwargs):
if expected:
msg = 'Expected to find %r in any of:\n%s'
else:
msg = 'Expected to not find %r in any of:\n%s'
patched = self.patched[mock_attr]
cmds = '\n'.join(repr(x) for x in patched.call_args_list)
raise AssertionError(msg % (mock.call(args, **kwargs), cmds))
@CheckAttr
def assertCommandCalled(self, args=(), mock_attr=None, **kwargs):
"""Assert that run was called with the specified args.
This verifies that at least one of the run calls exactly
matches the specified command line and misc-arguments.
Args:
args: Set of expected command-line arguments.
mock_attr: Which attributes's mock is being referenced.
**kwargs: Set of expected keyword arguments.
"""
call = mock.call(args, **kwargs)
patched = self.patched[mock_attr]
for icall in patched.call_args_list:
if call == icall:
return
cmds = '\n'.join(repr(x) for x in patched.call_args_list)
raise AssertionError('Expected to find %r in any of:\n%s' % (call, cmds))
@property
@CheckAttr
def call_count(self, mock_attr=None):
"""Return the number of times we've been called."""
return self.patched[mock_attr].call_count
@property
@CheckAttr
def call_args_list(self, mock_attr=None):
"""Return the list of args we've been called with."""
return self.patched[mock_attr].call_args_list