Browse Source

Convert the device message and pagination handlers to async/await. (#7678)

Patrick Cloke 3 years ago
parent
commit
98c4e35e3c
3 changed files with 19 additions and 31 deletions
  1. 1 0
      changelog.d/7678.misc
  2. 10 15
      synapse/handlers/devicemessage.py
  3. 8 16
      synapse/handlers/pagination.py

+ 1 - 0
changelog.d/7678.misc

@@ -0,0 +1 @@
+Convert the device message and pagination handlers to async/await.

+ 10 - 15
synapse/handlers/devicemessage.py

@@ -18,8 +18,6 @@ from typing import Any, Dict
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -51,8 +49,7 @@ class DeviceMessageHandler(object):
 
         self._device_list_updater = hs.get_device_handler().device_list_updater
 
-    @defer.inlineCallbacks
-    def on_direct_to_device_edu(self, origin, content):
+    async def on_direct_to_device_edu(self, origin, content):
         local_messages = {}
         sender_user_id = content["sender"]
         if origin != get_domain_from_id(sender_user_id):
@@ -82,11 +79,11 @@ class DeviceMessageHandler(object):
             }
             local_messages[user_id] = messages_by_device
 
-            yield self._check_for_unknown_devices(
+            await self._check_for_unknown_devices(
                 message_type, sender_user_id, by_device
             )
 
-        stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
+        stream_id = await self.store.add_messages_from_remote_to_device_inbox(
             origin, message_id, local_messages
         )
 
@@ -94,14 +91,13 @@ class DeviceMessageHandler(object):
             "to_device_key", stream_id, users=local_messages.keys()
         )
 
-    @defer.inlineCallbacks
-    def _check_for_unknown_devices(
+    async def _check_for_unknown_devices(
         self,
         message_type: str,
         sender_user_id: str,
         by_device: Dict[str, Dict[str, Any]],
     ):
-        """Checks inbound device messages for unkown remote devices, and if
+        """Checks inbound device messages for unknown remote devices, and if
         found marks the remote cache for the user as stale.
         """
 
@@ -115,7 +111,7 @@ class DeviceMessageHandler(object):
             requesting_device_ids.add(device_id)
 
         # Check if we are tracking the devices of the remote user.
-        room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+        room_ids = await self.store.get_rooms_for_user(sender_user_id)
         if not room_ids:
             logger.info(
                 "Received device message from remote device we don't"
@@ -127,7 +123,7 @@ class DeviceMessageHandler(object):
 
         # If we are tracking check that we know about the sending
         # devices.
-        cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+        cached_devices = await self.store.get_cached_devices_for_user(sender_user_id)
 
         unknown_devices = requesting_device_ids - set(cached_devices)
         if unknown_devices:
@@ -136,15 +132,14 @@ class DeviceMessageHandler(object):
                 sender_user_id,
                 unknown_devices,
             )
-            yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+            await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
 
             # Immediately attempt a resync in the background
             run_in_background(
                 self._device_list_updater.user_device_resync, sender_user_id
             )
 
-    @defer.inlineCallbacks
-    def send_device_message(self, sender_user_id, message_type, messages):
+    async def send_device_message(self, sender_user_id, message_type, messages):
         set_tag("number_of_messages", len(messages))
         set_tag("sender", sender_user_id)
         local_messages = {}
@@ -183,7 +178,7 @@ class DeviceMessageHandler(object):
                 }
 
         log_kv({"local_messages": local_messages})
-        stream_id = yield self.store.add_messages_to_device_inbox(
+        stream_id = await self.store.add_messages_to_device_inbox(
             local_messages, remote_edu_contents
         )
 

+ 8 - 16
synapse/handlers/pagination.py

@@ -15,7 +15,6 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
 from twisted.python.failure import Failure
 
 from synapse.api.constants import EventTypes, Membership
@@ -97,8 +96,7 @@ class PaginationHandler(object):
                     job["longest_max_lifetime"],
                 )
 
-    @defer.inlineCallbacks
-    def purge_history_for_rooms_in_range(self, min_ms, max_ms):
+    async def purge_history_for_rooms_in_range(self, min_ms, max_ms):
         """Purge outdated events from rooms within the given retention range.
 
         If a default retention policy is defined in the server's configuration and its
@@ -137,7 +135,7 @@ class PaginationHandler(object):
             include_null,
         )
 
-        rooms = yield self.store.get_rooms_for_retention_period_in_range(
+        rooms = await self.store.get_rooms_for_retention_period_in_range(
             min_ms, max_ms, include_null
         )
 
@@ -165,9 +163,9 @@ class PaginationHandler(object):
             # Figure out what token we should start purging at.
             ts = self.clock.time_msec() - max_lifetime
 
-            stream_ordering = yield self.store.find_first_stream_ordering_after_ts(ts)
+            stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
 
-            r = yield self.store.get_room_event_before_stream_ordering(
+            r = await self.store.get_room_event_before_stream_ordering(
                 room_id, stream_ordering,
             )
             if not r:
@@ -227,8 +225,7 @@ class PaginationHandler(object):
         )
         return purge_id
 
-    @defer.inlineCallbacks
-    def _purge_history(self, purge_id, room_id, token, delete_local_events):
+    async def _purge_history(self, purge_id, room_id, token, delete_local_events):
         """Carry out a history purge on a room.
 
         Args:
@@ -237,14 +234,11 @@ class PaginationHandler(object):
             token (str): topological token to delete events before
             delete_local_events (bool): True to delete local events as well as
                 remote ones
-
-        Returns:
-            Deferred
         """
         self._purges_in_progress_by_room.add(room_id)
         try:
-            with (yield self.pagination_lock.write(room_id)):
-                yield self.storage.purge_events.purge_history(
+            with await self.pagination_lock.write(room_id):
+                await self.storage.purge_events.purge_history(
                     room_id, token, delete_local_events
                 )
             logger.info("[purge] complete")
@@ -282,9 +276,7 @@ class PaginationHandler(object):
             await self.store.get_room_version_id(room_id)
 
             # first check that we have no users in this room
-            joined = await defer.maybeDeferred(
-                self.store.is_host_joined, room_id, self._server_name
-            )
+            joined = await self.store.is_host_joined(room_id, self._server_name)
 
             if joined:
                 raise SynapseError(400, "Users are still joined to this room")