Browse Source

Replace simple_async_mock with AsyncMock (#16180)

Python 3.8 has a native AsyncMock, use it instead of a custom
implementation.
Patrick Cloke 8 months ago
parent
commit
a8a46b1336

+ 1 - 0
changelog.d/16180.misc

@@ -0,0 +1 @@
+Use `AsyncMock` instead of custom code.

+ 49 - 48
tests/api/test_auth.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import pymacaroons
 
@@ -35,7 +35,6 @@ from synapse.types import Requester, UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 from tests.unittest import override_config
 from tests.utils import mock_getRawHeaders
 
@@ -60,16 +59,16 @@ class AuthTestCase(unittest.HomeserverTestCase):
         # this is overridden for the appservice tests
         self.store.get_app_service_by_token = Mock(return_value=None)
 
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.is_support_user = simple_async_mock(False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.is_support_user = AsyncMock(return_value=False)
 
     def test_get_user_by_req_user_valid_token(self) -> None:
         user_info = TokenLookupResult(
             user_id=self.test_user, token_id=5, device_id="device"
         )
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
-        self.store.get_user_locked_status = simple_async_mock(False)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -78,7 +77,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(requester.user.to_string(), self.test_user)
 
     def test_get_user_by_req_user_bad_token(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -91,7 +90,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_user_missing_token(self) -> None:
         user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
-        self.store.get_user_by_access_token = simple_async_mock(user_info)
+        self.store.get_user_by_access_token = AsyncMock(return_value=user_info)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -106,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             token="foobar", url="a_url", sender=self.test_user, ip_range_whitelist=None
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -125,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "192.168.10.10"
@@ -144,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ip_range_whitelist=IPSet(["192.168/16"]),
         )
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "131.111.8.42"
@@ -158,7 +157,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req_appservice_bad_token(self) -> None:
         self.store.get_app_service_by_token = Mock(return_value=None)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.args[b"access_token"] = [self.test_token]
@@ -172,7 +171,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_get_user_by_req_appservice_missing_token(self) -> None:
         app_service = Mock(token="foobar", url="a_url", sender=self.test_user)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
@@ -190,8 +189,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -210,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
         app_service.is_interested_in_user = Mock(return_value=False)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -234,10 +233,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a truth-y value
-        self.store.get_device = simple_async_mock({"hidden": False})
+        self.store.get_device = AsyncMock(return_value={"hidden": False})
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -266,10 +265,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         app_service.is_interested_in_user = Mock(return_value=True)
         self.store.get_app_service_by_token = Mock(return_value=app_service)
         # This just needs to return a truth-y value.
-        self.store.get_user_by_id = simple_async_mock({"is_guest": False})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": False})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
         # This also needs to just return a falsey value
-        self.store.get_device = simple_async_mock(None)
+        self.store.get_device = AsyncMock(return_value=None)
 
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
@@ -283,8 +282,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(failure.value.errcode, Codes.EXCLUSIVE)
 
     def test_get_user_by_req__puppeted_token__not_tracking_puppeted_mau(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -292,9 +291,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
-        self.store.get_user_locked_status = simple_async_mock(False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -304,8 +303,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
     def test_get_user_by_req__puppeted_token__tracking_puppeted_mau(self) -> None:
         self.auth._track_puppeted_user_ips = True
-        self.store.get_user_by_access_token = simple_async_mock(
-            TokenLookupResult(
+        self.store.get_user_by_access_token = AsyncMock(
+            return_value=TokenLookupResult(
                 user_id="@baldrick:matrix.org",
                 device_id="device",
                 token_id=5,
@@ -313,9 +312,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 token_used=True,
             )
         )
-        self.store.get_user_locked_status = simple_async_mock(False)
-        self.store.insert_client_ip = simple_async_mock(None)
-        self.store.mark_access_token_as_used = simple_async_mock(None)
+        self.store.get_user_locked_status = AsyncMock(return_value=False)
+        self.store.insert_client_ip = AsyncMock(return_value=None)
+        self.store.mark_access_token_as_used = AsyncMock(return_value=None)
         request = Mock(args={})
         request.getClientAddress.return_value.host = "127.0.0.1"
         request.args[b"access_token"] = [self.test_token]
@@ -324,7 +323,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(self.store.insert_client_ip.call_count, 2)
 
     def test_get_user_from_macaroon(self) -> None:
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -342,8 +341,8 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
     def test_get_guest_user_from_macaroon(self) -> None:
-        self.store.get_user_by_id = simple_async_mock({"is_guest": True})
-        self.store.get_user_by_access_token = simple_async_mock(None)
+        self.store.get_user_by_id = AsyncMock(return_value={"is_guest": True})
+        self.store.get_user_by_access_token = AsyncMock(return_value=None)
 
         user_id = "@baldrick:matrix.org"
         macaroon = pymacaroons.Macaroon(
@@ -373,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(lots_of_users)
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
 
         e = self.get_failure(
             self.auth_blocking.check_auth_blocking(), ResourceLimitError
@@ -383,25 +382,27 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.assertEqual(e.value.code, 403)
 
         # Ensure does not throw an error
-        self.store.get_monthly_active_count = simple_async_mock(small_number_of_users)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=small_number_of_users
+        )
         self.get_success(self.auth_blocking.check_auth_blocking())
 
     def test_blocking_mau__depending_on_user_type(self) -> None:
         self.auth_blocking._max_mau_value = 50
         self.auth_blocking._limit_usage_by_mau = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Support users allowed
         self.get_success(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.SUPPORT)
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Bots not allowed
         self.get_failure(
             self.auth_blocking.check_auth_blocking(user_type=UserTypes.BOT),
             ResourceLimitError,
         )
-        self.store.get_monthly_active_count = simple_async_mock(100)
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
         # Real users not allowed
         self.get_failure(self.auth_blocking.check_auth_blocking(), ResourceLimitError)
 
@@ -412,9 +413,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = False
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -443,9 +444,9 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._track_appservice_user_ips = True
 
-        self.store.get_monthly_active_count = simple_async_mock(100)
-        self.store.user_last_seen_monthly_active = simple_async_mock()
-        self.store.is_trial_user = simple_async_mock()
+        self.store.get_monthly_active_count = AsyncMock(return_value=100)
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
+        self.store.is_trial_user = AsyncMock(return_value=False)
 
         appservice = ApplicationService(
             "abcd",
@@ -473,7 +474,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
     def test_reserved_threepid(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._max_mau_value = 1
-        self.store.get_monthly_active_count = simple_async_mock(2)
+        self.store.get_monthly_active_count = AsyncMock(return_value=2)
         threepid = {"medium": "email", "address": "reserved@server.com"}
         unknown_threepid = {"medium": "email", "address": "unreserved@server.com"}
         self.auth_blocking._mau_limits_reserved_threepids = [threepid]

+ 16 - 15
tests/appservice/test_appservice.py

@@ -13,14 +13,13 @@
 # limitations under the License.
 import re
 from typing import Any, Generator
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 
 from synapse.appservice import ApplicationService, Namespace
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 
 def _regex(regex: str, exclusive: bool = True) -> Namespace:
@@ -43,8 +42,8 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
 
         self.store = Mock()
-        self.store.get_aliases_for_room = simple_async_mock([])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
 
     @defer.inlineCallbacks
     def test_regex_user_id_prefix_match(
@@ -127,10 +126,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#irc_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
                 yield self.service.is_interested_in_event(
@@ -182,10 +181,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         self.service.namespaces[ApplicationService.NS_ALIASES].append(
             _regex("#irc_.*:matrix.org")
         )
-        self.store.get_aliases_for_room = simple_async_mock(
-            ["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#xmpp_foobar:matrix.org", "#athing:matrix.org"]
         )
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertFalse(
             (
                 yield defer.ensureDeferred(
@@ -205,8 +204,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
         )
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         self.event.sender = "@irc_foobar:matrix.org"
-        self.store.get_aliases_for_room = simple_async_mock(["#irc_barfoo:matrix.org"])
-        self.store.get_local_users_in_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(
+            return_value=["#irc_barfoo:matrix.org"]
+        )
+        self.store.get_local_users_in_room = AsyncMock(return_value=[])
         self.assertTrue(
             (
                 yield self.service.is_interested_in_event(
@@ -235,10 +236,10 @@ class ApplicationServiceTestCase(unittest.TestCase):
     def test_member_list_match(self) -> Generator["defer.Deferred[Any]", object, None]:
         self.service.namespaces[ApplicationService.NS_USERS].append(_regex("@irc_.*"))
         # Note that @irc_fo:here is the AS user.
-        self.store.get_local_users_in_room = simple_async_mock(
-            ["@alice:here", "@irc_fo:here", "@bob:here"]
+        self.store.get_local_users_in_room = AsyncMock(
+            return_value=["@alice:here", "@irc_fo:here", "@bob:here"]
         )
-        self.store.get_aliases_for_room = simple_async_mock([])
+        self.store.get_aliases_for_room = AsyncMock(return_value=[])
 
         self.event.sender = "@xmpp_foobar:matrix.org"
         self.assertTrue(

+ 23 - 20
tests/appservice/test_scheduler.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import List, Optional, Sequence, Tuple, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from typing_extensions import TypeAlias
 
@@ -37,7 +37,6 @@ from synapse.types import DeviceListUpdates, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 
 from ..utils import MockClock
 
@@ -62,10 +61,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(True)
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=True)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -89,10 +90,10 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         events = [Mock(), Mock()]
 
         txn = Mock(id="idhere", service=service, events=events)
-        self.store.get_appservice_state = simple_async_mock(
-            ApplicationServiceState.DOWN
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.DOWN
         )
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -118,10 +119,12 @@ class ApplicationServiceSchedulerTransactionCtrlTestCase(unittest.TestCase):
         txn = Mock(id=txn_id, service=service, events=events)
 
         # mock methods
-        self.store.get_appservice_state = simple_async_mock(ApplicationServiceState.UP)
-        self.store.set_appservice_state = simple_async_mock(True)
-        txn.send = simple_async_mock(False)  # fails to send
-        self.store.create_appservice_txn = simple_async_mock(txn)
+        self.store.get_appservice_state = AsyncMock(
+            return_value=ApplicationServiceState.UP
+        )
+        self.store.set_appservice_state = AsyncMock(return_value=True)
+        txn.send = AsyncMock(return_value=False)  # fails to send
+        self.store.create_appservice_txn = AsyncMock(return_value=txn)
 
         # actual call
         self.successResultOf(defer.ensureDeferred(self.txnctrl.send(service, events)))
@@ -150,7 +153,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.as_api = Mock()
         self.store = Mock()
         self.service = Mock()
-        self.callback = simple_async_mock()
+        self.callback = AsyncMock()
         self.recoverer = _Recoverer(
             clock=cast(Clock, self.clock),
             as_api=self.as_api,
@@ -174,8 +177,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.recoverer.recover()
         # shouldn't have called anything prior to waiting for exp backoff
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(True)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=True)
+        txn.complete = AsyncMock(return_value=None)
         # wait for exp backoff
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
@@ -202,8 +205,8 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
 
         self.recoverer.recover()
         self.assertEqual(0, self.store.get_oldest_unsent_txn.call_count)
-        txn.send = simple_async_mock(False)
-        txn.complete = simple_async_mock(None)
+        txn.send = AsyncMock(return_value=False)
+        txn.complete = AsyncMock(return_value=None)
         self.clock.advance_time(2)
         self.assertEqual(1, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
@@ -216,7 +219,7 @@ class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase):
         self.assertEqual(3, txn.send.call_count)
         self.assertEqual(0, txn.complete.call_count)
         self.assertEqual(0, self.callback.call_count)
-        txn.send = simple_async_mock(True)  # successfully send the txn
+        txn.send = AsyncMock(return_value=True)  # successfully send the txn
         pop_txn = True  # returns the txn the first time, then no more.
         self.clock.advance_time(16)
         self.assertEqual(1, txn.send.call_count)  # new mock reset call count
@@ -244,7 +247,7 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: "MemoryReactor", clock: Clock, hs: HomeServer) -> None:
         self.scheduler = ApplicationServiceScheduler(hs)
         self.txn_ctrl = Mock()
-        self.txn_ctrl.send = simple_async_mock()
+        self.txn_ctrl.send = AsyncMock()
 
         # Replace instantiated _TransactionController instances with our Mock
         self.scheduler.txn_ctrl = self.txn_ctrl

+ 2 - 3
tests/events/test_presence_router.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -30,7 +30,6 @@ from synapse.types import JsonDict, StreamToken, create_requester
 from synapse.util import Clock
 
 from tests.handlers.test_sync import generate_sync_config
-from tests.test_utils import simple_async_mock
 from tests.unittest import (
     FederatingHomeserverTestCase,
     HomeserverTestCase,
@@ -157,7 +156,7 @@ class PresenceRouterTestCase(FederatingHomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         hs = self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,

+ 4 - 4
tests/handlers/test_appservice.py

@@ -36,7 +36,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, simple_async_mock
+from tests.test_utils import event_injection
 from tests.unittest import override_config
 from tests.utils import MockClock
 
@@ -399,7 +399,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
         self.hs = hs
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track any outgoing ephemeral events
-        self.send_mock = simple_async_mock()
+        self.send_mock = AsyncMock()
         hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]
 
         # Mock out application services, and allow defining our own in tests
@@ -897,7 +897,7 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
 
         # Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
         # will be sent over the wire
-        self.put_json = simple_async_mock()
+        self.put_json = AsyncMock()
         hs.get_application_service_api().put_json = self.put_json  # type: ignore[assignment]
 
         # Mock out application services, and allow defining our own in tests
@@ -1003,7 +1003,7 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # Mock the ApplicationServiceScheduler's _TransactionController's send method so that
         # we can track what's going out
-        self.send_mock = simple_async_mock()
+        self.send_mock = AsyncMock()
         hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock  # type: ignore[assignment]  # We assign to a method.
 
         # Define an application service for the tests

+ 5 - 6
tests/handlers/test_cas.py

@@ -12,7 +12,7 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.handlers.cas import CasResponse
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # These are a few constants that are used as config parameters in the tests.
@@ -61,7 +60,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         cas_response = CasResponse("test_user", {})
         request = _mock_request()
@@ -89,7 +88,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         cas_response = CasResponse("test_user", {})
@@ -129,7 +128,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         cas_response = CasResponse("föö", {})
         request = _mock_request()
@@ -160,7 +159,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         cas_response = CasResponse("test_user", {})

+ 21 - 21
tests/handlers/test_oauth_delegation.py

@@ -39,7 +39,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
 from tests.unittest import HomeserverTestCase, skip_unless
 from tests.utils import mock_getRawHeaders
 
@@ -147,7 +147,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_inactive_token(self) -> None:
         """The handler should return a 403 where the token is inactive."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": False},
@@ -166,7 +166,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_no_scope(self) -> None:
         """The handler should return a 403 where no scope is given."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": True},
@@ -185,7 +185,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user_no_subject(self) -> None:
         """The handler should return a 500 when no subject is present."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": True, "scope": " ".join([MATRIX_USER_SCOPE])},
@@ -204,7 +204,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_no_user_scope(self) -> None:
         """The handler should return a 500 when no subject is present."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -227,7 +227,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin_not_user(self) -> None:
         """The handler should raise when the scope has admin right but not user."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -251,7 +251,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin(self) -> None:
         """The handler should return a requester with admin rights."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -281,7 +281,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_admin_highest_privilege(self) -> None:
         """The handler should resolve to the most permissive scope."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -313,7 +313,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user(self) -> None:
         """The handler should return a requester with normal user rights."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -344,7 +344,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         """The handler should return a requester with normal user rights
         and an user ID matching the one specified in query param `user_id`"""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -378,7 +378,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_user_with_device(self) -> None:
         """The handler should return a requester with normal user rights and a device ID."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -408,7 +408,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_multiple_devices(self) -> None:
         """The handler should raise an error if multiple devices are found in the scope."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -433,7 +433,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_guest_not_allowed(self) -> None:
         """The handler should return an insufficient scope error."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -463,7 +463,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_active_guest_allowed(self) -> None:
         """The handler should return a requester with guest user rights and a device ID."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -499,19 +499,19 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         request.requestHeaders.getRawHeaders = mock_getRawHeaders()
 
         # The introspection endpoint is returning an error.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse(code=500, body=b"Internal Server Error")
         )
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint request fails.
-        self.http_client.request = simple_async_mock(raises=Exception())
+        self.http_client.request = AsyncMock(side_effect=Exception())
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint does not return a JSON object.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200, payload=["this is an array", "not an object"]
             )
@@ -520,7 +520,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
         self.assertEqual(error.value.code, 503)
 
         # The introspection endpoint does not return valid JSON.
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse(code=200, body=b"this is not valid JSON")
         )
         error = self.get_failure(self.auth.get_user_by_req(request), SynapseError)
@@ -528,7 +528,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
 
     def test_introspection_token_cache(self) -> None:
         access_token = "open_sesame"
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={"active": "true", "scope": "guest", "jti": access_token},
@@ -559,7 +559,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
 
         # test that if a cached token is expired, a fresh token will be pulled from authorizing server - first add a
         # token with a soon-to-expire `exp` field to the cache
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={
@@ -640,7 +640,7 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
     def test_cross_signing(self) -> None:
         """Try uploading device keys with OAuth delegation enabled."""
 
-        self.http_client.request = simple_async_mock(
+        self.http_client.request = AsyncMock(
             return_value=FakeResponse.json(
                 code=200,
                 payload={

+ 3 - 3
tests/handlers/test_oidc.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 import os
 from typing import Any, Awaitable, ContextManager, Dict, Optional, Tuple
-from unittest.mock import ANY, Mock, patch
+from unittest.mock import ANY, AsyncMock, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
 import pymacaroons
@@ -28,7 +28,7 @@ from synapse.util import Clock
 from synapse.util.macaroons import get_value_from_macaroon
 from synapse.util.stringutils import random_string
 
-from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils import FakeResponse, get_awaitable_result
 from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -164,7 +164,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler = hs.get_auth_handler()
         # Mock the complete SSO login method.
-        self.complete_sso_login = simple_async_mock()
+        self.complete_sso_login = AsyncMock()
         auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
 
         return hs

+ 6 - 7
tests/handlers/test_saml.py

@@ -13,7 +13,7 @@
 #  limitations under the License.
 
 from typing import Any, Dict, Optional, Set, Tuple
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import attr
 
@@ -25,7 +25,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 # Check if we have the dependencies to run the tests.
@@ -134,7 +133,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
@@ -164,7 +163,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
@@ -206,7 +205,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # mock out the error renderer too
         sso_handler = self.hs.get_sso_handler()
@@ -227,7 +226,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler and error renderer
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
         sso_handler = self.hs.get_sso_handler()
         sso_handler.render_error = Mock(return_value=None)  # type: ignore[assignment]
 
@@ -312,7 +311,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # stub out the auth handler
         auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()  # type: ignore[assignment]
+        auth_handler.complete_sso_login = AsyncMock()  # type: ignore[assignment]
 
         # The response doesn't have the proper userGroup or department.
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})

+ 3 - 6
tests/module_api/test_api.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.internet import defer
 from twisted.test.proto_helpers import MemoryReactor
@@ -33,7 +33,6 @@ from synapse.util import Clock
 
 from tests.events.test_presence_router import send_presence_update, sync_presence
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.test_utils import simple_async_mock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -70,7 +69,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         self.fed_transport_client = Mock(spec=["send_transaction"])
-        self.fed_transport_client.send_transaction = simple_async_mock({})
+        self.fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=self.fed_transport_client,
@@ -579,9 +578,7 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         """Test that the module API can join a remote room."""
         # Necessary to fake a remote join.
         fake_stream_id = 1
-        mocked_remote_join = simple_async_mock(
-            return_value=("fake-event-id", fake_stream_id)
-        )
+        mocked_remote_join = AsyncMock(return_value=("fake-event-id", fake_stream_id))
         self.hs.get_room_member_handler()._remote_join = mocked_remote_join  # type: ignore[assignment]
         fake_remote_host = f"{self.module_api.server_name}-remote"
 

+ 2 - 3
tests/push/test_bulk_push_rule_evaluator.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Any, Optional
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from parameterized import parameterized
 
@@ -28,7 +28,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
 
@@ -191,7 +190,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
         # Mock the method which calculates push rules -- we do this instead of
         # e.g. checking the results in the database because we want to ensure
         # that code isn't even running.
-        bulk_evaluator._action_for_event_by_user = simple_async_mock()  # type: ignore[assignment]
+        bulk_evaluator._action_for_event_by_user = AsyncMock()  # type: ignore[assignment]
 
         # Ensure no actions are generated!
         self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))

+ 2 - 3
tests/rest/client/test_notifications.py

@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -20,7 +20,6 @@ from synapse.rest.client import login, notifications, receipts, room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
-from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase
 
 
@@ -45,7 +44,7 @@ class HTTPPusherTests(HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         # Mock out the calls over federation.
         fed_transport_client = Mock(spec=["send_transaction"])
-        fed_transport_client.send_transaction = simple_async_mock({})
+        fed_transport_client.send_transaction = AsyncMock(return_value={})
 
         return self.setup_test_homeserver(
             federation_transport_client=fed_transport_client,

+ 2 - 3
tests/storage/test_background_update.py

@@ -32,7 +32,6 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import simple_async_mock
 from tests.unittest import override_config
 
 
@@ -348,8 +347,8 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
 
         # Mock out the AsyncContextManager
         class MockCM:
-            __aenter__ = simple_async_mock(return_value=None)
-            __aexit__ = simple_async_mock(return_value=None)
+            __aenter__ = AsyncMock(return_value=None)
+            __aexit__ = AsyncMock(return_value=None)
 
         self._update_ctx_manager = MockCM
 

+ 1 - 18
tests/test_utils/__init__.py

@@ -19,8 +19,7 @@ import json
 import sys
 import warnings
 from binascii import unhexlify
-from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
-from unittest.mock import Mock
+from typing import TYPE_CHECKING, Awaitable, Callable, Tuple, TypeVar
 
 import attr
 import zope.interface
@@ -62,10 +61,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
     """
     warnings.simplefilter("error", RuntimeWarning)
 
-    # unraisablehook was added in Python 3.8.
-    if not hasattr(sys, "unraisablehook"):
-        return lambda: None
-
     # State shared between unraisablehook and check_for_unraisable_exceptions.
     unraisable_exceptions = []
     orig_unraisablehook = sys.unraisablehook
@@ -88,18 +83,6 @@ def setup_awaitable_errors() -> Callable[[], None]:
     return cleanup
 
 
-def simple_async_mock(
-    return_value: Optional[TV] = None, raises: Optional[Exception] = None
-) -> Mock:
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 # Type ignore: it does not fully implement IResponse, but is good enough for tests
 @zope.interface.implementer(IResponse)
 @attr.s(slots=True, frozen=True, auto_attribs=True)