8000 bpo-17013: Extend Mock.called to allow waiting for calls · python/cpython@eab4041 · GitHub
[go: up one dir, main page]

Skip to content

Commit eab4041

Browse files
committed
bpo-17013: Extend Mock.called to allow waiting for calls
New methods allow tests to wait for calls executing in other threads.
1 parent 138ccbb commit eab4041

File tree

5 files changed

+197
-8
lines changed

5 files changed

+197
-8
lines changed

Doc/library/unittest.mock.rst

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ the *new_callable* argument to :func:`patch`.
493493

494494
.. attribute:: called
495495

496-
A boolean representing whether or not the mock object has been called:
496+
A boolean-like object representing whether or not the mock object has been called:
497497

498498
>>> mock = Mock(return_value=None)
499499
>>> mock.called
@@ -502,6 +502,15 @@ the *new_callable* argument to :func:`patch`.
502502
>>> mock.called
503503
True
504504

505+
The object gives access to methods helpful in multithreaded tests:
506+
507+
- :meth:`wait(/, skip=0, timeout=None)` asserts that mock is called
508+
*skip* times during *timeout*
509+
510+
- :meth:`wait_for(predicate, /, timeout=None)` asserts that
511+
*predicate* was ``True`` at least once during the timeout;
512+
*predicate* receives exactly one positional argument: the mock itself
513+
505514
.. attribute:: call_count
506515

507516
An integer telling you how many times the mock object has been called:

Lib/unittest/mock.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,10 @@
3232
import pprint
3333
import sys
3434
import builtins
35+
import threading
3536
from types import CodeType, ModuleType, MethodType
3637
from unittest.util import safe_repr
37-
from functools import wraps, partial
38+
from functools import wraps, partial, total_ordering
3839

3940

4041
_builtins = {name for name in dir(builtins) if not name.startswith('_')}
@@ -217,7 +218,7 @@ def reset_mock():
217218
if _is_instance_mock(ret) and not ret is mock:
218219
ret.reset_mock()
219220

