blob: d5aa41b577186b2f08ea1dbb9068d70e92151abc [file] [log] [blame]
# Copyright 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""A very very simple mock object harness."""
DONT_CARE = ''
class MockFunctionCall(object):
def __init__(self, name):
self.name = name
self.args = tuple()
self.return_value = None
self.when_called_handlers = []
def WithArgs(self, *args):
self.args = args
return self
def WillReturn(self, value):
self.return_value = value
return self
def WhenCalled(self, handler):
self.when_called_handlers.append(handler)
def VerifyEquals(self, got):
if self.name != got.name:
raise Exception('Self %s, got %s' % (repr(self), repr(got)))
if len(self.args) != len(got.args):
raise Exception('Self %s, got %s' % (repr(self), repr(got)))
for i in range(len(self.args)):
self_a = self.args[i]
got_a = got.args[i]
if self_a == DONT_CARE:
continue
if self_a != got_a:
raise Exception('Self %s, got %s' % (repr(self), repr(got)))
def __repr__(self):
def arg_to_text(a):
if a == DONT_CARE:
return '_'
return repr(a)
args_text = ', '.join([arg_to_text(a) for a in self.args])
if self.return_value in (None, DONT_CARE):
return '%s(%s)' % (self.name, args_text)
return '%s(%s)->%s' % (self.name, args_text, repr(self.return_value))
class MockTrace(object):
def __init__(self):
self.expected_calls = []
self.next_call_index = 0
class MockObject(object):
def __init__(self, parent_mock=None):
if parent_mock:
self._trace = parent_mock._trace # pylint: disable=protected-access
else:
self._trace = MockTrace()
def __setattr__(self, name, value):
if (not hasattr(self, '_trace') or
hasattr(value, 'is_hook')):
object.__setattr__(self, name, value)
return
assert isinstance(value, MockObject)
object.__setattr__(self, name, value)
def SetAttribute(self, name, value):
setattr(self, name, value)
def ExpectCall(self, func_name, *args):
assert self._trace.next_call_index == 0
if not hasattr(self, func_name):
self._install_hook(func_name)
call = MockFunctionCall(func_name)
self._trace.expected_calls.append(call)
call.WithArgs(*args)
return call
def _install_hook(self, func_name):
def handler(*args, **_):
got_call = MockFunctionCall(
func_name).WithArgs(*args).WillReturn(DONT_CARE)
if self._trace.next_call_index >= len(self._trace.expected_calls):
raise Exception(
'Call to %s was not expected, at end of programmed trace.' %
repr(got_call))
expected_call = self._trace.expected_calls[
self._trace.next_call_index]
expected_call.VerifyEquals(got_call)
self._trace.next_call_index += 1
for h in expected_call.when_called_handlers:
h(*args)
return expected_call.return_value
handler.is_hook = True
setattr(self, func_name, handler)