Skip to content

Commit a2cb0be

Browse files
sobolevncjw296
authored andcommitted
gh-98086: Now patch.dict can decorate async functions (#98095)
Backports: 67b4d2772c5124b908f8ed9b13166a79bbeb88d2 Signed-off-by: Chris Withers <chris@simplistix.co.uk>
1 parent 4d3f197 commit a2cb0be

File tree

3 files changed

+36
-0
lines changed

3 files changed

+36
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Make sure ``patch.dict()`` can be applied on async functions.

mock/mock.py

+18
Original file line numberDiff line numberDiff line change
@@ -1839,6 +1839,12 @@ def __init__(self, in_dict, values=(), clear=False, **kwargs):
18391839
def __call__(self, f):
18401840
if isinstance(f, type):
18411841
return self.decorate_class(f)
1842+
if inspect.iscoroutinefunction(f):
1843+
return self.decorate_async_callable(f)
1844+
return self.decorate_callable(f)
1845+
1846+
1847+
def decorate_callable(self, f):
18421848
@wraps(f)
18431849
def _inner(*args, **kw):
18441850
self._patch_dict()
@@ -1850,6 +1856,18 @@ def _inner(*args, **kw):
18501856
return _inner
18511857

18521858

1859+
def decorate_async_callable(self, f):
1860+
@wraps(f)
1861+
async def _inner(*args, **kw):
1862+
self._patch_dict()
1863+
try:
1864+
return await f(*args, **kw)
1865+
finally:
1866+
self._unpatch_dict()
1867+
1868+
return _inner
1869+
1870+
18531871
def decorate_class(self, klass):
18541872
for attr in dir(klass):
18551873
attr_value = getattr(klass, attr)

mock/tests/testasync.py

+17
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,23 @@ async def test_async():
159159

160160
run(test_async())
161161

162+
def test_patch_dict_async_def(self):
163+
foo = {'a': 'a'}
164+
@patch.dict(foo, {'a': 'b'})
165+
async def test_async():
166+
self.assertEqual(foo['a'], 'b')
167+
168+
self.assertTrue(iscoroutinefunction(test_async))
169+
run(test_async())
170+
171+
def test_patch_dict_async_def_context(self):
172+
foo = {'a': 'a'}
173+
async def test_async():
174+
with patch.dict(foo, {'a': 'b'}):
175+
self.assertEqual(foo['a'], 'b')
176+
177+
run(test_async())
178+
162179

163180
class AsyncMockTest(unittest.TestCase):
164181
def test_iscoroutinefunction_default(self):

0 commit comments

Comments
 (0)