220-
funcopy.called = False
221+
funcopy.called = _CallEvent(mock)
221222
funcopy.call_count = 0
222223
funcopy.call_args = None
223224
funcopy.call_args_list = _CallList()
@@ -439,7 +440,7 @@ def __init__(
439440
__dict__['_mock_wraps'] = wraps
440441
__dict__['_mock_delegate'] = None
441442

442-
__dict__['_mock_called'] = False
443+
__dict__['_mock_called'] = _CallEvent(self)
443444
__dict__['_mock_call_args'] = None
444445
__dict__['_mock_call_count'] = 0
445446
__dict__['_mock_call_args_list'] = _CallList()
@@ -577,7 +578,7 @@ def reset_mock(self, visited=None,*, return_value=False, side_effect=False):
577578
return
578579
visited.append(id(self))
579580

580-
self.called = False
581+
self.called = _CallEvent(self)
581582
self.call_args = None
582583
self.call_count = 0
583584
self.mock_calls = _CallList()
@@ -1093,8 +1094,8 @@ def _mock_call(self, /, *args, **kwargs):
10931094
return self._execute_mock_call(*args, **kwargs)
10941095

10951096
def _increment_mock_call(self, /, *args, **kwargs):
1096-
self.called = True
10971097
self.call_count += 1
1098+
self.called._notify()
10981099

10991100
# handle call_args
11001101
# needs to be set here so assertions on call arguments pass before
@@ -2358,6 +2359,67 @@ def _format_call_signature(name, args, kwargs):
23582359
return message % formatted_args
23592360

23602361

2362+
@total_ordering
2363+
class _CallEvent(object):
2364+
def __init__(self, mock):
2365+
self._mock = mock
2366+
self._condition = threading.Condition()
2367+
2368+
def wait(self, /, skip=0, timeout=None):
2369+
"""
2370+
Wait for any call.
2371+
2372+
:param skip: How many calls will be skipped.
2373+
As a result, the mock should be called at least
2374+
``skip + 1`` times.
2375+
"""
2376+
def predicate(mock):
2377+
return mock.call_count > skip
2378+
2379+
self.wait_for(predicate, timeout=timeout)
2380+
2381+
def wait_for(self, predicate, /, timeout=None):
2382+
"""
2383+
Wait for a given predicate to become True.
2384+
2385+
:param predicate: A callable that receives mock which result
2386+
will be interpreted as a boolean value.
2387+
The final predicate value is the return value.
2388+
"""
2389+
try:
2390+
self._condition.acquire()
2391+
2392+
def _predicate():
2393+
return predicate(self._mock)
2394+
2395+
b = self._condition.wait_for(_predicate, timeout)
2396+
2397+
if not b:
2398+
msg = (f"{self._mock._mock_name or 'mock'} was not called before"
2399+
f" timeout({timeout}).")
2400+
raise AssertionError(msg)
2401+
finally:
2402+
self._condition.release()
2403+
2404+
def __bool__(self):
2405+
return self._mock.call_count != 0
2406+
2407+
def __eq__(self, other):
2408+
return bool(self) == other
2409+
2410+
def __lt__(self, other):
2411+
return bool(self) < other
2412+
2413+
def __repr__(self):
2414+
return repr(bool(self))
2415+
2416+
def _notify(self):
2417+
try:
2418+
self._condition.acquire()
2419+
self._condition.notify_all()
2420+
finally:
2421+
self._condition.release()
2422+
23612423

23622424
class _Call(tuple):
23632425
"""

Lib/unittest/test/testmock/support.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import concurrent.futures
2+
import time
3+
4+
15
target = {'foo': 'FOO'}
26

37

@@ -14,3 +18,15 @@ def wibble(self): pass
1418

1519
class X(object):
1620
pass
21+
22+
23+
def call_after_delay(func, /, *args, **kwargs):
24+
time.sleep(kwargs.pop('delay'))
25+
func(*args, **kwargs)
26+
27+
28+
def run_async(func, /, *args, executor=None, delay=0, **kwargs):
29+
if executor is None:
30+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=5)
31+
32+
executor.submit(call_after_delay, func, *args, **kwargs, delay=delay)

Lib/unittest/test/testmock/testmock.py

Lines changed: 102 additions & 2 deletions
10000
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
import concurrent.futures
12
import copy
23
import re
34
import sys
45
import tempfile
56

6-
from test.support import ALWAYS_EQ
7+
from test.support import ALWAYS_EQ, start_threads
78
import unittest
8-
from unittest.test.testmock.support import is_instance
9+
from unittest.test.testmock.support import is_instance, run_async
910
from unittest import mock
1011
from unittest.mock import (
1112
call, DEFAULT, patch, sentinel,
@@ -2059,6 +2060,105 @@ def trace(frame, event, arg): # pragma: no cover
20592060
obj = mock(spec=Something)
20602061
self.assertIsInstance(obj, Something)
20612062

2063+
def test_wait_until_called(self):
2064+
mock = Mock(spec=Something)()
2065+
run_async(mock.method_1, delay=0.01)
2066+
mock.method_1.called.wait()
2067+
mock.method_1.assert_called_once()
2068+
2069+
def test_wait_until_called_called_before(self):
2070+
mock = Mock(spec=Something)()
2071+
mock.method_1()
2072+
mock.method_1.wait_until_called()
2073+
mock.method_1.assert_called_once()
2074+
2075+
def test_wait_until_called_magic_method(self):
2076+
mock = MagicMock(spec=Something)()
2077+
run_async(mock.method_1.__str__, delay=0.01)
2078+
mock.method_1.__str__.called.wait()
2079+
mock.method_1.__str__.assert_called_once()
2080+
2081+
def test_wait_until_called_timeout(self):
2082+
mock = Mock(spec=Something)()
2083+
run_async(mock.method_1, delay=0.2)
2084+
2085+
with self.assertRaises(AssertionError):
2086+
mock.method_1.called.wait(timeout=0.1)
2087+
2088+
mock.method_1.assert_not_called()
2089+
mock.method_1.called.wait()
2090+
mock.method_1.assert_called_once()
2091+
2092+
def test_wait_until_any_call_positional(self):
2093+
mock = Mock(spec=Something)()
2094+
run_async(mock.method_1, 1, delay=0.1)
2095+
run_async(mock.method_1, 2, delay=0.2)
2096+
run_async(mock.method_1, 3, delay=0.3)
2097+
2098+
for arg in (1, 2, 3):
2099+
self.assertNotIn(call(arg), mock.method_1.mock_calls)
2100+
mock.method_1.called.wait_for(lambda m: call(arg) in m.call_args_list)
2101+
mock.method_1.assert_called_with(arg)
2102+
2103+
def test_wait_until_any_call_keywords(self):
2104+
mock = Mock(spec=Something)()
2105+
run_async(mock.method_1, a=1, delay=0.1)
2106+
run_async(mock.method_1, a=2, delay=0.2)
2107+
run_async(mock.method_1, a=3, delay=0.3)
2108+
2109+
for arg in (1, 2, 3):
2110+
self.assertNotIn(call(arg), mock.method_1.mock_calls)
2111+
mock.method_1.called.wait_for(lambda m: call(a=arg) in m.call_args_list)
2112+
mock.method_1.assert_called_with(a=arg)
2113+
2114+
def test_wait_until_any_call_no_argument(self):
2115+
mock = Mock(spec=Something)()
2116+
mock.method_1(1)
2117+
mock.method_1assert_called_once_with(1)
2118+
2119+
with self.assertRaises(AssertionError):
2120+
mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)
2121+
2122+
mock.method_1()
2123+
mock.method_1.called.wait_for(lambda m: call() in m.call_args_list, timeout=0.01)
2124+
2125+
def test_called_is_boolean_like(self):
2126+
mock = Mock(spec=Something)()
2127+
2128+
self.assertFalse(mock.method_1.called)
2129+
2130+
self.assertEqual(mock.method_1.called, False)
2131+
self.assertEqual(mock.method_1.called, 0)
2132+
self.assertEqual(mock.method_1.called, 0.0)
2133+
2134+
self.assertLess(mock.method_1.called, 1)
2135+
self.assertLess(mock.method_1.called, 1.0)
2136+
2137+
self.assertLessEqual(mock.method_1.called, False)
2138+
self.assertLessEqual(mock.method_1.called, 0)
2139+
self.assertLessEqual(mock.method_1.called, 0.0)
2140+
2141+
self.assertEqual(str(mock.method_1.called), str(False))
2142+
self.assertEqual(repr(mock.method_1.called), repr(False))
2143+
2144+
mock.method_1()
2145+
2146+
self.assertTrue(mock.method_1.called)
2147+
2148+
self.assertEqual(mock.method_1.called, True)
2149+
self.assertEqual(mock.method_1.called, 1)
2150+
self.assertEqual(mock.method_1.called, 1.0)
2151+
2152+
self.assertGreater(mock.method_1.called, 0)
2153+
self.assertGreater(mock.method_1.called, 0.0)
2154+
2155+
self.assertGreaterEqual(mock.method_1.called, True)
2156+
self.assertGreaterEqual(mock.method_1.called, 1)
2157+
self.assertGreaterEqual(mock.method_1.called, 1.0)
2158+
2159+
self.assertEqual(str(mock.method_1.called), str(True))
2160+
self.assertEqual(repr(mock.method_1.called), repr(True))
2161+
20622162

20632163
if __name__ == '__main__':
20642164
unittest.main()
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Extend :attr:`called` of :class:`Mock.called` to wait for the calls in
2+
multithreaded tests. Patch by Ilya Kulakov.

0 commit comments

Comments
 (0)
0