|  | import abc | 
|  | import builtins | 
|  | import collections | 
|  | import collections.abc | 
|  | import copy | 
|  | from itertools import permutations | 
|  | import pickle | 
|  | from random import choice | 
|  | import sys | 
|  | from test import support | 
|  | import threading | 
|  | import time | 
|  | import typing | 
|  | import unittest | 
|  | import unittest.mock | 
|  | from weakref import proxy | 
|  | import contextlib | 
|  |  | 
|  | import functools | 
|  |  | 
|  | py_functools = support.import_fresh_module('functools', blocked=['_functools']) | 
|  | c_functools = support.import_fresh_module('functools', fresh=['_functools']) | 
|  |  | 
|  | decimal = support.import_fresh_module('decimal', fresh=['_decimal']) | 
|  |  | 
|  | @contextlib.contextmanager | 
|  | def replaced_module(name, replacement): | 
|  | original_module = sys.modules[name] | 
|  | sys.modules[name] = replacement | 
|  | try: | 
|  | yield | 
|  | finally: | 
|  | sys.modules[name] = original_module | 
|  |  | 
|  | def capture(*args, **kw): | 
|  | """capture all positional and keyword arguments""" | 
|  | return args, kw | 
|  |  | 
|  |  | 
|  | def signature(part): | 
|  | """ return the signature of a partial object """ | 
|  | return (part.func, part.args, part.keywords, part.__dict__) | 
|  |  | 
|  | class MyTuple(tuple): | 
|  | pass | 
|  |  | 
|  | class BadTuple(tuple): | 
|  | def __add__(self, other): | 
|  | return list(self) + list(other) | 
|  |  | 
|  | class MyDict(dict): | 
|  | pass | 
|  |  | 
|  |  | 
|  | class TestPartial: | 
|  |  | 
|  | def test_basic_examples(self): | 
|  | p = self.partial(capture, 1, 2, a=10, b=20) | 
|  | self.assertTrue(callable(p)) | 
|  | self.assertEqual(p(3, 4, b=30, c=40), | 
|  | ((1, 2, 3, 4), dict(a=10, b=30, c=40))) | 
|  | p = self.partial(map, lambda x: x*10) | 
|  | self.assertEqual(list(p([1,2,3,4])), [10, 20, 30, 40]) | 
|  |  | 
|  | def test_attributes(self): | 
|  | p = self.partial(capture, 1, 2, a=10, b=20) | 
|  | # attributes should be readable | 
|  | self.assertEqual(p.func, capture) | 
|  | self.assertEqual(p.args, (1, 2)) | 
|  | self.assertEqual(p.keywords, dict(a=10, b=20)) | 
|  |  | 
|  | def test_argument_checking(self): | 
|  | self.assertRaises(TypeError, self.partial)     # need at least a func arg | 
|  | try: | 
|  | self.partial(2)() | 
|  | except TypeError: | 
|  | pass | 
|  | else: | 
|  | self.fail('First arg not checked for callability') | 
|  |  | 
|  | def test_protection_of_callers_dict_argument(self): | 
|  | # a caller's dictionary should not be altered by partial | 
|  | def func(a=10, b=20): | 
|  | return a | 
|  | d = {'a':3} | 
|  | p = self.partial(func, a=5) | 
|  | self.assertEqual(p(**d), 3) | 
|  | self.assertEqual(d, {'a':3}) | 
|  | p(b=7) | 
|  | self.assertEqual(d, {'a':3}) | 
|  |  | 
|  | def test_kwargs_copy(self): | 
|  | # Issue #29532: Altering a kwarg dictionary passed to a constructor | 
|  | # should not affect a partial object after creation | 
|  | d = {'a': 3} | 
|  | p = self.partial(capture, **d) | 
|  | self.assertEqual(p(), ((), {'a': 3})) | 
|  | d['a'] = 5 | 
|  | self.assertEqual(p(), ((), {'a': 3})) | 
|  |  | 
|  | def test_arg_combinations(self): | 
|  | # exercise special code paths for zero args in either partial | 
|  | # object or the caller | 
|  | p = self.partial(capture) | 
|  | self.assertEqual(p(), ((), {})) | 
|  | self.assertEqual(p(1,2), ((1,2), {})) | 
|  | p = self.partial(capture, 1, 2) | 
|  | self.assertEqual(p(), ((1,2), {})) | 
|  | self.assertEqual(p(3,4), ((1,2,3,4), {})) | 
|  |  | 
|  | def test_kw_combinations(self): | 
|  | # exercise special code paths for no keyword args in | 
|  | # either the partial object or the caller | 
|  | p = self.partial(capture) | 
|  | self.assertEqual(p.keywords, {}) | 
|  | self.assertEqual(p(), ((), {})) | 
|  | self.assertEqual(p(a=1), ((), {'a':1})) | 
|  | p = self.partial(capture, a=1) | 
|  | self.assertEqual(p.keywords, {'a':1}) | 
|  | self.assertEqual(p(), ((), {'a':1})) | 
|  | self.assertEqual(p(b=2), ((), {'a':1, 'b':2})) | 
|  | # keyword args in the call override those in the partial object | 
|  | self.assertEqual(p(a=3, b=2), ((), {'a':3, 'b':2})) | 
|  |  | 
|  | def test_positional(self): | 
|  | # make sure positional arguments are captured correctly | 
|  | for args in [(), (0,), (0,1), (0,1,2), (0,1,2,3)]: | 
|  | p = self.partial(capture, *args) | 
|  | expected = args + ('x',) | 
|  | got, empty = p('x') | 
|  | self.assertTrue(expected == got and empty == {}) | 
|  |  | 
|  | def test_keyword(self): | 
|  | # make sure keyword arguments are captured correctly | 
|  | for a in ['a', 0, None, 3.5]: | 
|  | p = self.partial(capture, a=a) | 
|  | expected = {'a':a,'x':None} | 
|  | empty, got = p(x=None) | 
|  | self.assertTrue(expected == got and empty == ()) | 
|  |  | 
|  | def test_no_side_effects(self): | 
|  | # make sure there are no side effects that affect subsequent calls | 
|  | p = self.partial(capture, 0, a=1) | 
|  | args1, kw1 = p(1, b=2) | 
|  | self.assertTrue(args1 == (0,1) and kw1 == {'a':1,'b':2}) | 
|  | args2, kw2 = p() | 
|  | self.assertTrue(args2 == (0,) and kw2 == {'a':1}) | 
|  |  | 
|  | def test_error_propagation(self): | 
|  | def f(x, y): | 
|  | x / y | 
|  | self.assertRaises(ZeroDivisionError, self.partial(f, 1, 0)) | 
|  | self.assertRaises(ZeroDivisionError, self.partial(f, 1), 0) | 
|  | self.assertRaises(ZeroDivisionError, self.partial(f), 1, 0) | 
|  | self.assertRaises(ZeroDivisionError, self.partial(f, y=0), 1) | 
|  |  | 
|  | def test_weakref(self): | 
|  | f = self.partial(int, base=16) | 
|  | p = proxy(f) | 
|  | self.assertEqual(f.func, p.func) | 
|  | f = None | 
|  | self.assertRaises(ReferenceError, getattr, p, 'func') | 
|  |  | 
|  | def test_with_bound_and_unbound_methods(self): | 
|  | data = list(map(str, range(10))) | 
|  | join = self.partial(str.join, '') | 
|  | self.assertEqual(join(data), '0123456789') | 
|  | join = self.partial(''.join) | 
|  | self.assertEqual(join(data), '0123456789') | 
|  |  | 
|  | def test_nested_optimization(self): | 
|  | partial = self.partial | 
|  | inner = partial(signature, 'asdf') | 
|  | nested = partial(inner, bar=True) | 
|  | flat = partial(signature, 'asdf', bar=True) | 
|  | self.assertEqual(signature(nested), signature(flat)) | 
|  |  | 
|  | def test_nested_partial_with_attribute(self): | 
|  | # see issue 25137 | 
|  | partial = self.partial | 
|  |  | 
|  | def foo(bar): | 
|  | return bar | 
|  |  | 
|  | p = partial(foo, 'first') | 
|  | p2 = partial(p, 'second') | 
|  | p2.new_attr = 'spam' | 
|  | self.assertEqual(p2.new_attr, 'spam') | 
|  |  | 
|  | def test_repr(self): | 
|  | args = (object(), object()) | 
|  | args_repr = ', '.join(repr(a) for a in args) | 
|  | kwargs = {'a': object(), 'b': object()} | 
|  | kwargs_reprs = ['a={a!r}, b={b!r}'.format_map(kwargs), | 
|  | 'b={b!r}, a={a!r}'.format_map(kwargs)] | 
|  | if self.partial in (c_functools.partial, py_functools.partial): | 
|  | name = 'functools.partial' | 
|  | else: | 
|  | name = self.partial.__name__ | 
|  |  | 
|  | f = self.partial(capture) | 
|  | self.assertEqual(f'{name}({capture!r})', repr(f)) | 
|  |  | 
|  | f = self.partial(capture, *args) | 
|  | self.assertEqual(f'{name}({capture!r}, {args_repr})', repr(f)) | 
|  |  | 
|  | f = self.partial(capture, **kwargs) | 
|  | self.assertIn(repr(f), | 
|  | [f'{name}({capture!r}, {kwargs_repr})' | 
|  | for kwargs_repr in kwargs_reprs]) | 
|  |  | 
|  | f = self.partial(capture, *args, **kwargs) | 
|  | self.assertIn(repr(f), | 
|  | [f'{name}({capture!r}, {args_repr}, {kwargs_repr})' | 
|  | for kwargs_repr in kwargs_reprs]) | 
|  |  | 
|  | def test_recursive_repr(self): | 
|  | if self.partial in (c_functools.partial, py_functools.partial): | 
|  | name = 'functools.partial' | 
|  | else: | 
|  | name = self.partial.__name__ | 
|  |  | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((f, (), {}, {})) | 
|  | try: | 
|  | self.assertEqual(repr(f), '%s(...)' % (name,)) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((capture, (f,), {}, {})) | 
|  | try: | 
|  | self.assertEqual(repr(f), '%s(%r, ...)' % (name, capture,)) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((capture, (), {'a': f}, {})) | 
|  | try: | 
|  | self.assertEqual(repr(f), '%s(%r, a=...)' % (name, capture,)) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | def test_pickle(self): | 
|  | with self.AllowPickle(): | 
|  | f = self.partial(signature, ['asdf'], bar=[True]) | 
|  | f.attr = [] | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | f_copy = pickle.loads(pickle.dumps(f, proto)) | 
|  | self.assertEqual(signature(f_copy), signature(f)) | 
|  |  | 
|  | def test_copy(self): | 
|  | f = self.partial(signature, ['asdf'], bar=[True]) | 
|  | f.attr = [] | 
|  | f_copy = copy.copy(f) | 
|  | self.assertEqual(signature(f_copy), signature(f)) | 
|  | self.assertIs(f_copy.attr, f.attr) | 
|  | self.assertIs(f_copy.args, f.args) | 
|  | self.assertIs(f_copy.keywords, f.keywords) | 
|  |  | 
|  | def test_deepcopy(self): | 
|  | f = self.partial(signature, ['asdf'], bar=[True]) | 
|  | f.attr = [] | 
|  | f_copy = copy.deepcopy(f) | 
|  | self.assertEqual(signature(f_copy), signature(f)) | 
|  | self.assertIsNot(f_copy.attr, f.attr) | 
|  | self.assertIsNot(f_copy.args, f.args) | 
|  | self.assertIsNot(f_copy.args[0], f.args[0]) | 
|  | self.assertIsNot(f_copy.keywords, f.keywords) | 
|  | self.assertIsNot(f_copy.keywords['bar'], f.keywords['bar']) | 
|  |  | 
|  | def test_setstate(self): | 
|  | f = self.partial(signature) | 
|  | f.__setstate__((capture, (1,), dict(a=10), dict(attr=[]))) | 
|  |  | 
|  | self.assertEqual(signature(f), | 
|  | (capture, (1,), dict(a=10), dict(attr=[]))) | 
|  | self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) | 
|  |  | 
|  | f.__setstate__((capture, (1,), dict(a=10), None)) | 
|  |  | 
|  | self.assertEqual(signature(f), (capture, (1,), dict(a=10), {})) | 
|  | self.assertEqual(f(2, b=20), ((1, 2), {'a': 10, 'b': 20})) | 
|  |  | 
|  | f.__setstate__((capture, (1,), None, None)) | 
|  | #self.assertEqual(signature(f), (capture, (1,), {}, {})) | 
|  | self.assertEqual(f(2, b=20), ((1, 2), {'b': 20})) | 
|  | self.assertEqual(f(2), ((1, 2), {})) | 
|  | self.assertEqual(f(), ((1,), {})) | 
|  |  | 
|  | f.__setstate__((capture, (), {}, None)) | 
|  | self.assertEqual(signature(f), (capture, (), {}, {})) | 
|  | self.assertEqual(f(2, b=20), ((2,), {'b': 20})) | 
|  | self.assertEqual(f(2), ((2,), {})) | 
|  | self.assertEqual(f(), ((), {})) | 
|  |  | 
|  | def test_setstate_errors(self): | 
|  | f = self.partial(signature) | 
|  | self.assertRaises(TypeError, f.__setstate__, (capture, (), {})) | 
|  | self.assertRaises(TypeError, f.__setstate__, (capture, (), {}, {}, None)) | 
|  | self.assertRaises(TypeError, f.__setstate__, [capture, (), {}, None]) | 
|  | self.assertRaises(TypeError, f.__setstate__, (None, (), {}, None)) | 
|  | self.assertRaises(TypeError, f.__setstate__, (capture, None, {}, None)) | 
|  | self.assertRaises(TypeError, f.__setstate__, (capture, [], {}, None)) | 
|  | self.assertRaises(TypeError, f.__setstate__, (capture, (), [], None)) | 
|  |  | 
|  | def test_setstate_subclasses(self): | 
|  | f = self.partial(signature) | 
|  | f.__setstate__((capture, MyTuple((1,)), MyDict(a=10), None)) | 
|  | s = signature(f) | 
|  | self.assertEqual(s, (capture, (1,), dict(a=10), {})) | 
|  | self.assertIs(type(s[1]), tuple) | 
|  | self.assertIs(type(s[2]), dict) | 
|  | r = f() | 
|  | self.assertEqual(r, ((1,), {'a': 10})) | 
|  | self.assertIs(type(r[0]), tuple) | 
|  | self.assertIs(type(r[1]), dict) | 
|  |  | 
|  | f.__setstate__((capture, BadTuple((1,)), {}, None)) | 
|  | s = signature(f) | 
|  | self.assertEqual(s, (capture, (1,), {}, {})) | 
|  | self.assertIs(type(s[1]), tuple) | 
|  | r = f(2) | 
|  | self.assertEqual(r, ((1, 2), {})) | 
|  | self.assertIs(type(r[0]), tuple) | 
|  |  | 
|  | def test_recursive_pickle(self): | 
|  | with self.AllowPickle(): | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((f, (), {}, {})) | 
|  | try: | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | with self.assertRaises(RecursionError): | 
|  | pickle.dumps(f, proto) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((capture, (f,), {}, {})) | 
|  | try: | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | f_copy = pickle.loads(pickle.dumps(f, proto)) | 
|  | try: | 
|  | self.assertIs(f_copy.args[0], f_copy) | 
|  | finally: | 
|  | f_copy.__setstate__((capture, (), {}, {})) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | f = self.partial(capture) | 
|  | f.__setstate__((capture, (), {'a': f}, {})) | 
|  | try: | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | f_copy = pickle.loads(pickle.dumps(f, proto)) | 
|  | try: | 
|  | self.assertIs(f_copy.keywords['a'], f_copy) | 
|  | finally: | 
|  | f_copy.__setstate__((capture, (), {}, {})) | 
|  | finally: | 
|  | f.__setstate__((capture, (), {}, {})) | 
|  |  | 
|  | # Issue 6083: Reference counting bug | 
|  | def test_setstate_refcount(self): | 
|  | class BadSequence: | 
|  | def __len__(self): | 
|  | return 4 | 
|  | def __getitem__(self, key): | 
|  | if key == 0: | 
|  | return max | 
|  | elif key == 1: | 
|  | return tuple(range(1000000)) | 
|  | elif key in (2, 3): | 
|  | return {} | 
|  | raise IndexError | 
|  |  | 
|  | f = self.partial(object) | 
|  | self.assertRaises(TypeError, f.__setstate__, BadSequence()) | 
|  |  | 
|  | @unittest.skipUnless(c_functools, 'requires the C _functools module') | 
|  | class TestPartialC(TestPartial, unittest.TestCase): | 
|  | if c_functools: | 
|  | partial = c_functools.partial | 
|  |  | 
|  | class AllowPickle: | 
|  | def __enter__(self): | 
|  | return self | 
|  | def __exit__(self, type, value, tb): | 
|  | return False | 
|  |  | 
|  | def test_attributes_unwritable(self): | 
|  | # attributes should not be writable | 
|  | p = self.partial(capture, 1, 2, a=10, b=20) | 
|  | self.assertRaises(AttributeError, setattr, p, 'func', map) | 
|  | self.assertRaises(AttributeError, setattr, p, 'args', (1, 2)) | 
|  | self.assertRaises(AttributeError, setattr, p, 'keywords', dict(a=1, b=2)) | 
|  |  | 
|  | p = self.partial(hex) | 
|  | try: | 
|  | del p.__dict__ | 
|  | except TypeError: | 
|  | pass | 
|  | else: | 
|  | self.fail('partial object allowed __dict__ to be deleted') | 
|  |  | 
|  | def test_manually_adding_non_string_keyword(self): | 
|  | p = self.partial(capture) | 
|  | # Adding a non-string/unicode keyword to partial kwargs | 
|  | p.keywords[1234] = 'value' | 
|  | r = repr(p) | 
|  | self.assertIn('1234', r) | 
|  | self.assertIn("'value'", r) | 
|  | with self.assertRaises(TypeError): | 
|  | p() | 
|  |  | 
|  | def test_keystr_replaces_value(self): | 
|  | p = self.partial(capture) | 
|  |  | 
|  | class MutatesYourDict(object): | 
|  | def __str__(self): | 
|  | p.keywords[self] = ['sth2'] | 
|  | return 'astr' | 
|  |  | 
|  | # Replacing the value during key formatting should keep the original | 
|  | # value alive (at least long enough). | 
|  | p.keywords[MutatesYourDict()] = ['sth'] | 
|  | r = repr(p) | 
|  | self.assertIn('astr', r) | 
|  | self.assertIn("['sth']", r) | 
|  |  | 
|  |  | 
|  | class TestPartialPy(TestPartial, unittest.TestCase): | 
|  | partial = py_functools.partial | 
|  |  | 
|  | class AllowPickle: | 
|  | def __init__(self): | 
|  | self._cm = replaced_module("functools", py_functools) | 
|  | def __enter__(self): | 
|  | return self._cm.__enter__() | 
|  | def __exit__(self, type, value, tb): | 
|  | return self._cm.__exit__(type, value, tb) | 
|  |  | 
|  | if c_functools: | 
|  | class CPartialSubclass(c_functools.partial): | 
|  | pass | 
|  |  | 
|  | class PyPartialSubclass(py_functools.partial): | 
|  | pass | 
|  |  | 
|  | @unittest.skipUnless(c_functools, 'requires the C _functools module') | 
|  | class TestPartialCSubclass(TestPartialC): | 
|  | if c_functools: | 
|  | partial = CPartialSubclass | 
|  |  | 
|  | # partial subclasses are not optimized for nested calls | 
|  | test_nested_optimization = None | 
|  |  | 
|  | class TestPartialPySubclass(TestPartialPy): | 
|  | partial = PyPartialSubclass | 
|  |  | 
|  | class TestPartialMethod(unittest.TestCase): | 
|  |  | 
|  | class A(object): | 
|  | nothing = functools.partialmethod(capture) | 
|  | positional = functools.partialmethod(capture, 1) | 
|  | keywords = functools.partialmethod(capture, a=2) | 
|  | both = functools.partialmethod(capture, 3, b=4) | 
|  |  | 
|  | nested = functools.partialmethod(positional, 5) | 
|  |  | 
|  | over_partial = functools.partialmethod(functools.partial(capture, c=6), 7) | 
|  |  | 
|  | static = functools.partialmethod(staticmethod(capture), 8) | 
|  | cls = functools.partialmethod(classmethod(capture), d=9) | 
|  |  | 
|  | a = A() | 
|  |  | 
|  | def test_arg_combinations(self): | 
|  | self.assertEqual(self.a.nothing(), ((self.a,), {})) | 
|  | self.assertEqual(self.a.nothing(5), ((self.a, 5), {})) | 
|  | self.assertEqual(self.a.nothing(c=6), ((self.a,), {'c': 6})) | 
|  | self.assertEqual(self.a.nothing(5, c=6), ((self.a, 5), {'c': 6})) | 
|  |  | 
|  | self.assertEqual(self.a.positional(), ((self.a, 1), {})) | 
|  | self.assertEqual(self.a.positional(5), ((self.a, 1, 5), {})) | 
|  | self.assertEqual(self.a.positional(c=6), ((self.a, 1), {'c': 6})) | 
|  | self.assertEqual(self.a.positional(5, c=6), ((self.a, 1, 5), {'c': 6})) | 
|  |  | 
|  | self.assertEqual(self.a.keywords(), ((self.a,), {'a': 2})) | 
|  | self.assertEqual(self.a.keywords(5), ((self.a, 5), {'a': 2})) | 
|  | self.assertEqual(self.a.keywords(c=6), ((self.a,), {'a': 2, 'c': 6})) | 
|  | self.assertEqual(self.a.keywords(5, c=6), ((self.a, 5), {'a': 2, 'c': 6})) | 
|  |  | 
|  | self.assertEqual(self.a.both(), ((self.a, 3), {'b': 4})) | 
|  | self.assertEqual(self.a.both(5), ((self.a, 3, 5), {'b': 4})) | 
|  | self.assertEqual(self.a.both(c=6), ((self.a, 3), {'b': 4, 'c': 6})) | 
|  | self.assertEqual(self.a.both(5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) | 
|  |  | 
|  | self.assertEqual(self.A.both(self.a, 5, c=6), ((self.a, 3, 5), {'b': 4, 'c': 6})) | 
|  |  | 
|  | def test_nested(self): | 
|  | self.assertEqual(self.a.nested(), ((self.a, 1, 5), {})) | 
|  | self.assertEqual(self.a.nested(6), ((self.a, 1, 5, 6), {})) | 
|  | self.assertEqual(self.a.nested(d=7), ((self.a, 1, 5), {'d': 7})) | 
|  | self.assertEqual(self.a.nested(6, d=7), ((self.a, 1, 5, 6), {'d': 7})) | 
|  |  | 
|  | self.assertEqual(self.A.nested(self.a, 6, d=7), ((self.a, 1, 5, 6), {'d': 7})) | 
|  |  | 
|  | def test_over_partial(self): | 
|  | self.assertEqual(self.a.over_partial(), ((self.a, 7), {'c': 6})) | 
|  | self.assertEqual(self.a.over_partial(5), ((self.a, 7, 5), {'c': 6})) | 
|  | self.assertEqual(self.a.over_partial(d=8), ((self.a, 7), {'c': 6, 'd': 8})) | 
|  | self.assertEqual(self.a.over_partial(5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) | 
|  |  | 
|  | self.assertEqual(self.A.over_partial(self.a, 5, d=8), ((self.a, 7, 5), {'c': 6, 'd': 8})) | 
|  |  | 
|  | def test_bound_method_introspection(self): | 
|  | obj = self.a | 
|  | self.assertIs(obj.both.__self__, obj) | 
|  | self.assertIs(obj.nested.__self__, obj) | 
|  | self.assertIs(obj.over_partial.__self__, obj) | 
|  | self.assertIs(obj.cls.__self__, self.A) | 
|  | self.assertIs(self.A.cls.__self__, self.A) | 
|  |  | 
|  | def test_unbound_method_retrieval(self): | 
|  | obj = self.A | 
|  | self.assertFalse(hasattr(obj.both, "__self__")) | 
|  | self.assertFalse(hasattr(obj.nested, "__self__")) | 
|  | self.assertFalse(hasattr(obj.over_partial, "__self__")) | 
|  | self.assertFalse(hasattr(obj.static, "__self__")) | 
|  | self.assertFalse(hasattr(self.a.static, "__self__")) | 
|  |  | 
|  | def test_descriptors(self): | 
|  | for obj in [self.A, self.a]: | 
|  | with self.subTest(obj=obj): | 
|  | self.assertEqual(obj.static(), ((8,), {})) | 
|  | self.assertEqual(obj.static(5), ((8, 5), {})) | 
|  | self.assertEqual(obj.static(d=8), ((8,), {'d': 8})) | 
|  | self.assertEqual(obj.static(5, d=8), ((8, 5), {'d': 8})) | 
|  |  | 
|  | self.assertEqual(obj.cls(), ((self.A,), {'d': 9})) | 
|  | self.assertEqual(obj.cls(5), ((self.A, 5), {'d': 9})) | 
|  | self.assertEqual(obj.cls(c=8), ((self.A,), {'c': 8, 'd': 9})) | 
|  | self.assertEqual(obj.cls(5, c=8), ((self.A, 5), {'c': 8, 'd': 9})) | 
|  |  | 
|  | def test_overriding_keywords(self): | 
|  | self.assertEqual(self.a.keywords(a=3), ((self.a,), {'a': 3})) | 
|  | self.assertEqual(self.A.keywords(self.a, a=3), ((self.a,), {'a': 3})) | 
|  |  | 
|  | def test_invalid_args(self): | 
|  | with self.assertRaises(TypeError): | 
|  | class B(object): | 
|  | method = functools.partialmethod(None, 1) | 
|  |  | 
|  | def test_repr(self): | 
|  | self.assertEqual(repr(vars(self.A)['both']), | 
|  | 'functools.partialmethod({}, 3, b=4)'.format(capture)) | 
|  |  | 
|  | def test_abstract(self): | 
|  | class Abstract(abc.ABCMeta): | 
|  |  | 
|  | @abc.abstractmethod | 
|  | def add(self, x, y): | 
|  | pass | 
|  |  | 
|  | add5 = functools.partialmethod(add, 5) | 
|  |  | 
|  | self.assertTrue(Abstract.add.__isabstractmethod__) | 
|  | self.assertTrue(Abstract.add5.__isabstractmethod__) | 
|  |  | 
|  | for func in [self.A.static, self.A.cls, self.A.over_partial, self.A.nested, self.A.both]: | 
|  | self.assertFalse(getattr(func, '__isabstractmethod__', False)) | 
|  |  | 
|  |  | 
|  | class TestUpdateWrapper(unittest.TestCase): | 
|  |  | 
|  | def check_wrapper(self, wrapper, wrapped, | 
|  | assigned=functools.WRAPPER_ASSIGNMENTS, | 
|  | updated=functools.WRAPPER_UPDATES): | 
|  | # Check attributes were assigned | 
|  | for name in assigned: | 
|  | self.assertIs(getattr(wrapper, name), getattr(wrapped, name)) | 
|  | # Check attributes were updated | 
|  | for name in updated: | 
|  | wrapper_attr = getattr(wrapper, name) | 
|  | wrapped_attr = getattr(wrapped, name) | 
|  | for key in wrapped_attr: | 
|  | if name == "__dict__" and key == "__wrapped__": | 
|  | # __wrapped__ is overwritten by the update code | 
|  | continue | 
|  | self.assertIs(wrapped_attr[key], wrapper_attr[key]) | 
|  | # Check __wrapped__ | 
|  | self.assertIs(wrapper.__wrapped__, wrapped) | 
|  |  | 
|  |  | 
|  | def _default_update(self): | 
|  | def f(a:'This is a new annotation'): | 
|  | """This is a test""" | 
|  | pass | 
|  | f.attr = 'This is also a test' | 
|  | f.__wrapped__ = "This is a bald faced lie" | 
|  | def wrapper(b:'This is the prior annotation'): | 
|  | pass | 
|  | functools.update_wrapper(wrapper, f) | 
|  | return wrapper, f | 
|  |  | 
|  | def test_default_update(self): | 
|  | wrapper, f = self._default_update() | 
|  | self.check_wrapper(wrapper, f) | 
|  | self.assertIs(wrapper.__wrapped__, f) | 
|  | self.assertEqual(wrapper.__name__, 'f') | 
|  | self.assertEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.attr, 'This is also a test') | 
|  | self.assertEqual(wrapper.__annotations__['a'], 'This is a new annotation') | 
|  | self.assertNotIn('b', wrapper.__annotations__) | 
|  |  | 
|  | @unittest.skipIf(sys.flags.optimize >= 2, | 
|  | "Docstrings are omitted with -O2 and above") | 
|  | def test_default_update_doc(self): | 
|  | wrapper, f = self._default_update() | 
|  | self.assertEqual(wrapper.__doc__, 'This is a test') | 
|  |  | 
|  | def test_no_update(self): | 
|  | def f(): | 
|  | """This is a test""" | 
|  | pass | 
|  | f.attr = 'This is also a test' | 
|  | def wrapper(): | 
|  | pass | 
|  | functools.update_wrapper(wrapper, f, (), ()) | 
|  | self.check_wrapper(wrapper, f, (), ()) | 
|  | self.assertEqual(wrapper.__name__, 'wrapper') | 
|  | self.assertNotEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.__doc__, None) | 
|  | self.assertEqual(wrapper.__annotations__, {}) | 
|  | self.assertFalse(hasattr(wrapper, 'attr')) | 
|  |  | 
|  | def test_selective_update(self): | 
|  | def f(): | 
|  | pass | 
|  | f.attr = 'This is a different test' | 
|  | f.dict_attr = dict(a=1, b=2, c=3) | 
|  | def wrapper(): | 
|  | pass | 
|  | wrapper.dict_attr = {} | 
|  | assign = ('attr',) | 
|  | update = ('dict_attr',) | 
|  | functools.update_wrapper(wrapper, f, assign, update) | 
|  | self.check_wrapper(wrapper, f, assign, update) | 
|  | self.assertEqual(wrapper.__name__, 'wrapper') | 
|  | self.assertNotEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.__doc__, None) | 
|  | self.assertEqual(wrapper.attr, 'This is a different test') | 
|  | self.assertEqual(wrapper.dict_attr, f.dict_attr) | 
|  |  | 
|  | def test_missing_attributes(self): | 
|  | def f(): | 
|  | pass | 
|  | def wrapper(): | 
|  | pass | 
|  | wrapper.dict_attr = {} | 
|  | assign = ('attr',) | 
|  | update = ('dict_attr',) | 
|  | # Missing attributes on wrapped object are ignored | 
|  | functools.update_wrapper(wrapper, f, assign, update) | 
|  | self.assertNotIn('attr', wrapper.__dict__) | 
|  | self.assertEqual(wrapper.dict_attr, {}) | 
|  | # Wrapper must have expected attributes for updating | 
|  | del wrapper.dict_attr | 
|  | with self.assertRaises(AttributeError): | 
|  | functools.update_wrapper(wrapper, f, assign, update) | 
|  | wrapper.dict_attr = 1 | 
|  | with self.assertRaises(AttributeError): | 
|  | functools.update_wrapper(wrapper, f, assign, update) | 
|  |  | 
|  | @support.requires_docstrings | 
|  | @unittest.skipIf(sys.flags.optimize >= 2, | 
|  | "Docstrings are omitted with -O2 and above") | 
|  | def test_builtin_update(self): | 
|  | # Test for bug #1576241 | 
|  | def wrapper(): | 
|  | pass | 
|  | functools.update_wrapper(wrapper, max) | 
|  | self.assertEqual(wrapper.__name__, 'max') | 
|  | self.assertTrue(wrapper.__doc__.startswith('max(')) | 
|  | self.assertEqual(wrapper.__annotations__, {}) | 
|  |  | 
|  |  | 
|  | class TestWraps(TestUpdateWrapper): | 
|  |  | 
|  | def _default_update(self): | 
|  | def f(): | 
|  | """This is a test""" | 
|  | pass | 
|  | f.attr = 'This is also a test' | 
|  | f.__wrapped__ = "This is still a bald faced lie" | 
|  | @functools.wraps(f) | 
|  | def wrapper(): | 
|  | pass | 
|  | return wrapper, f | 
|  |  | 
|  | def test_default_update(self): | 
|  | wrapper, f = self._default_update() | 
|  | self.check_wrapper(wrapper, f) | 
|  | self.assertEqual(wrapper.__name__, 'f') | 
|  | self.assertEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.attr, 'This is also a test') | 
|  |  | 
|  | @unittest.skipIf(sys.flags.optimize >= 2, | 
|  | "Docstrings are omitted with -O2 and above") | 
|  | def test_default_update_doc(self): | 
|  | wrapper, _ = self._default_update() | 
|  | self.assertEqual(wrapper.__doc__, 'This is a test') | 
|  |  | 
|  | def test_no_update(self): | 
|  | def f(): | 
|  | """This is a test""" | 
|  | pass | 
|  | f.attr = 'This is also a test' | 
|  | @functools.wraps(f, (), ()) | 
|  | def wrapper(): | 
|  | pass | 
|  | self.check_wrapper(wrapper, f, (), ()) | 
|  | self.assertEqual(wrapper.__name__, 'wrapper') | 
|  | self.assertNotEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.__doc__, None) | 
|  | self.assertFalse(hasattr(wrapper, 'attr')) | 
|  |  | 
|  | def test_selective_update(self): | 
|  | def f(): | 
|  | pass | 
|  | f.attr = 'This is a different test' | 
|  | f.dict_attr = dict(a=1, b=2, c=3) | 
|  | def add_dict_attr(f): | 
|  | f.dict_attr = {} | 
|  | return f | 
|  | assign = ('attr',) | 
|  | update = ('dict_attr',) | 
|  | @functools.wraps(f, assign, update) | 
|  | @add_dict_attr | 
|  | def wrapper(): | 
|  | pass | 
|  | self.check_wrapper(wrapper, f, assign, update) | 
|  | self.assertEqual(wrapper.__name__, 'wrapper') | 
|  | self.assertNotEqual(wrapper.__qualname__, f.__qualname__) | 
|  | self.assertEqual(wrapper.__doc__, None) | 
|  | self.assertEqual(wrapper.attr, 'This is a different test') | 
|  | self.assertEqual(wrapper.dict_attr, f.dict_attr) | 
|  |  | 
|  |  | 
|  | class TestReduce: | 
|  | def test_reduce(self): | 
|  | class Squares: | 
|  | def __init__(self, max): | 
|  | self.max = max | 
|  | self.sofar = [] | 
|  |  | 
|  | def __len__(self): | 
|  | return len(self.sofar) | 
|  |  | 
|  | def __getitem__(self, i): | 
|  | if not 0 <= i < self.max: raise IndexError | 
|  | n = len(self.sofar) | 
|  | while n <= i: | 
|  | self.sofar.append(n*n) | 
|  | n += 1 | 
|  | return self.sofar[i] | 
|  | def add(x, y): | 
|  | return x + y | 
|  | self.assertEqual(self.reduce(add, ['a', 'b', 'c'], ''), 'abc') | 
|  | self.assertEqual( | 
|  | self.reduce(add, [['a', 'c'], [], ['d', 'w']], []), | 
|  | ['a','c','d','w'] | 
|  | ) | 
|  | self.assertEqual(self.reduce(lambda x, y: x*y, range(2,8), 1), 5040) | 
|  | self.assertEqual( | 
|  | self.reduce(lambda x, y: x*y, range(2,21), 1), | 
|  | 2432902008176640000 | 
|  | ) | 
|  | self.assertEqual(self.reduce(add, Squares(10)), 285) | 
|  | self.assertEqual(self.reduce(add, Squares(10), 0), 285) | 
|  | self.assertEqual(self.reduce(add, Squares(0), 0), 0) | 
|  | self.assertRaises(TypeError, self.reduce) | 
|  | self.assertRaises(TypeError, self.reduce, 42, 42) | 
|  | self.assertRaises(TypeError, self.reduce, 42, 42, 42) | 
|  | self.assertEqual(self.reduce(42, "1"), "1") # func is never called with one item | 
|  | self.assertEqual(self.reduce(42, "", "1"), "1") # func is never called with one item | 
|  | self.assertRaises(TypeError, self.reduce, 42, (42, 42)) | 
|  | self.assertRaises(TypeError, self.reduce, add, []) # arg 2 must not be empty sequence with no initial value | 
|  | self.assertRaises(TypeError, self.reduce, add, "") | 
|  | self.assertRaises(TypeError, self.reduce, add, ()) | 
|  | self.assertRaises(TypeError, self.reduce, add, object()) | 
|  |  | 
|  | class TestFailingIter: | 
|  | def __iter__(self): | 
|  | raise RuntimeError | 
|  | self.assertRaises(RuntimeError, self.reduce, add, TestFailingIter()) | 
|  |  | 
|  | self.assertEqual(self.reduce(add, [], None), None) | 
|  | self.assertEqual(self.reduce(add, [], 42), 42) | 
|  |  | 
|  | class BadSeq: | 
|  | def __getitem__(self, index): | 
|  | raise ValueError | 
|  | self.assertRaises(ValueError, self.reduce, 42, BadSeq()) | 
|  |  | 
|  | # Test reduce()'s use of iterators. | 
|  | def test_iterator_usage(self): | 
|  | class SequenceClass: | 
|  | def __init__(self, n): | 
|  | self.n = n | 
|  | def __getitem__(self, i): | 
|  | if 0 <= i < self.n: | 
|  | return i | 
|  | else: | 
|  | raise IndexError | 
|  |  | 
|  | from operator import add | 
|  | self.assertEqual(self.reduce(add, SequenceClass(5)), 10) | 
|  | self.assertEqual(self.reduce(add, SequenceClass(5), 42), 52) | 
|  | self.assertRaises(TypeError, self.reduce, add, SequenceClass(0)) | 
|  | self.assertEqual(self.reduce(add, SequenceClass(0), 42), 42) | 
|  | self.assertEqual(self.reduce(add, SequenceClass(1)), 0) | 
|  | self.assertEqual(self.reduce(add, SequenceClass(1), 42), 42) | 
|  |  | 
|  | d = {"one": 1, "two": 2, "three": 3} | 
|  | self.assertEqual(self.reduce(add, d), "".join(d.keys())) | 
|  |  | 
|  |  | 
|  | @unittest.skipUnless(c_functools, 'requires the C _functools module') | 
|  | class TestReduceC(TestReduce, unittest.TestCase): | 
|  | if c_functools: | 
|  | reduce = c_functools.reduce | 
|  |  | 
|  |  | 
|  | class TestReducePy(TestReduce, unittest.TestCase): | 
|  | reduce = staticmethod(py_functools.reduce) | 
|  |  | 
|  |  | 
|  | class TestCmpToKey: | 
|  |  | 
|  | def test_cmp_to_key(self): | 
|  | def cmp1(x, y): | 
|  | return (x > y) - (x < y) | 
|  | key = self.cmp_to_key(cmp1) | 
|  | self.assertEqual(key(3), key(3)) | 
|  | self.assertGreater(key(3), key(1)) | 
|  | self.assertGreaterEqual(key(3), key(3)) | 
|  |  | 
|  | def cmp2(x, y): | 
|  | return int(x) - int(y) | 
|  | key = self.cmp_to_key(cmp2) | 
|  | self.assertEqual(key(4.0), key('4')) | 
|  | self.assertLess(key(2), key('35')) | 
|  | self.assertLessEqual(key(2), key('35')) | 
|  | self.assertNotEqual(key(2), key('35')) | 
|  |  | 
|  | def test_cmp_to_key_arguments(self): | 
|  | def cmp1(x, y): | 
|  | return (x > y) - (x < y) | 
|  | key = self.cmp_to_key(mycmp=cmp1) | 
|  | self.assertEqual(key(obj=3), key(obj=3)) | 
|  | self.assertGreater(key(obj=3), key(obj=1)) | 
|  | with self.assertRaises((TypeError, AttributeError)): | 
|  | key(3) > 1    # rhs is not a K object | 
|  | with self.assertRaises((TypeError, AttributeError)): | 
|  | 1 < key(3)    # lhs is not a K object | 
|  | with self.assertRaises(TypeError): | 
|  | key = self.cmp_to_key()             # too few args | 
|  | with self.assertRaises(TypeError): | 
|  | key = self.cmp_to_key(cmp1, None)   # too many args | 
|  | key = self.cmp_to_key(cmp1) | 
|  | with self.assertRaises(TypeError): | 
|  | key()                                    # too few args | 
|  | with self.assertRaises(TypeError): | 
|  | key(None, None)                          # too many args | 
|  |  | 
|  | def test_bad_cmp(self): | 
|  | def cmp1(x, y): | 
|  | raise ZeroDivisionError | 
|  | key = self.cmp_to_key(cmp1) | 
|  | with self.assertRaises(ZeroDivisionError): | 
|  | key(3) > key(1) | 
|  |  | 
|  | class BadCmp: | 
|  | def __lt__(self, other): | 
|  | raise ZeroDivisionError | 
|  | def cmp1(x, y): | 
|  | return BadCmp() | 
|  | with self.assertRaises(ZeroDivisionError): | 
|  | key(3) > key(1) | 
|  |  | 
|  | def test_obj_field(self): | 
|  | def cmp1(x, y): | 
|  | return (x > y) - (x < y) | 
|  | key = self.cmp_to_key(mycmp=cmp1) | 
|  | self.assertEqual(key(50).obj, 50) | 
|  |  | 
|  | def test_sort_int(self): | 
|  | def mycmp(x, y): | 
|  | return y - x | 
|  | self.assertEqual(sorted(range(5), key=self.cmp_to_key(mycmp)), | 
|  | [4, 3, 2, 1, 0]) | 
|  |  | 
|  | def test_sort_int_str(self): | 
|  | def mycmp(x, y): | 
|  | x, y = int(x), int(y) | 
|  | return (x > y) - (x < y) | 
|  | values = [5, '3', 7, 2, '0', '1', 4, '10', 1] | 
|  | values = sorted(values, key=self.cmp_to_key(mycmp)) | 
|  | self.assertEqual([int(value) for value in values], | 
|  | [0, 1, 1, 2, 3, 4, 5, 7, 10]) | 
|  |  | 
|  | def test_hash(self): | 
|  | def mycmp(x, y): | 
|  | return y - x | 
|  | key = self.cmp_to_key(mycmp) | 
|  | k = key(10) | 
|  | self.assertRaises(TypeError, hash, k) | 
|  | self.assertNotIsInstance(k, collections.abc.Hashable) | 
|  |  | 
|  |  | 
|  | @unittest.skipUnless(c_functools, 'requires the C _functools module') | 
|  | class TestCmpToKeyC(TestCmpToKey, unittest.TestCase): | 
|  | if c_functools: | 
|  | cmp_to_key = c_functools.cmp_to_key | 
|  |  | 
|  |  | 
|  | class TestCmpToKeyPy(TestCmpToKey, unittest.TestCase): | 
|  | cmp_to_key = staticmethod(py_functools.cmp_to_key) | 
|  |  | 
|  |  | 
|  | class TestTotalOrdering(unittest.TestCase): | 
|  |  | 
|  | def test_total_ordering_lt(self): | 
|  | @functools.total_ordering | 
|  | class A: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __lt__(self, other): | 
|  | return self.value < other.value | 
|  | def __eq__(self, other): | 
|  | return self.value == other.value | 
|  | self.assertTrue(A(1) < A(2)) | 
|  | self.assertTrue(A(2) > A(1)) | 
|  | self.assertTrue(A(1) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(1)) | 
|  | self.assertTrue(A(2) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(2)) | 
|  | self.assertFalse(A(1) > A(2)) | 
|  |  | 
|  | def test_total_ordering_le(self): | 
|  | @functools.total_ordering | 
|  | class A: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __le__(self, other): | 
|  | return self.value <= other.value | 
|  | def __eq__(self, other): | 
|  | return self.value == other.value | 
|  | self.assertTrue(A(1) < A(2)) | 
|  | self.assertTrue(A(2) > A(1)) | 
|  | self.assertTrue(A(1) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(1)) | 
|  | self.assertTrue(A(2) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(2)) | 
|  | self.assertFalse(A(1) >= A(2)) | 
|  |  | 
|  | def test_total_ordering_gt(self): | 
|  | @functools.total_ordering | 
|  | class A: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __gt__(self, other): | 
|  | return self.value > other.value | 
|  | def __eq__(self, other): | 
|  | return self.value == other.value | 
|  | self.assertTrue(A(1) < A(2)) | 
|  | self.assertTrue(A(2) > A(1)) | 
|  | self.assertTrue(A(1) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(1)) | 
|  | self.assertTrue(A(2) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(2)) | 
|  | self.assertFalse(A(2) < A(1)) | 
|  |  | 
|  | def test_total_ordering_ge(self): | 
|  | @functools.total_ordering | 
|  | class A: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __ge__(self, other): | 
|  | return self.value >= other.value | 
|  | def __eq__(self, other): | 
|  | return self.value == other.value | 
|  | self.assertTrue(A(1) < A(2)) | 
|  | self.assertTrue(A(2) > A(1)) | 
|  | self.assertTrue(A(1) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(1)) | 
|  | self.assertTrue(A(2) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(2)) | 
|  | self.assertFalse(A(2) <= A(1)) | 
|  |  | 
|  | def test_total_ordering_no_overwrite(self): | 
|  | # new methods should not overwrite existing | 
|  | @functools.total_ordering | 
|  | class A(int): | 
|  | pass | 
|  | self.assertTrue(A(1) < A(2)) | 
|  | self.assertTrue(A(2) > A(1)) | 
|  | self.assertTrue(A(1) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(1)) | 
|  | self.assertTrue(A(2) <= A(2)) | 
|  | self.assertTrue(A(2) >= A(2)) | 
|  |  | 
|  | def test_no_operations_defined(self): | 
|  | with self.assertRaises(ValueError): | 
|  | @functools.total_ordering | 
|  | class A: | 
|  | pass | 
|  |  | 
|  | def test_type_error_when_not_implemented(self): | 
|  | # bug 10042; ensure stack overflow does not occur | 
|  | # when decorated types return NotImplemented | 
|  | @functools.total_ordering | 
|  | class ImplementsLessThan: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, ImplementsLessThan): | 
|  | return self.value == other.value | 
|  | return False | 
|  | def __lt__(self, other): | 
|  | if isinstance(other, ImplementsLessThan): | 
|  | return self.value < other.value | 
|  | return NotImplemented | 
|  |  | 
|  | @functools.total_ordering | 
|  | class ImplementsGreaterThan: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, ImplementsGreaterThan): | 
|  | return self.value == other.value | 
|  | return False | 
|  | def __gt__(self, other): | 
|  | if isinstance(other, ImplementsGreaterThan): | 
|  | return self.value > other.value | 
|  | return NotImplemented | 
|  |  | 
|  | @functools.total_ordering | 
|  | class ImplementsLessThanEqualTo: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, ImplementsLessThanEqualTo): | 
|  | return self.value == other.value | 
|  | return False | 
|  | def __le__(self, other): | 
|  | if isinstance(other, ImplementsLessThanEqualTo): | 
|  | return self.value <= other.value | 
|  | return NotImplemented | 
|  |  | 
|  | @functools.total_ordering | 
|  | class ImplementsGreaterThanEqualTo: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, ImplementsGreaterThanEqualTo): | 
|  | return self.value == other.value | 
|  | return False | 
|  | def __ge__(self, other): | 
|  | if isinstance(other, ImplementsGreaterThanEqualTo): | 
|  | return self.value >= other.value | 
|  | return NotImplemented | 
|  |  | 
|  | @functools.total_ordering | 
|  | class ComparatorNotImplemented: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __eq__(self, other): | 
|  | if isinstance(other, ComparatorNotImplemented): | 
|  | return self.value == other.value | 
|  | return False | 
|  | def __lt__(self, other): | 
|  | return NotImplemented | 
|  |  | 
|  | with self.subTest("LT < 1"), self.assertRaises(TypeError): | 
|  | ImplementsLessThan(-1) < 1 | 
|  |  | 
|  | with self.subTest("LT < LE"), self.assertRaises(TypeError): | 
|  | ImplementsLessThan(0) < ImplementsLessThanEqualTo(0) | 
|  |  | 
|  | with self.subTest("LT < GT"), self.assertRaises(TypeError): | 
|  | ImplementsLessThan(1) < ImplementsGreaterThan(1) | 
|  |  | 
|  | with self.subTest("LE <= LT"), self.assertRaises(TypeError): | 
|  | ImplementsLessThanEqualTo(2) <= ImplementsLessThan(2) | 
|  |  | 
|  | with self.subTest("LE <= GE"), self.assertRaises(TypeError): | 
|  | ImplementsLessThanEqualTo(3) <= ImplementsGreaterThanEqualTo(3) | 
|  |  | 
|  | with self.subTest("GT > GE"), self.assertRaises(TypeError): | 
|  | ImplementsGreaterThan(4) > ImplementsGreaterThanEqualTo(4) | 
|  |  | 
|  | with self.subTest("GT > LT"), self.assertRaises(TypeError): | 
|  | ImplementsGreaterThan(5) > ImplementsLessThan(5) | 
|  |  | 
|  | with self.subTest("GE >= GT"), self.assertRaises(TypeError): | 
|  | ImplementsGreaterThanEqualTo(6) >= ImplementsGreaterThan(6) | 
|  |  | 
|  | with self.subTest("GE >= LE"), self.assertRaises(TypeError): | 
|  | ImplementsGreaterThanEqualTo(7) >= ImplementsLessThanEqualTo(7) | 
|  |  | 
|  | with self.subTest("GE when equal"): | 
|  | a = ComparatorNotImplemented(8) | 
|  | b = ComparatorNotImplemented(8) | 
|  | self.assertEqual(a, b) | 
|  | with self.assertRaises(TypeError): | 
|  | a >= b | 
|  |  | 
|  | with self.subTest("LE when equal"): | 
|  | a = ComparatorNotImplemented(9) | 
|  | b = ComparatorNotImplemented(9) | 
|  | self.assertEqual(a, b) | 
|  | with self.assertRaises(TypeError): | 
|  | a <= b | 
|  |  | 
|  | def test_pickle(self): | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | for name in '__lt__', '__gt__', '__le__', '__ge__': | 
|  | with self.subTest(method=name, proto=proto): | 
|  | method = getattr(Orderable_LT, name) | 
|  | method_copy = pickle.loads(pickle.dumps(method, proto)) | 
|  | self.assertIs(method_copy, method) | 
|  |  | 
|  | @functools.total_ordering | 
|  | class Orderable_LT: | 
|  | def __init__(self, value): | 
|  | self.value = value | 
|  | def __lt__(self, other): | 
|  | return self.value < other.value | 
|  | def __eq__(self, other): | 
|  | return self.value == other.value | 
|  |  | 
|  |  | 
|  | class TestLRU: | 
|  |  | 
|  | def test_lru(self): | 
|  | def orig(x, y): | 
|  | return 3 * x + y | 
|  | f = self.module.lru_cache(maxsize=20)(orig) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(maxsize, 20) | 
|  | self.assertEqual(currsize, 0) | 
|  | self.assertEqual(hits, 0) | 
|  | self.assertEqual(misses, 0) | 
|  |  | 
|  | domain = range(5) | 
|  | for i in range(1000): | 
|  | x, y = choice(domain), choice(domain) | 
|  | actual = f(x, y) | 
|  | expected = orig(x, y) | 
|  | self.assertEqual(actual, expected) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertTrue(hits > misses) | 
|  | self.assertEqual(hits + misses, 1000) | 
|  | self.assertEqual(currsize, 20) | 
|  |  | 
|  | f.cache_clear()   # test clearing | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 0) | 
|  | self.assertEqual(misses, 0) | 
|  | self.assertEqual(currsize, 0) | 
|  | f(x, y) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 0) | 
|  | self.assertEqual(misses, 1) | 
|  | self.assertEqual(currsize, 1) | 
|  |  | 
|  | # Test bypassing the cache | 
|  | self.assertIs(f.__wrapped__, orig) | 
|  | f.__wrapped__(x, y) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 0) | 
|  | self.assertEqual(misses, 1) | 
|  | self.assertEqual(currsize, 1) | 
|  |  | 
|  | # test size zero (which means "never-cache") | 
|  | @self.module.lru_cache(0) | 
|  | def f(): | 
|  | nonlocal f_cnt | 
|  | f_cnt += 1 | 
|  | return 20 | 
|  | self.assertEqual(f.cache_info().maxsize, 0) | 
|  | f_cnt = 0 | 
|  | for i in range(5): | 
|  | self.assertEqual(f(), 20) | 
|  | self.assertEqual(f_cnt, 5) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 0) | 
|  | self.assertEqual(misses, 5) | 
|  | self.assertEqual(currsize, 0) | 
|  |  | 
|  | # test size one | 
|  | @self.module.lru_cache(1) | 
|  | def f(): | 
|  | nonlocal f_cnt | 
|  | f_cnt += 1 | 
|  | return 20 | 
|  | self.assertEqual(f.cache_info().maxsize, 1) | 
|  | f_cnt = 0 | 
|  | for i in range(5): | 
|  | self.assertEqual(f(), 20) | 
|  | self.assertEqual(f_cnt, 1) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 4) | 
|  | self.assertEqual(misses, 1) | 
|  | self.assertEqual(currsize, 1) | 
|  |  | 
|  | # test size two | 
|  | @self.module.lru_cache(2) | 
|  | def f(x): | 
|  | nonlocal f_cnt | 
|  | f_cnt += 1 | 
|  | return x*10 | 
|  | self.assertEqual(f.cache_info().maxsize, 2) | 
|  | f_cnt = 0 | 
|  | for x in 7, 9, 7, 9, 7, 9, 8, 8, 8, 9, 9, 9, 8, 8, 8, 7: | 
|  | #    *  *              *                          * | 
|  | self.assertEqual(f(x), x*10) | 
|  | self.assertEqual(f_cnt, 4) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(hits, 12) | 
|  | self.assertEqual(misses, 4) | 
|  | self.assertEqual(currsize, 2) | 
|  |  | 
|  | def test_lru_bug_35780(self): | 
|  | # C version of the lru_cache was not checking to see if | 
|  | # the user function call has already modified the cache | 
|  | # (this arises in recursive calls and in multi-threading). | 
|  | # This cause the cache to have orphan links not referenced | 
|  | # by the cache dictionary. | 
|  |  | 
|  | once = True                 # Modified by f(x) below | 
|  |  | 
|  | @self.module.lru_cache(maxsize=10) | 
|  | def f(x): | 
|  | nonlocal once | 
|  | rv = f'.{x}.' | 
|  | if x == 20 and once: | 
|  | once = False | 
|  | rv = f(x) | 
|  | return rv | 
|  |  | 
|  | # Fill the cache | 
|  | for x in range(15): | 
|  | self.assertEqual(f(x), f'.{x}.') | 
|  | self.assertEqual(f.cache_info().currsize, 10) | 
|  |  | 
|  | # Make a recursive call and make sure the cache remains full | 
|  | self.assertEqual(f(20), '.20.') | 
|  | self.assertEqual(f.cache_info().currsize, 10) | 
|  |  | 
|  | def test_lru_hash_only_once(self): | 
|  | # To protect against weird reentrancy bugs and to improve | 
|  | # efficiency when faced with slow __hash__ methods, the | 
|  | # LRU cache guarantees that it will only call __hash__ | 
|  | # only once per use as an argument to the cached function. | 
|  |  | 
|  | @self.module.lru_cache(maxsize=1) | 
|  | def f(x, y): | 
|  | return x * 3 + y | 
|  |  | 
|  | # Simulate the integer 5 | 
|  | mock_int = unittest.mock.Mock() | 
|  | mock_int.__mul__ = unittest.mock.Mock(return_value=15) | 
|  | mock_int.__hash__ = unittest.mock.Mock(return_value=999) | 
|  |  | 
|  | # Add to cache:  One use as an argument gives one call | 
|  | self.assertEqual(f(mock_int, 1), 16) | 
|  | self.assertEqual(mock_int.__hash__.call_count, 1) | 
|  | self.assertEqual(f.cache_info(), (0, 1, 1, 1)) | 
|  |  | 
|  | # Cache hit: One use as an argument gives one additional call | 
|  | self.assertEqual(f(mock_int, 1), 16) | 
|  | self.assertEqual(mock_int.__hash__.call_count, 2) | 
|  | self.assertEqual(f.cache_info(), (1, 1, 1, 1)) | 
|  |  | 
|  | # Cache eviction: No use as an argument gives no additional call | 
|  | self.assertEqual(f(6, 2), 20) | 
|  | self.assertEqual(mock_int.__hash__.call_count, 2) | 
|  | self.assertEqual(f.cache_info(), (1, 2, 1, 1)) | 
|  |  | 
|  | # Cache miss: One use as an argument gives one additional call | 
|  | self.assertEqual(f(mock_int, 1), 16) | 
|  | self.assertEqual(mock_int.__hash__.call_count, 3) | 
|  | self.assertEqual(f.cache_info(), (1, 3, 1, 1)) | 
|  |  | 
|  | def test_lru_reentrancy_with_len(self): | 
|  | # Test to make sure the LRU cache code isn't thrown-off by | 
|  | # caching the built-in len() function.  Since len() can be | 
|  | # cached, we shouldn't use it inside the lru code itself. | 
|  | old_len = builtins.len | 
|  | try: | 
|  | builtins.len = self.module.lru_cache(4)(len) | 
|  | for i in [0, 0, 1, 2, 3, 3, 4, 5, 6, 1, 7, 2, 1]: | 
|  | self.assertEqual(len('abcdefghijklmn'[:i]), i) | 
|  | finally: | 
|  | builtins.len = old_len | 
|  |  | 
|  | def test_lru_star_arg_handling(self): | 
|  | # Test regression that arose in ea064ff3c10f | 
|  | @functools.lru_cache() | 
|  | def f(*args): | 
|  | return args | 
|  |  | 
|  | self.assertEqual(f(1, 2), (1, 2)) | 
|  | self.assertEqual(f((1, 2)), ((1, 2),)) | 
|  |  | 
|  | def test_lru_type_error(self): | 
|  | # Regression test for issue #28653. | 
|  | # lru_cache was leaking when one of the arguments | 
|  | # wasn't cacheable. | 
|  |  | 
|  | @functools.lru_cache(maxsize=None) | 
|  | def infinite_cache(o): | 
|  | pass | 
|  |  | 
|  | @functools.lru_cache(maxsize=10) | 
|  | def limited_cache(o): | 
|  | pass | 
|  |  | 
|  | with self.assertRaises(TypeError): | 
|  | infinite_cache([]) | 
|  |  | 
|  | with self.assertRaises(TypeError): | 
|  | limited_cache([]) | 
|  |  | 
|  | def test_lru_with_maxsize_none(self): | 
|  | @self.module.lru_cache(maxsize=None) | 
|  | def fib(n): | 
|  | if n < 2: | 
|  | return n | 
|  | return fib(n-1) + fib(n-2) | 
|  | self.assertEqual([fib(n) for n in range(16)], | 
|  | [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) | 
|  | fib.cache_clear() | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) | 
|  |  | 
|  | def test_lru_with_maxsize_negative(self): | 
|  | @self.module.lru_cache(maxsize=-10) | 
|  | def eq(n): | 
|  | return n | 
|  | for i in (0, 1): | 
|  | self.assertEqual([eq(n) for n in range(150)], list(range(150))) | 
|  | self.assertEqual(eq.cache_info(), | 
|  | self.module._CacheInfo(hits=0, misses=300, maxsize=0, currsize=0)) | 
|  |  | 
|  | def test_lru_with_exceptions(self): | 
|  | # Verify that user_function exceptions get passed through without | 
|  | # creating a hard-to-read chained exception. | 
|  | # http://bugs.python.org/issue13177 | 
|  | for maxsize in (None, 128): | 
|  | @self.module.lru_cache(maxsize) | 
|  | def func(i): | 
|  | return 'abc'[i] | 
|  | self.assertEqual(func(0), 'a') | 
|  | with self.assertRaises(IndexError) as cm: | 
|  | func(15) | 
|  | self.assertIsNone(cm.exception.__context__) | 
|  | # Verify that the previous exception did not result in a cached entry | 
|  | with self.assertRaises(IndexError): | 
|  | func(15) | 
|  |  | 
|  | def test_lru_with_types(self): | 
|  | for maxsize in (None, 128): | 
|  | @self.module.lru_cache(maxsize=maxsize, typed=True) | 
|  | def square(x): | 
|  | return x * x | 
|  | self.assertEqual(square(3), 9) | 
|  | self.assertEqual(type(square(3)), type(9)) | 
|  | self.assertEqual(square(3.0), 9.0) | 
|  | self.assertEqual(type(square(3.0)), type(9.0)) | 
|  | self.assertEqual(square(x=3), 9) | 
|  | self.assertEqual(type(square(x=3)), type(9)) | 
|  | self.assertEqual(square(x=3.0), 9.0) | 
|  | self.assertEqual(type(square(x=3.0)), type(9.0)) | 
|  | self.assertEqual(square.cache_info().hits, 4) | 
|  | self.assertEqual(square.cache_info().misses, 4) | 
|  |  | 
|  | def test_lru_with_keyword_args(self): | 
|  | @self.module.lru_cache() | 
|  | def fib(n): | 
|  | if n < 2: | 
|  | return n | 
|  | return fib(n=n-1) + fib(n=n-2) | 
|  | self.assertEqual( | 
|  | [fib(n=number) for number in range(16)], | 
|  | [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610] | 
|  | ) | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=28, misses=16, maxsize=128, currsize=16)) | 
|  | fib.cache_clear() | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=0, misses=0, maxsize=128, currsize=0)) | 
|  |  | 
|  | def test_lru_with_keyword_args_maxsize_none(self): | 
|  | @self.module.lru_cache(maxsize=None) | 
|  | def fib(n): | 
|  | if n < 2: | 
|  | return n | 
|  | return fib(n=n-1) + fib(n=n-2) | 
|  | self.assertEqual([fib(n=number) for number in range(16)], | 
|  | [0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89, 144, 233, 377, 610]) | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=28, misses=16, maxsize=None, currsize=16)) | 
|  | fib.cache_clear() | 
|  | self.assertEqual(fib.cache_info(), | 
|  | self.module._CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)) | 
|  |  | 
|  | def test_kwargs_order(self): | 
|  | # PEP 468: Preserving Keyword Argument Order | 
|  | @self.module.lru_cache(maxsize=10) | 
|  | def f(**kwargs): | 
|  | return list(kwargs.items()) | 
|  | self.assertEqual(f(a=1, b=2), [('a', 1), ('b', 2)]) | 
|  | self.assertEqual(f(b=2, a=1), [('b', 2), ('a', 1)]) | 
|  | self.assertEqual(f.cache_info(), | 
|  | self.module._CacheInfo(hits=0, misses=2, maxsize=10, currsize=2)) | 
|  |  | 
|  | def test_lru_cache_decoration(self): | 
|  | def f(zomg: 'zomg_annotation'): | 
|  | """f doc string""" | 
|  | return 42 | 
|  | g = self.module.lru_cache()(f) | 
|  | for attr in self.module.WRAPPER_ASSIGNMENTS: | 
|  | self.assertEqual(getattr(g, attr), getattr(f, attr)) | 
|  |  | 
|  | def test_lru_cache_threaded(self): | 
|  | n, m = 5, 11 | 
|  | def orig(x, y): | 
|  | return 3 * x + y | 
|  | f = self.module.lru_cache(maxsize=n*m)(orig) | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | self.assertEqual(currsize, 0) | 
|  |  | 
|  | start = threading.Event() | 
|  | def full(k): | 
|  | start.wait(10) | 
|  | for _ in range(m): | 
|  | self.assertEqual(f(k, 0), orig(k, 0)) | 
|  |  | 
|  | def clear(): | 
|  | start.wait(10) | 
|  | for _ in range(2*m): | 
|  | f.cache_clear() | 
|  |  | 
|  | orig_si = sys.getswitchinterval() | 
|  | support.setswitchinterval(1e-6) | 
|  | try: | 
|  | # create n threads in order to fill cache | 
|  | threads = [threading.Thread(target=full, args=[k]) | 
|  | for k in range(n)] | 
|  | with support.start_threads(threads): | 
|  | start.set() | 
|  |  | 
|  | hits, misses, maxsize, currsize = f.cache_info() | 
|  | if self.module is py_functools: | 
|  | # XXX: Why can be not equal? | 
|  | self.assertLessEqual(misses, n) | 
|  | self.assertLessEqual(hits, m*n - misses) | 
|  | else: | 
|  | self.assertEqual(misses, n) | 
|  | self.assertEqual(hits, m*n - misses) | 
|  | self.assertEqual(currsize, n) | 
|  |  | 
|  | # create n threads in order to fill cache and 1 to clear it | 
|  | threads = [threading.Thread(target=clear)] | 
|  | threads += [threading.Thread(target=full, args=[k]) | 
|  | for k in range(n)] | 
|  | start.clear() | 
|  | with support.start_threads(threads): | 
|  | start.set() | 
|  | finally: | 
|  | sys.setswitchinterval(orig_si) | 
|  |  | 
|  | def test_lru_cache_threaded2(self): | 
|  | # Simultaneous call with the same arguments | 
|  | n, m = 5, 7 | 
|  | start = threading.Barrier(n+1) | 
|  | pause = threading.Barrier(n+1) | 
|  | stop = threading.Barrier(n+1) | 
|  | @self.module.lru_cache(maxsize=m*n) | 
|  | def f(x): | 
|  | pause.wait(10) | 
|  | return 3 * x | 
|  | self.assertEqual(f.cache_info(), (0, 0, m*n, 0)) | 
|  | def test(): | 
|  | for i in range(m): | 
|  | start.wait(10) | 
|  | self.assertEqual(f(i), 3 * i) | 
|  | stop.wait(10) | 
|  | threads = [threading.Thread(target=test) for k in range(n)] | 
|  | with support.start_threads(threads): | 
|  | for i in range(m): | 
|  | start.wait(10) | 
|  | stop.reset() | 
|  | pause.wait(10) | 
|  | start.reset() | 
|  | stop.wait(10) | 
|  | pause.reset() | 
|  | self.assertEqual(f.cache_info(), (0, (i+1)*n, m*n, i+1)) | 
|  |  | 
|  | def test_lru_cache_threaded3(self): | 
|  | @self.module.lru_cache(maxsize=2) | 
|  | def f(x): | 
|  | time.sleep(.01) | 
|  | return 3 * x | 
|  | def test(i, x): | 
|  | with self.subTest(thread=i): | 
|  | self.assertEqual(f(x), 3 * x, i) | 
|  | threads = [threading.Thread(target=test, args=(i, v)) | 
|  | for i, v in enumerate([1, 2, 2, 3, 2])] | 
|  | with support.start_threads(threads): | 
|  | pass | 
|  |  | 
|  | def test_need_for_rlock(self): | 
|  | # This will deadlock on an LRU cache that uses a regular lock | 
|  |  | 
|  | @self.module.lru_cache(maxsize=10) | 
|  | def test_func(x): | 
|  | 'Used to demonstrate a reentrant lru_cache call within a single thread' | 
|  | return x | 
|  |  | 
|  | class DoubleEq: | 
|  | 'Demonstrate a reentrant lru_cache call within a single thread' | 
|  | def __init__(self, x): | 
|  | self.x = x | 
|  | def __hash__(self): | 
|  | return self.x | 
|  | def __eq__(self, other): | 
|  | if self.x == 2: | 
|  | test_func(DoubleEq(1)) | 
|  | return self.x == other.x | 
|  |  | 
|  | test_func(DoubleEq(1))                      # Load the cache | 
|  | test_func(DoubleEq(2))                      # Load the cache | 
|  | self.assertEqual(test_func(DoubleEq(2)),    # Trigger a re-entrant __eq__ call | 
|  | DoubleEq(2))               # Verify the correct return value | 
|  |  | 
|  | def test_early_detection_of_bad_call(self): | 
|  | # Issue #22184 | 
|  | with self.assertRaises(TypeError): | 
|  | @functools.lru_cache | 
|  | def f(): | 
|  | pass | 
|  |  | 
|  | def test_lru_method(self): | 
|  | class X(int): | 
|  | f_cnt = 0 | 
|  | @self.module.lru_cache(2) | 
|  | def f(self, x): | 
|  | self.f_cnt += 1 | 
|  | return x*10+self | 
|  | a = X(5) | 
|  | b = X(5) | 
|  | c = X(7) | 
|  | self.assertEqual(X.f.cache_info(), (0, 0, 2, 0)) | 
|  |  | 
|  | for x in 1, 2, 2, 3, 1, 1, 1, 2, 3, 3: | 
|  | self.assertEqual(a.f(x), x*10 + 5) | 
|  | self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 0, 0)) | 
|  | self.assertEqual(X.f.cache_info(), (4, 6, 2, 2)) | 
|  |  | 
|  | for x in 1, 2, 1, 1, 1, 1, 3, 2, 2, 2: | 
|  | self.assertEqual(b.f(x), x*10 + 5) | 
|  | self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 0)) | 
|  | self.assertEqual(X.f.cache_info(), (10, 10, 2, 2)) | 
|  |  | 
|  | for x in 2, 1, 1, 1, 1, 2, 1, 3, 2, 1: | 
|  | self.assertEqual(c.f(x), x*10 + 7) | 
|  | self.assertEqual((a.f_cnt, b.f_cnt, c.f_cnt), (6, 4, 5)) | 
|  | self.assertEqual(X.f.cache_info(), (15, 15, 2, 2)) | 
|  |  | 
|  | self.assertEqual(a.f.cache_info(), X.f.cache_info()) | 
|  | self.assertEqual(b.f.cache_info(), X.f.cache_info()) | 
|  | self.assertEqual(c.f.cache_info(), X.f.cache_info()) | 
|  |  | 
|  | def test_pickle(self): | 
|  | cls = self.__class__ | 
|  | for f in cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth: | 
|  | for proto in range(pickle.HIGHEST_PROTOCOL + 1): | 
|  | with self.subTest(proto=proto, func=f): | 
|  | f_copy = pickle.loads(pickle.dumps(f, proto)) | 
|  | self.assertIs(f_copy, f) | 
|  |  | 
|  | def test_copy(self): | 
|  | cls = self.__class__ | 
|  | def orig(x, y): | 
|  | return 3 * x + y | 
|  | part = self.module.partial(orig, 2) | 
|  | funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, | 
|  | self.module.lru_cache(2)(part)) | 
|  | for f in funcs: | 
|  | with self.subTest(func=f): | 
|  | f_copy = copy.copy(f) | 
|  | self.assertIs(f_copy, f) | 
|  |  | 
|  | def test_deepcopy(self): | 
|  | cls = self.__class__ | 
|  | def orig(x, y): | 
|  | return 3 * x + y | 
|  | part = self.module.partial(orig, 2) | 
|  | funcs = (cls.cached_func[0], cls.cached_meth, cls.cached_staticmeth, | 
|  | self.module.lru_cache(2)(part)) | 
|  | for f in funcs: | 
|  | with self.subTest(func=f): | 
|  | f_copy = copy.deepcopy(f) | 
|  | self.assertIs(f_copy, f) | 
|  |  | 
|  |  | 
|  | @py_functools.lru_cache() | 
|  | def py_cached_func(x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  | @c_functools.lru_cache() | 
|  | def c_cached_func(x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  |  | 
|  | class TestLRUPy(TestLRU, unittest.TestCase): | 
|  | module = py_functools | 
|  | cached_func = py_cached_func, | 
|  |  | 
|  | @module.lru_cache() | 
|  | def cached_meth(self, x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  | @staticmethod | 
|  | @module.lru_cache() | 
|  | def cached_staticmeth(x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  |  | 
|  | class TestLRUC(TestLRU, unittest.TestCase): | 
|  | module = c_functools | 
|  | cached_func = c_cached_func, | 
|  |  | 
|  | @module.lru_cache() | 
|  | def cached_meth(self, x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  | @staticmethod | 
|  | @module.lru_cache() | 
|  | def cached_staticmeth(x, y): | 
|  | return 3 * x + y | 
|  |  | 
|  |  | 
|  | class TestSingleDispatch(unittest.TestCase): | 
|  | def test_simple_overloads(self): | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | return "base" | 
|  | def g_int(i): | 
|  | return "integer" | 
|  | g.register(int, g_int) | 
|  | self.assertEqual(g("str"), "base") | 
|  | self.assertEqual(g(1), "integer") | 
|  | self.assertEqual(g([1,2,3]), "base") | 
|  |  | 
|  | def test_mro(self): | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | return "base" | 
|  | class A: | 
|  | pass | 
|  | class C(A): | 
|  | pass | 
|  | class B(A): | 
|  | pass | 
|  | class D(C, B): | 
|  | pass | 
|  | def g_A(a): | 
|  | return "A" | 
|  | def g_B(b): | 
|  | return "B" | 
|  | g.register(A, g_A) | 
|  | g.register(B, g_B) | 
|  | self.assertEqual(g(A()), "A") | 
|  | self.assertEqual(g(B()), "B") | 
|  | self.assertEqual(g(C()), "A") | 
|  | self.assertEqual(g(D()), "B") | 
|  |  | 
|  | def test_register_decorator(self): | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | return "base" | 
|  | @g.register(int) | 
|  | def g_int(i): | 
|  | return "int %s" % (i,) | 
|  | self.assertEqual(g(""), "base") | 
|  | self.assertEqual(g(12), "int 12") | 
|  | self.assertIs(g.dispatch(int), g_int) | 
|  | self.assertIs(g.dispatch(object), g.dispatch(str)) | 
|  | # Note: in the assert above this is not g. | 
|  | # @singledispatch returns the wrapper. | 
|  |  | 
|  | def test_wrapping_attributes(self): | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | "Simple test" | 
|  | return "Test" | 
|  | self.assertEqual(g.__name__, "g") | 
|  | if sys.flags.optimize < 2: | 
|  | self.assertEqual(g.__doc__, "Simple test") | 
|  |  | 
|  | @unittest.skipUnless(decimal, 'requires _decimal') | 
|  | @support.cpython_only | 
|  | def test_c_classes(self): | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | return "base" | 
|  | @g.register(decimal.DecimalException) | 
|  | def _(obj): | 
|  | return obj.args | 
|  | subn = decimal.Subnormal("Exponent < Emin") | 
|  | rnd = decimal.Rounded("Number got rounded") | 
|  | self.assertEqual(g(subn), ("Exponent < Emin",)) | 
|  | self.assertEqual(g(rnd), ("Number got rounded",)) | 
|  | @g.register(decimal.Subnormal) | 
|  | def _(obj): | 
|  | return "Too small to care." | 
|  | self.assertEqual(g(subn), "Too small to care.") | 
|  | self.assertEqual(g(rnd), ("Number got rounded",)) | 
|  |  | 
|  | def test_compose_mro(self): | 
|  | # None of the examples in this test depend on haystack ordering. | 
|  | c = collections.abc | 
|  | mro = functools._compose_mro | 
|  | bases = [c.Sequence, c.MutableMapping, c.Mapping, c.Set] | 
|  | for haystack in permutations(bases): | 
|  | m = mro(dict, haystack) | 
|  | self.assertEqual(m, [dict, c.MutableMapping, c.Mapping, | 
|  | c.Collection, c.Sized, c.Iterable, | 
|  | c.Container, object]) | 
|  | bases = [c.Container, c.Mapping, c.MutableMapping, collections.OrderedDict] | 
|  | for haystack in permutations(bases): | 
|  | m = mro(collections.ChainMap, haystack) | 
|  | self.assertEqual(m, [collections.ChainMap, c.MutableMapping, c.Mapping, | 
|  | c.Collection, c.Sized, c.Iterable, | 
|  | c.Container, object]) | 
|  |  | 
|  | # If there's a generic function with implementations registered for | 
|  | # both Sized and Container, passing a defaultdict to it results in an | 
|  | # ambiguous dispatch which will cause a RuntimeError (see | 
|  | # test_mro_conflicts). | 
|  | bases = [c.Container, c.Sized, str] | 
|  | for haystack in permutations(bases): | 
|  | m = mro(collections.defaultdict, [c.Sized, c.Container, str]) | 
|  | self.assertEqual(m, [collections.defaultdict, dict, c.Sized, | 
|  | c.Container, object]) | 
|  |  | 
|  | # MutableSequence below is registered directly on D. In other words, it | 
|  | # precedes MutableMapping which means single dispatch will always | 
|  | # choose MutableSequence here. | 
|  | class D(collections.defaultdict): | 
|  | pass | 
|  | c.MutableSequence.register(D) | 
|  | bases = [c.MutableSequence, c.MutableMapping] | 
|  | for haystack in permutations(bases): | 
|  | m = mro(D, bases) | 
|  | self.assertEqual(m, [D, c.MutableSequence, c.Sequence, c.Reversible, | 
|  | collections.defaultdict, dict, c.MutableMapping, c.Mapping, | 
|  | c.Collection, c.Sized, c.Iterable, c.Container, | 
|  | object]) | 
|  |  | 
|  | # Container and Callable are registered on different base classes and | 
|  | # a generic function supporting both should always pick the Callable | 
|  | # implementation if a C instance is passed. | 
|  | class C(collections.defaultdict): | 
|  | def __call__(self): | 
|  | pass | 
|  | bases = [c.Sized, c.Callable, c.Container, c.Mapping] | 
|  | for haystack in permutations(bases): | 
|  | m = mro(C, haystack) | 
|  | self.assertEqual(m, [C, c.Callable, collections.defaultdict, dict, c.Mapping, | 
|  | c.Collection, c.Sized, c.Iterable, | 
|  | c.Container, object]) | 
|  |  | 
|  | def test_register_abc(self): | 
|  | c = collections.abc | 
|  | d = {"a": "b"} | 
|  | l = [1, 2, 3] | 
|  | s = {object(), None} | 
|  | f = frozenset(s) | 
|  | t = (1, 2, 3) | 
|  | @functools.singledispatch | 
|  | def g(obj): | 
|  | return "base" | 
|  | self.assertEqual(g(d), "base") | 
|  | self.assertEqual(g(l), "base") | 
|  | self.assertEqual(g(s), "base") | 
|  | self.assertEqual(g(f), "base") | 
|  | self.assertEqual(g(t), "base") | 
|  | g.register(c.Sized, lambda obj: "sized") | 
|  | self.assertEqual(g(d), "sized") | 
|  | self.assertEqual(g(l), "sized") | 
|  | self.assertEqual(g(s), "sized") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(c.MutableMapping, lambda obj: "mutablemapping") | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(g(l), "sized") | 
|  | self.assertEqual(g(s), "sized") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(collections.ChainMap, lambda obj: "chainmap") | 
|  | self.assertEqual(g(d), "mutablemapping")  # irrelevant ABCs registered | 
|  | self.assertEqual(g(l), "sized") | 
|  | self.assertEqual(g(s), "sized") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(c.MutableSequence, lambda obj: "mutablesequence") | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "sized") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(c.MutableSet, lambda obj: "mutableset") | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(c.Mapping, lambda obj: "mapping") | 
|  | self.assertEqual(g(d), "mutablemapping")  # not specific enough | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sized") | 
|  | g.register(c.Sequence, lambda obj: "sequence") | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "sized") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(c.Set, lambda obj: "set") | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "set") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(dict, lambda obj: "dict") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "mutablesequence") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "set") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(list, lambda obj: "list") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(g(s), "mutableset") | 
|  | self.assertEqual(g(f), "set") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(set, lambda obj: "concrete-set") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(g(s), "concrete-set") | 
|  | self.assertEqual(g(f), "set") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(frozenset, lambda obj: "frozen-set") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(g(s), "concrete-set") | 
|  | self.assertEqual(g(f), "frozen-set") | 
|  | self.assertEqual(g(t), "sequence") | 
|  | g.register(tuple, lambda obj: "tuple") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(g(s), "concrete-set") | 
|  | self.assertEqual(g(f), "frozen-set") | 
|  | self.assertEqual(g(t), "tuple") | 
|  |  | 
|  | def test_c3_abc(self): | 
|  | c = collections.abc | 
|  | mro = functools._c3_mro | 
|  | class A(object): | 
|  | pass | 
|  | class B(A): | 
|  | def __len__(self): | 
|  | return 0   # implies Sized | 
|  | @c.Container.register | 
|  | class C(object): | 
|  | pass | 
|  | class D(object): | 
|  | pass   # unrelated | 
|  | class X(D, C, B): | 
|  | def __call__(self): | 
|  | pass   # implies Callable | 
|  | expected = [X, c.Callable, D, C, c.Container, B, c.Sized, A, object] | 
|  | for abcs in permutations([c.Sized, c.Callable, c.Container]): | 
|  | self.assertEqual(mro(X, abcs=abcs), expected) | 
|  | # unrelated ABCs don't appear in the resulting MRO | 
|  | many_abcs = [c.Mapping, c.Sized, c.Callable, c.Container, c.Iterable] | 
|  | self.assertEqual(mro(X, abcs=many_abcs), expected) | 
|  |  | 
|  | def test_false_meta(self): | 
|  | # see issue23572 | 
|  | class MetaA(type): | 
|  | def __len__(self): | 
|  | return 0 | 
|  | class A(metaclass=MetaA): | 
|  | pass | 
|  | class AA(A): | 
|  | pass | 
|  | @functools.singledispatch | 
|  | def fun(a): | 
|  | return 'base A' | 
|  | @fun.register(A) | 
|  | def _(a): | 
|  | return 'fun A' | 
|  | aa = AA() | 
|  | self.assertEqual(fun(aa), 'fun A') | 
|  |  | 
|  | def test_mro_conflicts(self): | 
|  | c = collections.abc | 
|  | @functools.singledispatch | 
|  | def g(arg): | 
|  | return "base" | 
|  | class O(c.Sized): | 
|  | def __len__(self): | 
|  | return 0 | 
|  | o = O() | 
|  | self.assertEqual(g(o), "base") | 
|  | g.register(c.Iterable, lambda arg: "iterable") | 
|  | g.register(c.Container, lambda arg: "container") | 
|  | g.register(c.Sized, lambda arg: "sized") | 
|  | g.register(c.Set, lambda arg: "set") | 
|  | self.assertEqual(g(o), "sized") | 
|  | c.Iterable.register(O) | 
|  | self.assertEqual(g(o), "sized")   # because it's explicitly in __mro__ | 
|  | c.Container.register(O) | 
|  | self.assertEqual(g(o), "sized")   # see above: Sized is in __mro__ | 
|  | c.Set.register(O) | 
|  | self.assertEqual(g(o), "set")     # because c.Set is a subclass of | 
|  | # c.Sized and c.Container | 
|  | class P: | 
|  | pass | 
|  | p = P() | 
|  | self.assertEqual(g(p), "base") | 
|  | c.Iterable.register(P) | 
|  | self.assertEqual(g(p), "iterable") | 
|  | c.Container.register(P) | 
|  | with self.assertRaises(RuntimeError) as re_one: | 
|  | g(p) | 
|  | self.assertIn( | 
|  | str(re_one.exception), | 
|  | (("Ambiguous dispatch: <class 'collections.abc.Container'> " | 
|  | "or <class 'collections.abc.Iterable'>"), | 
|  | ("Ambiguous dispatch: <class 'collections.abc.Iterable'> " | 
|  | "or <class 'collections.abc.Container'>")), | 
|  | ) | 
|  | class Q(c.Sized): | 
|  | def __len__(self): | 
|  | return 0 | 
|  | q = Q() | 
|  | self.assertEqual(g(q), "sized") | 
|  | c.Iterable.register(Q) | 
|  | self.assertEqual(g(q), "sized")   # because it's explicitly in __mro__ | 
|  | c.Set.register(Q) | 
|  | self.assertEqual(g(q), "set")     # because c.Set is a subclass of | 
|  | # c.Sized and c.Iterable | 
|  | @functools.singledispatch | 
|  | def h(arg): | 
|  | return "base" | 
|  | @h.register(c.Sized) | 
|  | def _(arg): | 
|  | return "sized" | 
|  | @h.register(c.Container) | 
|  | def _(arg): | 
|  | return "container" | 
|  | # Even though Sized and Container are explicit bases of MutableMapping, | 
|  | # this ABC is implicitly registered on defaultdict which makes all of | 
|  | # MutableMapping's bases implicit as well from defaultdict's | 
|  | # perspective. | 
|  | with self.assertRaises(RuntimeError) as re_two: | 
|  | h(collections.defaultdict(lambda: 0)) | 
|  | self.assertIn( | 
|  | str(re_two.exception), | 
|  | (("Ambiguous dispatch: <class 'collections.abc.Container'> " | 
|  | "or <class 'collections.abc.Sized'>"), | 
|  | ("Ambiguous dispatch: <class 'collections.abc.Sized'> " | 
|  | "or <class 'collections.abc.Container'>")), | 
|  | ) | 
|  | class R(collections.defaultdict): | 
|  | pass | 
|  | c.MutableSequence.register(R) | 
|  | @functools.singledispatch | 
|  | def i(arg): | 
|  | return "base" | 
|  | @i.register(c.MutableMapping) | 
|  | def _(arg): | 
|  | return "mapping" | 
|  | @i.register(c.MutableSequence) | 
|  | def _(arg): | 
|  | return "sequence" | 
|  | r = R() | 
|  | self.assertEqual(i(r), "sequence") | 
|  | class S: | 
|  | pass | 
|  | class T(S, c.Sized): | 
|  | def __len__(self): | 
|  | return 0 | 
|  | t = T() | 
|  | self.assertEqual(h(t), "sized") | 
|  | c.Container.register(T) | 
|  | self.assertEqual(h(t), "sized")   # because it's explicitly in the MRO | 
|  | class U: | 
|  | def __len__(self): | 
|  | return 0 | 
|  | u = U() | 
|  | self.assertEqual(h(u), "sized")   # implicit Sized subclass inferred | 
|  | # from the existence of __len__() | 
|  | c.Container.register(U) | 
|  | # There is no preference for registered versus inferred ABCs. | 
|  | with self.assertRaises(RuntimeError) as re_three: | 
|  | h(u) | 
|  | self.assertIn( | 
|  | str(re_three.exception), | 
|  | (("Ambiguous dispatch: <class 'collections.abc.Container'> " | 
|  | "or <class 'collections.abc.Sized'>"), | 
|  | ("Ambiguous dispatch: <class 'collections.abc.Sized'> " | 
|  | "or <class 'collections.abc.Container'>")), | 
|  | ) | 
|  | class V(c.Sized, S): | 
|  | def __len__(self): | 
|  | return 0 | 
|  | @functools.singledispatch | 
|  | def j(arg): | 
|  | return "base" | 
|  | @j.register(S) | 
|  | def _(arg): | 
|  | return "s" | 
|  | @j.register(c.Container) | 
|  | def _(arg): | 
|  | return "container" | 
|  | v = V() | 
|  | self.assertEqual(j(v), "s") | 
|  | c.Container.register(V) | 
|  | self.assertEqual(j(v), "container")   # because it ends up right after | 
|  | # Sized in the MRO | 
|  |  | 
|  | def test_cache_invalidation(self): | 
|  | from collections import UserDict | 
|  | import weakref | 
|  |  | 
|  | class TracingDict(UserDict): | 
|  | def __init__(self, *args, **kwargs): | 
|  | super(TracingDict, self).__init__(*args, **kwargs) | 
|  | self.set_ops = [] | 
|  | self.get_ops = [] | 
|  | def __getitem__(self, key): | 
|  | result = self.data[key] | 
|  | self.get_ops.append(key) | 
|  | return result | 
|  | def __setitem__(self, key, value): | 
|  | self.set_ops.append(key) | 
|  | self.data[key] = value | 
|  | def clear(self): | 
|  | self.data.clear() | 
|  |  | 
|  | td = TracingDict() | 
|  | with support.swap_attr(weakref, "WeakKeyDictionary", lambda: td): | 
|  | c = collections.abc | 
|  | @functools.singledispatch | 
|  | def g(arg): | 
|  | return "base" | 
|  | d = {} | 
|  | l = [] | 
|  | self.assertEqual(len(td), 0) | 
|  | self.assertEqual(g(d), "base") | 
|  | self.assertEqual(len(td), 1) | 
|  | self.assertEqual(td.get_ops, []) | 
|  | self.assertEqual(td.set_ops, [dict]) | 
|  | self.assertEqual(td.data[dict], g.registry[object]) | 
|  | self.assertEqual(g(l), "base") | 
|  | self.assertEqual(len(td), 2) | 
|  | self.assertEqual(td.get_ops, []) | 
|  | self.assertEqual(td.set_ops, [dict, list]) | 
|  | self.assertEqual(td.data[dict], g.registry[object]) | 
|  | self.assertEqual(td.data[list], g.registry[object]) | 
|  | self.assertEqual(td.data[dict], td.data[list]) | 
|  | self.assertEqual(g(l), "base") | 
|  | self.assertEqual(g(d), "base") | 
|  | self.assertEqual(td.get_ops, [list, dict]) | 
|  | self.assertEqual(td.set_ops, [dict, list]) | 
|  | g.register(list, lambda arg: "list") | 
|  | self.assertEqual(td.get_ops, [list, dict]) | 
|  | self.assertEqual(len(td), 0) | 
|  | self.assertEqual(g(d), "base") | 
|  | self.assertEqual(len(td), 1) | 
|  | self.assertEqual(td.get_ops, [list, dict]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict]) | 
|  | self.assertEqual(td.data[dict], | 
|  | functools._find_impl(dict, g.registry)) | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(len(td), 2) | 
|  | self.assertEqual(td.get_ops, [list, dict]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list]) | 
|  | self.assertEqual(td.data[list], | 
|  | functools._find_impl(list, g.registry)) | 
|  | class X: | 
|  | pass | 
|  | c.MutableMapping.register(X)   # Will not invalidate the cache, | 
|  | # not using ABCs yet. | 
|  | self.assertEqual(g(d), "base") | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(td.get_ops, [list, dict, dict, list]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list]) | 
|  | g.register(c.Sized, lambda arg: "sized") | 
|  | self.assertEqual(len(td), 0) | 
|  | self.assertEqual(g(d), "sized") | 
|  | self.assertEqual(len(td), 1) | 
|  | self.assertEqual(td.get_ops, [list, dict, dict, list]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list, dict]) | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(len(td), 2) | 
|  | self.assertEqual(td.get_ops, [list, dict, dict, list]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(g(d), "sized") | 
|  | self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) | 
|  | g.dispatch(list) | 
|  | g.dispatch(dict) | 
|  | self.assertEqual(td.get_ops, [list, dict, dict, list, list, dict, | 
|  | list, dict]) | 
|  | self.assertEqual(td.set_ops, [dict, list, dict, list, dict, list]) | 
|  | c.MutableSet.register(X)       # Will invalidate the cache. | 
|  | self.assertEqual(len(td), 2)   # Stale cache. | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(len(td), 1) | 
|  | g.register(c.MutableMapping, lambda arg: "mutablemapping") | 
|  | self.assertEqual(len(td), 0) | 
|  | self.assertEqual(g(d), "mutablemapping") | 
|  | self.assertEqual(len(td), 1) | 
|  | self.assertEqual(g(l), "list") | 
|  | self.assertEqual(len(td), 2) | 
|  | g.register(dict, lambda arg: "dict") | 
|  | self.assertEqual(g(d), "dict") | 
|  | self.assertEqual(g(l), "list") | 
|  | g._clear_cache() | 
|  | self.assertEqual(len(td), 0) | 
|  |  | 
|  | def test_annotations(self): | 
|  | @functools.singledispatch | 
|  | def i(arg): | 
|  | return "base" | 
|  | @i.register | 
|  | def _(arg: collections.abc.Mapping): | 
|  | return "mapping" | 
|  | @i.register | 
|  | def _(arg: "collections.abc.Sequence"): | 
|  | return "sequence" | 
|  | self.assertEqual(i(None), "base") | 
|  | self.assertEqual(i({"a": 1}), "mapping") | 
|  | self.assertEqual(i([1, 2, 3]), "sequence") | 
|  | self.assertEqual(i((1, 2, 3)), "sequence") | 
|  | self.assertEqual(i("str"), "sequence") | 
|  |  | 
|  | # Registering classes as callables doesn't work with annotations, | 
|  | # you need to pass the type explicitly. | 
|  | @i.register(str) | 
|  | class _: | 
|  | def __init__(self, arg): | 
|  | self.arg = arg | 
|  |  | 
|  | def __eq__(self, other): | 
|  | return self.arg == other | 
|  | self.assertEqual(i("str"), "str") | 
|  |  | 
|  | def test_method_register(self): | 
|  | class A: | 
|  | @functools.singledispatchmethod | 
|  | def t(self, arg): | 
|  | self.arg = "base" | 
|  | @t.register(int) | 
|  | def _(self, arg): | 
|  | self.arg = "int" | 
|  | @t.register(str) | 
|  | def _(self, arg): | 
|  | self.arg = "str" | 
|  | a = A() | 
|  |  | 
|  | a.t(0) | 
|  | self.assertEqual(a.arg, "int") | 
|  | aa = A() | 
|  | self.assertFalse(hasattr(aa, 'arg')) | 
|  | a.t('') | 
|  | self.assertEqual(a.arg, "str") | 
|  | aa = A() | 
|  | self.assertFalse(hasattr(aa, 'arg')) | 
|  | a.t(0.0) | 
|  | self.assertEqual(a.arg, "base") | 
|  | aa = A() | 
|  | self.assertFalse(hasattr(aa, 'arg')) | 
|  |  | 
|  | def test_staticmethod_register(self): | 
|  | class A: | 
|  | @functools.singledispatchmethod | 
|  | @staticmethod | 
|  | def t(arg): | 
|  | return arg | 
|  | @t.register(int) | 
|  | @staticmethod | 
|  | def _(arg): | 
|  | return isinstance(arg, int) | 
|  | @t.register(str) | 
|  | @staticmethod | 
|  | def _(arg): | 
|  | return isinstance(arg, str) | 
|  | a = A() | 
|  |  | 
|  | self.assertTrue(A.t(0)) | 
|  | self.assertTrue(A.t('')) | 
|  | self.assertEqual(A.t(0.0), 0.0) | 
|  |  | 
|  | def test_classmethod_register(self): | 
|  | class A: | 
|  | def __init__(self, arg): | 
|  | self.arg = arg | 
|  |  | 
|  | @functools.singledispatchmethod | 
|  | @classmethod | 
|  | def t(cls, arg): | 
|  | return cls("base") | 
|  | @t.register(int) | 
|  | @classmethod | 
|  | def _(cls, arg): | 
|  | return cls("int") | 
|  | @t.register(str) | 
|  | @classmethod | 
|  | def _(cls, arg): | 
|  | return cls("str") | 
|  |  | 
|  | self.assertEqual(A.t(0).arg, "int") | 
|  | self.assertEqual(A.t('').arg, "str") | 
|  | self.assertEqual(A.t(0.0).arg, "base") | 
|  |  | 
|  | def test_callable_register(self): | 
|  | class A: | 
|  | def __init__(self, arg): | 
|  | self.arg = arg | 
|  |  | 
|  | @functools.singledispatchmethod | 
|  | @classmethod | 
|  | def t(cls, arg): | 
|  | return cls("base") | 
|  |  | 
|  | @A.t.register(int) | 
|  | @classmethod | 
|  | def _(cls, arg): | 
|  | return cls("int") | 
|  | @A.t.register(str) | 
|  | @classmethod | 
|  | def _(cls, arg): | 
|  | return cls("str") | 
|  |  | 
|  | self.assertEqual(A.t(0).arg, "int") | 
|  | self.assertEqual(A.t('').arg, "str") | 
|  | self.assertEqual(A.t(0.0).arg, "base") | 
|  |  | 
|  | def test_abstractmethod_register(self): | 
|  | class Abstract(abc.ABCMeta): | 
|  |  | 
|  | @functools.singledispatchmethod | 
|  | @abc.abstractmethod | 
|  | def add(self, x, y): | 
|  | pass | 
|  |  | 
|  | self.assertTrue(Abstract.add.__isabstractmethod__) | 
|  |  | 
|  | def test_type_ann_register(self): | 
|  | class A: | 
|  | @functools.singledispatchmethod | 
|  | def t(self, arg): | 
|  | return "base" | 
|  | @t.register | 
|  | def _(self, arg: int): | 
|  | return "int" | 
|  | @t.register | 
|  | def _(self, arg: str): | 
|  | return "str" | 
|  | a = A() | 
|  |  | 
|  | self.assertEqual(a.t(0), "int") | 
|  | self.assertEqual(a.t(''), "str") | 
|  | self.assertEqual(a.t(0.0), "base") | 
|  |  | 
|  | def test_invalid_registrations(self): | 
|  | msg_prefix = "Invalid first argument to `register()`: " | 
|  | msg_suffix = ( | 
|  | ". Use either `@register(some_class)` or plain `@register` on an " | 
|  | "annotated function." | 
|  | ) | 
|  | @functools.singledispatch | 
|  | def i(arg): | 
|  | return "base" | 
|  | with self.assertRaises(TypeError) as exc: | 
|  | @i.register(42) | 
|  | def _(arg): | 
|  | return "I annotated with a non-type" | 
|  | self.assertTrue(str(exc.exception).startswith(msg_prefix + "42")) | 
|  | self.assertTrue(str(exc.exception).endswith(msg_suffix)) | 
|  | with self.assertRaises(TypeError) as exc: | 
|  | @i.register | 
|  | def _(arg): | 
|  | return "I forgot to annotate" | 
|  | self.assertTrue(str(exc.exception).startswith(msg_prefix + | 
|  | "<function TestSingleDispatch.test_invalid_registrations.<locals>._" | 
|  | )) | 
|  | self.assertTrue(str(exc.exception).endswith(msg_suffix)) | 
|  |  | 
|  | # FIXME: The following will only work after PEP 560 is implemented. | 
|  | return | 
|  |  | 
|  | with self.assertRaises(TypeError) as exc: | 
|  | @i.register | 
|  | def _(arg: typing.Iterable[str]): | 
|  | # At runtime, dispatching on generics is impossible. | 
|  | # When registering implementations with singledispatch, avoid | 
|  | # types from `typing`. Instead, annotate with regular types | 
|  | # or ABCs. | 
|  | return "I annotated with a generic collection" | 
|  | self.assertTrue(str(exc.exception).startswith(msg_prefix + | 
|  | "<function TestSingleDispatch.test_invalid_registrations.<locals>._" | 
|  | )) | 
|  | self.assertTrue(str(exc.exception).endswith(msg_suffix)) | 
|  |  | 
|  | def test_invalid_positional_argument(self): | 
|  | @functools.singledispatch | 
|  | def f(*args): | 
|  | pass | 
|  | msg = 'f requires at least 1 positional argument' | 
|  | with self.assertRaisesRegex(TypeError, msg): | 
|  | f() | 
|  |  | 
|  |  | 
|  | class CachedCostItem: | 
|  | _cost = 1 | 
|  |  | 
|  | def __init__(self): | 
|  | self.lock = py_functools.RLock() | 
|  |  | 
|  | @py_functools.cached_property | 
|  | def cost(self): | 
|  | """The cost of the item.""" | 
|  | with self.lock: | 
|  | self._cost += 1 | 
|  | return self._cost | 
|  |  | 
|  |  | 
|  | class OptionallyCachedCostItem: | 
|  | _cost = 1 | 
|  |  | 
|  | def get_cost(self): | 
|  | """The cost of the item.""" | 
|  | self._cost += 1 | 
|  | return self._cost | 
|  |  | 
|  | cached_cost = py_functools.cached_property(get_cost) | 
|  |  | 
|  |  | 
|  | class CachedCostItemWait: | 
|  |  | 
|  | def __init__(self, event): | 
|  | self._cost = 1 | 
|  | self.lock = py_functools.RLock() | 
|  | self.event = event | 
|  |  | 
|  | @py_functools.cached_property | 
|  | def cost(self): | 
|  | self.event.wait(1) | 
|  | with self.lock: | 
|  | self._cost += 1 | 
|  | return self._cost | 
|  |  | 
|  |  | 
|  | class CachedCostItemWithSlots: | 
|  | __slots__ = ('_cost') | 
|  |  | 
|  | def __init__(self): | 
|  | self._cost = 1 | 
|  |  | 
|  | @py_functools.cached_property | 
|  | def cost(self): | 
|  | raise RuntimeError('never called, slots not supported') | 
|  |  | 
|  |  | 
|  | class TestCachedProperty(unittest.TestCase): | 
|  | def test_cached(self): | 
|  | item = CachedCostItem() | 
|  | self.assertEqual(item.cost, 2) | 
|  | self.assertEqual(item.cost, 2) # not 3 | 
|  |  | 
|  | def test_cached_attribute_name_differs_from_func_name(self): | 
|  | item = OptionallyCachedCostItem() | 
|  | self.assertEqual(item.get_cost(), 2) | 
|  | self.assertEqual(item.cached_cost, 3) | 
|  | self.assertEqual(item.get_cost(), 4) | 
|  | self.assertEqual(item.cached_cost, 3) | 
|  |  | 
|  | def test_threaded(self): | 
|  | go = threading.Event() | 
|  | item = CachedCostItemWait(go) | 
|  |  | 
|  | num_threads = 3 | 
|  |  | 
|  | orig_si = sys.getswitchinterval() | 
|  | sys.setswitchinterval(1e-6) | 
|  | try: | 
|  | threads = [ | 
|  | threading.Thread(target=lambda: item.cost) | 
|  | for k in range(num_threads) | 
|  | ] | 
|  | with support.start_threads(threads): | 
|  | go.set() | 
|  | finally: | 
|  | sys.setswitchinterval(orig_si) | 
|  |  | 
|  | self.assertEqual(item.cost, 2) | 
|  |  | 
|  | def test_object_with_slots(self): | 
|  | item = CachedCostItemWithSlots() | 
|  | with self.assertRaisesRegex( | 
|  | TypeError, | 
|  | "No '__dict__' attribute on 'CachedCostItemWithSlots' instance to cache 'cost' property.", | 
|  | ): | 
|  | item.cost | 
|  |  | 
|  | def test_immutable_dict(self): | 
|  | class MyMeta(type): | 
|  | @py_functools.cached_property | 
|  | def prop(self): | 
|  | return True | 
|  |  | 
|  | class MyClass(metaclass=MyMeta): | 
|  | pass | 
|  |  | 
|  | with self.assertRaisesRegex( | 
|  | TypeError, | 
|  | "The '__dict__' attribute on 'MyMeta' instance does not support item assignment for caching 'prop' property.", | 
|  | ): | 
|  | MyClass.prop | 
|  |  | 
|  | def test_reuse_different_names(self): | 
|  | """Disallow this case because decorated function a would not be cached.""" | 
|  | with self.assertRaises(RuntimeError) as ctx: | 
|  | class ReusedCachedProperty: | 
|  | @py_functools.cached_property | 
|  | def a(self): | 
|  | pass | 
|  |  | 
|  | b = a | 
|  |  | 
|  | self.assertEqual( | 
|  | str(ctx.exception.__context__), | 
|  | str(TypeError("Cannot assign the same cached_property to two different names ('a' and 'b').")) | 
|  | ) | 
|  |  | 
|  | def test_reuse_same_name(self): | 
|  | """Reusing a cached_property on different classes under the same name is OK.""" | 
|  | counter = 0 | 
|  |  | 
|  | @py_functools.cached_property | 
|  | def _cp(_self): | 
|  | nonlocal counter | 
|  | counter += 1 | 
|  | return counter | 
|  |  | 
|  | class A: | 
|  | cp = _cp | 
|  |  | 
|  | class B: | 
|  | cp = _cp | 
|  |  | 
|  | a = A() | 
|  | b = B() | 
|  |  | 
|  | self.assertEqual(a.cp, 1) | 
|  | self.assertEqual(b.cp, 2) | 
|  | self.assertEqual(a.cp, 1) | 
|  |  | 
|  | def test_set_name_not_called(self): | 
|  | cp = py_functools.cached_property(lambda s: None) | 
|  | class Foo: | 
|  | pass | 
|  |  | 
|  | Foo.cp = cp | 
|  |  | 
|  | with self.assertRaisesRegex( | 
|  | TypeError, | 
|  | "Cannot use cached_property instance without calling __set_name__ on it.", | 
|  | ): | 
|  | Foo().cp | 
|  |  | 
|  | def test_access_from_class(self): | 
|  | self.assertIsInstance(CachedCostItem.cost, py_functools.cached_property) | 
|  |  | 
|  | def test_doc(self): | 
|  | self.assertEqual(CachedCostItem.cost.__doc__, "The cost of the item.") | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | unittest.main() |