blob: ad2f54dcbafbb3978b6db9da6c3b17a19468a67d [file] [log] [blame]
#!/usr/bin/env python
# Copyright 2014 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""
Collection of unit tests for 'infra.libs.memoize' library.
"""
import unittest
from infra_libs import memoize
class MemoTestCase(unittest.TestCase):
def setUp(self):
self.tagged = set()
def tag(self, tag=None):
self.assertNotIn(tag, self.tagged)
self.tagged.add(tag)
def assertTagged(self, *tags):
if len(tags) == 0:
tags = (None,)
self.assertEqual(set(tags), self.tagged)
def clearTagged(self):
self.tagged.clear()
class FunctionTestCase(MemoTestCase):
def testFuncNoArgs(self):
@memoize.memo()
def func():
self.tag()
return 'foo'
for _ in xrange(10):
self.assertEqual(func(), 'foo')
self.assertTagged()
def testFuncAllArgs(self):
@memoize.memo()
def func(a, b):
self.tag((a, b))
return a + b
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(func(1, 2), 3)
self.assertEqual(func(3, 4), 7)
self.assertTagged(
(1, 2),
(3, 4),
)
def testFuncIgnoreArgs(self):
@memoize.memo(ignore=('b'))
def func(a, b):
self.tag(a)
return a + b
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(func(1, 1), 2)
self.assertEqual(func(1, 2), 2)
self.assertEqual(func(2, 1), 3)
self.assertEqual(func(2, 2), 3)
self.assertTagged(
1,
2,
)
def testOldClassMethod(self):
class Test:
# Disable 'no __init__ method' warning | pylint: disable=W0232
# pylint: disable=old-style-class
@classmethod
@memoize.memo()
def func(cls, a):
self.tag(a)
return a
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(Test.func(1), 1)
self.assertEqual(Test.func(2), 2)
self.assertTagged(
1,
2,
)
def testNewClassMethod(self):
class Test(object):
# Disable 'no __init__ method' warning | pylint: disable=W0232
@classmethod
@memoize.memo()
def func(cls, a):
self.tag(a)
return a
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(Test.func(1), 1)
self.assertEqual(Test.func(2), 2)
self.assertTagged(
1,
2,
)
def testOldClassStaticMethod(self):
class Test:
# Disable 'no __init__ method' warning | pylint: disable=W0232
# pylint: disable=old-style-class
@staticmethod
@memoize.memo()
def func(a):
self.tag(a)
return a
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(Test.func(1), 1)
self.assertEqual(Test.func(2), 2)
self.assertTagged(
1,
2,
)
def testNewClassStaticMethod(self):
class Test(object):
# Disable 'no __init__ method' warning | pylint: disable=W0232
@staticmethod
@memoize.memo()
def func(a):
self.tag(a)
return a
# Execute multiple rounds of two unique function executions.
for _ in xrange(10):
self.assertEqual(Test.func(1), 1)
self.assertEqual(Test.func(2), 2)
self.assertTagged(
1,
2,
)
def testClearAllArgs(self):
@memoize.memo()
def func(a, b=10):
self.tag((a, b))
return a + b
# First round
self.assertEqual(func(1), 11)
self.assertEqual(func(1, b=0), 1)
self.assertTagged(
(1, 10),
(1, 0),
)
# Clear (1)
self.clearTagged()
func.memo_clear(1)
self.assertEqual(func(1), 11)
self.assertEqual(func(1, b=0), 1)
self.assertTagged(
(1, 10),
)
# Clear (1, b=0)
self.clearTagged()
func.memo_clear(1, b=0)
self.assertEqual(func(1), 11)
self.assertEqual(func(1, b=0), 1)
self.assertTagged(
(1, 0),
)
class MemoInstanceMethodTestCase(MemoTestCase):
class TestBaseOld:
# pylint: disable=old-style-class
def __init__(self, test_case, name):
self.test_case = test_case
self.name = name
def __hash__(self):
# Prevent this instance from being used as a memo key
raise NotImplementedError()
class TestBaseNew(object):
def __init__(self, test_case, name):
self.test_case = test_case
self.name = name
def __hash__(self):
# Prevent this instance from being used as a memo key
raise NotImplementedError()
class TestHash(object):
def __init__(self):
self._counter = 0
@memoize.memo()
def __hash__(self):
assert self._counter == 0
self._counter += 1
return self._counter
def testOldClassNoArgs(self):
class Test(self.TestBaseOld):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self):
self.test_case.tag(self.name)
return 'foo'
t0 = Test(self, 't0')
t1 = Test(self, 't1')
for _ in xrange(10):
self.assertEqual(t0.func(), 'foo')
self.assertEqual(t1.func(), 'foo')
self.assertTagged(
't0',
't1',
)
def testNewClassNoArgs(self):
class Test(self.TestBaseNew):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self):
self.test_case.tag(self.name)
return 'foo'
t0 = Test(self, 't0')
t1 = Test(self, 't1')
for _ in xrange(10):
self.assertEqual(t0.func(), 'foo')
self.assertEqual(t1.func(), 'foo')
self.assertTagged(
't0',
't1',
)
def testOldClassArgs(self):
class Test(self.TestBaseOld):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self, a, b):
self.test_case.tag((self.name, a, b))
return a + b
t0 = Test(self, 't0')
t1 = Test(self, 't1')
for _ in xrange(10):
self.assertEqual(t0.func(1, 2), 3)
self.assertEqual(t0.func(1, 3), 4)
self.assertEqual(t1.func(1, 2), 3)
self.assertEqual(t1.func(1, 3), 4)
self.assertTagged(
('t0', 1, 2),
('t0', 1, 3),
('t1', 1, 2),
('t1', 1, 3),
)
def testNewClassArgs(self):
class Test(self.TestBaseNew):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self, a, b):
self.test_case.tag((self.name, a, b))
return a + b
t0 = Test(self, 't0')
t1 = Test(self, 't1')
for _ in xrange(10):
self.assertEqual(t0.func(1, 2), 3)
self.assertEqual(t0.func(1, 3), 4)
self.assertEqual(t1.func(1, 2), 3)
self.assertEqual(t1.func(1, 3), 4)
self.assertTagged(
('t0', 1, 2),
('t0', 1, 3),
('t1', 1, 2),
('t1', 1, 3),
)
def testNewClassDirectCall(self):
class Test(self.TestBaseNew):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self, a, b):
self.test_case.tag((self.name, a, b))
return a + b
t0 = Test(self, 't0')
for _ in xrange(10):
self.assertEqual(t0.func.__get__(t0)(1,2), 3)
self.assertTagged(
('t0', 1, 2),
)
def testClear(self):
class Test(self.TestBaseNew):
# Disable 'hash not overridden' warning | pylint: disable=W0223
@memoize.memo()
def func(self, a):
self.test_case.tag((self.name, a))
return a
# Call '10' and '20'
t = Test(self, 'test')
t.func(10)
self.assertTagged(
('test', 10),
)
# Clear
self.clearTagged()
t.func.memo_clear(10)
# Call '10'; it should be tagged
t.func(10)
self.assertTagged(
('test', 10),
)
def testOverrideHash(self):
t = self.TestHash()
self.assertEquals(hash(t), 1)
self.assertEquals(hash(t), 1)
class MemoClassMethodTestCase(MemoTestCase):
"""Tests handling of the 'cls' and 'self' parameters"""
class Test(object):
def __init__(self, test_case, name):
self.test_case = test_case
self.name = name
self._value = 0
@memoize.memo(ignore=('tag',))
def func(self, a, tag):
self.test_case.tag(tag)
return a
@classmethod
@memoize.memo(ignore=('test_case', 'tag'))
def class_func(cls, test_case, memo_value, tag):
test_case.tag(tag)
return memo_value
@property
@memoize.memo()
def prop(self):
self.test_case.tag(self.name)
return self._value
@prop.setter
def prop(self, value):
self._value = value
@prop.deleter
def prop(self):
self._value = 0
class TestWithEquals(Test):
def __hash__(self):
return 7
def __eq__(self, other):
return type(other) == type(self)
def testClassMethodNoEquals(self):
self.assertEqual(self.Test.class_func(self, 1, 't0'), 1)
self.assertEqual(self.Test.class_func(self, 1, 't1'), 1)
self.assertEqual(self.Test.class_func(self, 2, 't2'), 2)
self.Test.class_func.memo_clear(self.Test, self, 2, None)
self.assertEqual(self.Test.class_func(self, 1, 't3'), 1)
self.assertEqual(self.Test.class_func(self, 2, 't4'), 2)
self.Test.class_func.memo_clear()
self.assertEqual(self.Test.class_func(self, 1, 't5'), 1)
self.assertTagged(
't0',
't2',
't4',
't5',
)
def testInstanceMethodNoEquals(self):
t0 = self.Test(self, 't0')
t1 = self.Test(self, 't1')
self.assertEqual(t0.func(1, 't0.0'), 1)
self.assertEqual(t1.func(1, 't1.0'), 1)
t1.func.memo_clear(1, None)
self.assertEqual(t0.func(1, 't0.1'), 1)
self.assertEqual(t1.func(1, 't1.1'), 1)
self.assertTagged(
't0.0',
't1.0',
't1.1',
)
def testInstanceMethodWithEquals(self):
t0 = self.TestWithEquals(self, 't0')
t1 = self.TestWithEquals(self, 't1')
self.assertEqual(hash(t0), 7)
self.assertTrue(t0 == t1)
self.assertEqual(t0.func(1, 't0.0'), 1)
self.assertEqual(t1.func(1, 't1.0'), 1)
t1.func.memo_clear(1, None)
self.assertEqual(t0.func(1, 't0.1'), 1)
self.assertEqual(t1.func(1, 't1.1'), 1)
self.assertTagged(
't0.0',
't1.0',
't1.1',
)
def testProperty(self):
t0 = self.Test(self, 't0')
t1 = self.Test(self, 't1')
# The property can be set.
t0.prop = 1024
t1.prop = 1337
del(t1.prop)
for _ in xrange(10):
self.assertEqual(t0.prop, 1024)
self.assertEqual(t1.prop, 0)
self.assertTagged(
't0',
't1',
)
if __name__ == '__main__':
unittest.main()