|
@@ -13,7 +13,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from typing import Dict, Iterable, List, Optional
|
|
|
-from unittest.mock import Mock
|
|
|
+from unittest.mock import AsyncMock, Mock
|
|
|
|
|
|
from parameterized import parameterized
|
|
|
|
|
@@ -36,7 +36,7 @@ from synapse.util import Clock
|
|
|
from synapse.util.stringutils import random_string
|
|
|
|
|
|
from tests import unittest
|
|
|
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
|
|
|
+from tests.test_utils import event_injection, simple_async_mock
|
|
|
from tests.unittest import override_config
|
|
|
from tests.utils import MockClock
|
|
|
|
|
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
self.mock_store = Mock()
|
|
|
- self.mock_as_api = Mock()
|
|
|
+ self.mock_as_api = AsyncMock()
|
|
|
self.mock_scheduler = Mock()
|
|
|
hs = Mock()
|
|
|
hs.get_datastores.return_value = Mock(main=self.mock_store)
|
|
|
- self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
|
|
|
- self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
|
|
|
- self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
|
|
|
- None
|
|
|
- )
|
|
|
+ self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
|
|
|
+ self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
|
|
|
+ self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
|
|
|
hs.get_application_service_api.return_value = self.mock_as_api
|
|
|
hs.get_application_service_scheduler.return_value = self.mock_scheduler
|
|
|
hs.get_clock.return_value = MockClock()
|
|
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
self._mkservice(is_interested_in_event=False),
|
|
|
]
|
|
|
|
|
|
- self.mock_as_api.query_user.return_value = make_awaitable(True)
|
|
|
+ self.mock_as_api.query_user.return_value = True
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_user_by_id.return_value = make_awaitable([])
|
|
|
+ self.mock_store.get_user_by_id = AsyncMock(return_value=[])
|
|
|
|
|
|
event = Mock(
|
|
|
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
|
|
|
)
|
|
|
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
|
|
|
- make_awaitable((0, {})),
|
|
|
- make_awaitable((1, {event.event_id: 0})),
|
|
|
- ]
|
|
|
- self.mock_store.get_events_as_list.side_effect = [
|
|
|
- make_awaitable([]),
|
|
|
- make_awaitable([event]),
|
|
|
- ]
|
|
|
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
|
|
|
+ side_effect=[
|
|
|
+ (0, {}),
|
|
|
+ (1, {event.event_id: 0}),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ self.mock_store.get_events_as_list = AsyncMock(
|
|
|
+ side_effect=[
|
|
|
+ [],
|
|
|
+ [event],
|
|
|
+ ]
|
|
|
+ )
|
|
|
self.handler.notify_interested_services(RoomStreamToken(None, 1))
|
|
|
|
|
|
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
|
|
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
services = [self._mkservice(is_interested_in_event=True)]
|
|
|
services[0].is_interested_in_user.return_value = True
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_user_by_id.return_value = make_awaitable(None)
|
|
|
+ self.mock_store.get_user_by_id = AsyncMock(return_value=None)
|
|
|
|
|
|
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
|
|
- self.mock_as_api.query_user.return_value = make_awaitable(True)
|
|
|
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
|
|
|
- make_awaitable((0, {event.event_id: 0})),
|
|
|
- ]
|
|
|
- self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
|
|
|
+ self.mock_as_api.query_user.return_value = True
|
|
|
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
|
|
|
+ side_effect=[
|
|
|
+ (0, {event.event_id: 0}),
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
|
|
|
self.handler.notify_interested_services(RoomStreamToken(None, 0))
|
|
|
|
|
|
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
|
|
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
services = [self._mkservice(is_interested_in_event=True)]
|
|
|
services[0].is_interested_in_user.return_value = True
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
|
|
|
+ self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
|
|
|
|
|
|
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
|
|
|
- self.mock_as_api.query_user.return_value = make_awaitable(True)
|
|
|
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
|
|
|
- make_awaitable((0, [event], {event.event_id: 0})),
|
|
|
- ]
|
|
|
+ self.mock_as_api.query_user.return_value = True
|
|
|
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
|
|
|
+ side_effect=[
|
|
|
+ (0, [event], {event.event_id: 0}),
|
|
|
+ ]
|
|
|
+ )
|
|
|
|
|
|
self.handler.notify_interested_services(RoomStreamToken(None, 0))
|
|
|
|
|
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
self._mkservice_alias(is_room_alias_in_namespace=False),
|
|
|
]
|
|
|
|
|
|
- self.mock_as_api.query_alias.return_value = make_awaitable(True)
|
|
|
+ self.mock_as_api.query_alias = AsyncMock(return_value=True)
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
|
|
|
- Mock(room_id=room_id, servers=servers)
|
|
|
+ self.mock_store.get_association_from_room_alias = AsyncMock(
|
|
|
+ return_value=Mock(room_id=room_id, servers=servers)
|
|
|
)
|
|
|
|
|
|
result = self.successResultOf(
|
|
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
def test_get_3pe_protocols_protocol_no_response(self) -> None:
|
|
|
service = self._mkservice(False, ["my-protocol"])
|
|
|
self.mock_store.get_app_services.return_value = [service]
|
|
|
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
|
|
|
+ self.mock_as_api.get_3pe_protocol.return_value = None
|
|
|
response = self.successResultOf(
|
|
|
defer.ensureDeferred(self.handler.get_3pe_protocols())
|
|
|
)
|
|
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
def test_get_3pe_protocols_select_one_protocol(self) -> None:
|
|
|
service = self._mkservice(False, ["my-protocol"])
|
|
|
self.mock_store.get_app_services.return_value = [service]
|
|
|
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
|
|
- {"x-protocol-data": 42, "instances": []}
|
|
|
- )
|
|
|
+ self.mock_as_api.get_3pe_protocol.return_value = {
|
|
|
+ "x-protocol-data": 42,
|
|
|
+ "instances": [],
|
|
|
+ }
|
|
|
response = self.successResultOf(
|
|
|
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
|
|
|
)
|
|
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
def test_get_3pe_protocols_one_protocol(self) -> None:
|
|
|
service = self._mkservice(False, ["my-protocol"])
|
|
|
self.mock_store.get_app_services.return_value = [service]
|
|
|
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
|
|
- {"x-protocol-data": 42, "instances": []}
|
|
|
- )
|
|
|
+ self.mock_as_api.get_3pe_protocol.return_value = {
|
|
|
+ "x-protocol-data": 42,
|
|
|
+ "instances": [],
|
|
|
+ }
|
|
|
response = self.successResultOf(
|
|
|
defer.ensureDeferred(self.handler.get_3pe_protocols())
|
|
|
)
|
|
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
service_one = self._mkservice(False, ["my-protocol"])
|
|
|
service_two = self._mkservice(False, ["other-protocol"])
|
|
|
self.mock_store.get_app_services.return_value = [service_one, service_two]
|
|
|
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
|
|
|
- {"x-protocol-data": 42, "instances": []}
|
|
|
- )
|
|
|
+ self.mock_as_api.get_3pe_protocol.return_value = {
|
|
|
+ "x-protocol-data": 42,
|
|
|
+ "instances": [],
|
|
|
+ }
|
|
|
response = self.successResultOf(
|
|
|
defer.ensureDeferred(self.handler.get_3pe_protocols())
|
|
|
)
|
|
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
interested_service = self._mkservice(is_interested_in_event=True)
|
|
|
services = [interested_service]
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
|
|
- 579
|
|
|
- )
|
|
|
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
|
|
|
|
|
|
event = Mock(event_id="event_1")
|
|
|
- self.event_source.sources.receipt.get_new_events_as.return_value = (
|
|
|
- make_awaitable(([event], None))
|
|
|
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
|
|
|
+ return_value=([event], None)
|
|
|
)
|
|
|
|
|
|
self.handler.notify_interested_services_ephemeral(
|
|
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
services = [interested_service]
|
|
|
|
|
|
self.mock_store.get_app_services.return_value = services
|
|
|
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
|
|
|
- 580
|
|
|
- )
|
|
|
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
|
|
|
|
|
|
event = Mock(event_id="event_1")
|
|
|
- self.event_source.sources.receipt.get_new_events_as.return_value = (
|
|
|
- make_awaitable(([event], None))
|
|
|
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
|
|
|
+ return_value=([event], None)
|
|
|
)
|
|
|
|
|
|
self.handler.notify_interested_services_ephemeral(
|
|
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
|
|
A mock representing the ApplicationService.
|
|
|
"""
|
|
|
service = Mock()
|
|
|
- service.is_interested_in_event.return_value = make_awaitable(
|
|
|
- is_interested_in_event
|
|
|
- )
|
|
|
+ service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
|
|
|
service.token = "mock_service_token"
|
|
|
service.url = "mock_service_url"
|
|
|
service.protocols = protocols
|