Browse Source

Add type hints to the receipts and user directory handlers. (#8976)

Patrick Cloke 3 years ago
parent
commit
31b1905e13
4 changed files with 64 additions and 32 deletions
  1. 1 0
      changelog.d/8976.misc
  2. 2 0
      mypy.ini
  3. 19 11
      synapse/handlers/receipts.py
  4. 42 21
      synapse/handlers/user_directory.py

+ 1 - 0
changelog.d/8976.misc

@@ -0,0 +1 @@
+Add type hints to the receipts and user directory handlers.

+ 2 - 0
mypy.ini

@@ -45,6 +45,7 @@ files =
   synapse/handlers/presence.py,
   synapse/handlers/presence.py,
   synapse/handlers/profile.py,
   synapse/handlers/profile.py,
   synapse/handlers/read_marker.py,
   synapse/handlers/read_marker.py,
+  synapse/handlers/receipts.py,
   synapse/handlers/register.py,
   synapse/handlers/register.py,
   synapse/handlers/room.py,
   synapse/handlers/room.py,
   synapse/handlers/room_list.py,
   synapse/handlers/room_list.py,
@@ -53,6 +54,7 @@ files =
   synapse/handlers/saml_handler.py,
   synapse/handlers/saml_handler.py,
   synapse/handlers/sso.py,
   synapse/handlers/sso.py,
   synapse/handlers/sync.py,
   synapse/handlers/sync.py,
+  synapse/handlers/user_directory.py,
   synapse/handlers/ui_auth,
   synapse/handlers/ui_auth,
   synapse/http/client.py,
   synapse/http/client.py,
   synapse/http/federation/matrix_federation_agent.py,
   synapse/http/federation/matrix_federation_agent.py,

+ 19 - 11
synapse/handlers/receipts.py

@@ -13,17 +13,20 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import logging
 import logging
-from typing import List, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 
 from synapse.appservice import ApplicationService
 from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
 from synapse.handlers._base import BaseHandler
 from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
 from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
 
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
 class ReceiptsHandler(BaseHandler):
 class ReceiptsHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         super().__init__(hs)
 
 
         self.server_name = hs.config.server_name
         self.server_name = hs.config.server_name
@@ -36,7 +39,7 @@ class ReceiptsHandler(BaseHandler):
         self.clock = self.hs.get_clock()
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
         self.state = hs.get_state_handler()
 
 
-    async def _received_remote_receipt(self, origin, content):
+    async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
         """Called when we receive an EDU of type m.receipt from a remote HS.
         """Called when we receive an EDU of type m.receipt from a remote HS.
         """
         """
         receipts = []
         receipts = []
@@ -63,11 +66,11 @@ class ReceiptsHandler(BaseHandler):
 
 
         await self._handle_new_receipts(receipts)
         await self._handle_new_receipts(receipts)
 
 
-    async def _handle_new_receipts(self, receipts):
+    async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
         """Takes a list of receipts, stores them and informs the notifier.
         """Takes a list of receipts, stores them and informs the notifier.
         """
         """
-        min_batch_id = None
-        max_batch_id = None
+        min_batch_id = None  # type: Optional[int]
+        max_batch_id = None  # type: Optional[int]
 
 
         for receipt in receipts:
         for receipt in receipts:
             res = await self.store.insert_receipt(
             res = await self.store.insert_receipt(
@@ -89,7 +92,8 @@ class ReceiptsHandler(BaseHandler):
             if max_batch_id is None or max_persisted_id > max_batch_id:
             if max_batch_id is None or max_persisted_id > max_batch_id:
                 max_batch_id = max_persisted_id
                 max_batch_id = max_persisted_id
 
 
-        if min_batch_id is None:
+        # Either both of these should be None or neither.
+        if min_batch_id is None or max_batch_id is None:
             # no new receipts
             # no new receipts
             return False
             return False
 
 
@@ -103,7 +107,9 @@ class ReceiptsHandler(BaseHandler):
 
 
         return True
         return True
 
 
-    async def received_client_receipt(self, room_id, receipt_type, user_id, event_id):
+    async def received_client_receipt(
+        self, room_id: str, receipt_type: str, user_id: str, event_id: str
+    ) -> None:
         """Called when a client tells us a local user has read up to the given
         """Called when a client tells us a local user has read up to the given
         event_id in the room.
         event_id in the room.
         """
         """
@@ -123,10 +129,12 @@ class ReceiptsHandler(BaseHandler):
 
 
 
 
 class ReceiptEventSource:
 class ReceiptEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.store = hs.get_datastore()
 
 
-    async def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(
+        self, from_key: int, room_ids: List[str], **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         from_key = int(from_key)
         from_key = int(from_key)
         to_key = self.get_current_key()
         to_key = self.get_current_key()
 
 
@@ -171,5 +179,5 @@ class ReceiptEventSource:
 
 
         return (events, to_key)
         return (events, to_key)
 
 
-    def get_current_key(self, direction="f"):
+    def get_current_key(self, direction: str = "f") -> int:
         return self.store.get_max_receipt_stream_id()
         return self.store.get_max_receipt_stream_id()

+ 42 - 21
synapse/handlers/user_directory.py

@@ -14,14 +14,19 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import logging
 import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
 
 
 import synapse.metrics
 import synapse.metrics
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
 from synapse.handlers.state_deltas import StateDeltasHandler
 from synapse.handlers.state_deltas import StateDeltasHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.roommember import ProfileInfo
 from synapse.storage.roommember import ProfileInfo
+from synapse.types import JsonDict
 from synapse.util.metrics import Measure
 from synapse.util.metrics import Measure
 
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
@@ -36,7 +41,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     be in the directory or not when necessary.
     be in the directory or not when necessary.
     """
     """
 
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         super().__init__(hs)
 
 
         self.store = hs.get_datastore()
         self.store = hs.get_datastore()
@@ -49,7 +54,7 @@ class UserDirectoryHandler(StateDeltasHandler):
         self.search_all_users = hs.config.user_directory_search_all_users
         self.search_all_users = hs.config.user_directory_search_all_users
         self.spam_checker = hs.get_spam_checker()
         self.spam_checker = hs.get_spam_checker()
         # The current position in the current_state_delta stream
         # The current position in the current_state_delta stream
-        self.pos = None
+        self.pos = None  # type: Optional[int]
 
 
         # Guard to ensure we only process deltas one at a time
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
         self._is_processing = False
@@ -61,7 +66,9 @@ class UserDirectoryHandler(StateDeltasHandler):
             # we start populating the user directory
             # we start populating the user directory
             self.clock.call_later(0, self.notify_new_event)
             self.clock.call_later(0, self.notify_new_event)
 
 
-    async def search_users(self, user_id, search_term, limit):
+    async def search_users(
+        self, user_id: str, search_term: str, limit: int
+    ) -> JsonDict:
         """Searches for users in directory
         """Searches for users in directory
 
 
         Returns:
         Returns:
@@ -89,7 +96,7 @@ class UserDirectoryHandler(StateDeltasHandler):
 
 
         return results
         return results
 
 
-    def notify_new_event(self):
+    def notify_new_event(self) -> None:
         """Called when there may be more deltas to process
         """Called when there may be more deltas to process
         """
         """
         if not self.update_user_directory:
         if not self.update_user_directory:
@@ -107,7 +114,9 @@ class UserDirectoryHandler(StateDeltasHandler):
         self._is_processing = True
         self._is_processing = True
         run_as_background_process("user_directory.notify_new_event", process)
         run_as_background_process("user_directory.notify_new_event", process)
 
 
-    async def handle_local_profile_change(self, user_id, profile):
+    async def handle_local_profile_change(
+        self, user_id: str, profile: ProfileInfo
+    ) -> None:
         """Called to update index of our local user profiles when they change
         """Called to update index of our local user profiles when they change
         irrespective of any rooms the user may be in.
         irrespective of any rooms the user may be in.
         """
         """
@@ -124,14 +133,14 @@ class UserDirectoryHandler(StateDeltasHandler):
                 user_id, profile.display_name, profile.avatar_url
                 user_id, profile.display_name, profile.avatar_url
             )
             )
 
 
-    async def handle_user_deactivated(self, user_id):
+    async def handle_user_deactivated(self, user_id: str) -> None:
         """Called when a user ID is deactivated
         """Called when a user ID is deactivated
         """
         """
         # FIXME(#3714): We should probably do this in the same worker as all
         # FIXME(#3714): We should probably do this in the same worker as all
         # the other changes.
         # the other changes.
         await self.store.remove_from_user_dir(user_id)
         await self.store.remove_from_user_dir(user_id)
 
 
-    async def _unsafe_process(self):
+    async def _unsafe_process(self) -> None:
         # If self.pos is None then means we haven't fetched it from DB
         # If self.pos is None then means we haven't fetched it from DB
         if self.pos is None:
         if self.pos is None:
             self.pos = await self.store.get_user_directory_stream_pos()
             self.pos = await self.store.get_user_directory_stream_pos()
@@ -166,7 +175,7 @@ class UserDirectoryHandler(StateDeltasHandler):
 
 
                 await self.store.update_user_directory_stream_pos(max_pos)
                 await self.store.update_user_directory_stream_pos(max_pos)
 
 
-    async def _handle_deltas(self, deltas):
+    async def _handle_deltas(self, deltas: List[Dict[str, Any]]) -> None:
         """Called with the state deltas to process
         """Called with the state deltas to process
         """
         """
         for delta in deltas:
         for delta in deltas:
@@ -236,16 +245,20 @@ class UserDirectoryHandler(StateDeltasHandler):
                 logger.debug("Ignoring irrelevant type: %r", typ)
                 logger.debug("Ignoring irrelevant type: %r", typ)
 
 
     async def _handle_room_publicity_change(
     async def _handle_room_publicity_change(
-        self, room_id, prev_event_id, event_id, typ
-    ):
+        self,
+        room_id: str,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+        typ: str,
+    ) -> None:
         """Handle a room having potentially changed from/to world_readable/publicly
         """Handle a room having potentially changed from/to world_readable/publicly
         joinable.
         joinable.
 
 
         Args:
         Args:
-            room_id (str)
-            prev_event_id (str|None): The previous event before the state change
-            event_id (str|None): The new event after the state change
-            typ (str): Type of the event
+            room_id: The ID of the room which changed.
+            prev_event_id: The previous event before the state change
+            event_id: The new event after the state change
+            typ: Type of the event
         """
         """
         logger.debug("Handling change for %s: %s", typ, room_id)
         logger.debug("Handling change for %s: %s", typ, room_id)
 
 
@@ -303,12 +316,14 @@ class UserDirectoryHandler(StateDeltasHandler):
         for user_id, profile in users_with_profile.items():
         for user_id, profile in users_with_profile.items():
             await self._handle_new_user(room_id, user_id, profile)
             await self._handle_new_user(room_id, user_id, profile)
 
 
-    async def _handle_new_user(self, room_id, user_id, profile):
+    async def _handle_new_user(
+        self, room_id: str, user_id: str, profile: ProfileInfo
+    ) -> None:
         """Called when we might need to add user to directory
         """Called when we might need to add user to directory
 
 
         Args:
         Args:
-            room_id (str): room_id that user joined or started being public
-            user_id (str)
+            room_id: The room ID that user joined or started being public
+            user_id
         """
         """
         logger.debug("Adding new user to dir, %r", user_id)
         logger.debug("Adding new user to dir, %r", user_id)
 
 
@@ -356,12 +371,12 @@ class UserDirectoryHandler(StateDeltasHandler):
             if to_insert:
             if to_insert:
                 await self.store.add_users_who_share_private_room(room_id, to_insert)
                 await self.store.add_users_who_share_private_room(room_id, to_insert)
 
 
-    async def _handle_remove_user(self, room_id, user_id):
+    async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
         """Called when we might need to remove user from directory
         """Called when we might need to remove user from directory
 
 
         Args:
         Args:
-            room_id (str): room_id that user left or stopped being public that
-            user_id (str)
+            room_id: The room ID that user left or stopped being public that
+            user_id
         """
         """
         logger.debug("Removing user %r", user_id)
         logger.debug("Removing user %r", user_id)
 
 
@@ -374,7 +389,13 @@ class UserDirectoryHandler(StateDeltasHandler):
         if len(rooms_user_is_in) == 0:
         if len(rooms_user_is_in) == 0:
             await self.store.remove_from_user_dir(user_id)
             await self.store.remove_from_user_dir(user_id)
 
 
-    async def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
+    async def _handle_profile_change(
+        self,
+        user_id: str,
+        room_id: str,
+        prev_event_id: Optional[str],
+        event_id: Optional[str],
+    ) -> None:
         """Check member event changes for any profile changes and update the
         """Check member event changes for any profile changes and update the
         database if there are.
         database if there are.
         """
         """