Browse Source

Move callback-related code from AccountData to its own class

Andrew Morgan 1 year ago
parent
commit
3dcc1efc43

+ 3 - 16
synapse/handlers/account_data.py

@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import random
-from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse.api.constants import AccountDataTypes
 from synapse.replication.http.account_data import (
@@ -33,10 +33,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
-    [str, Optional[str], str, JsonDict], Awaitable
-]
-
 
 class AccountDataHandler:
     def __init__(self, hs: "HomeServer"):
@@ -60,16 +56,7 @@ class AccountDataHandler:
         self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs)
         self._account_data_writers = hs.config.worker.writers.account_data
 
-        self._on_account_data_updated_callbacks: List[
-            ON_ACCOUNT_DATA_UPDATED_CALLBACK
-        ] = []
-
-    def register_module_callbacks(
-        self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
-    ) -> None:
-        """Register callbacks from modules."""
-        if on_account_data_updated is not None:
-            self._on_account_data_updated_callbacks.append(on_account_data_updated)
+        self._module_api_callbacks = hs.get_module_api_callbacks().account_data
 
     async def _notify_modules(
         self,
@@ -92,7 +79,7 @@ class AccountDataHandler:
             account_data_type: The type of the account data.
             content: The content that is now associated with this type.
         """
-        for callback in self._on_account_data_updated_callbacks:
+        for callback in self._module_api_callbacks.on_account_data_updated_callbacks:
             try:
                 await callback(user_id, room_id, account_data_type, content)
             except Exception as e:

+ 4 - 4
synapse/module_api/__init__.py

@@ -41,7 +41,6 @@ from synapse.api.errors import SynapseError
 from synapse.events import EventBase
 from synapse.events.presence_router import PresenceRouter
 from synapse.events.spamcheck import SpamChecker
-from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
 from synapse.handlers.auth import AuthHandler
 from synapse.handlers.device import DeviceHandler
 from synapse.handlers.push_rules import RuleSpec, check_actions
@@ -59,6 +58,9 @@ from synapse.logging.context import (
     run_in_background,
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.module_api.callbacks.account_data_callbacks import (
+    ON_ACCOUNT_DATA_UPDATED_CALLBACK,
+)
 from synapse.module_api.callbacks.account_validity_callbacks import (
     IS_USER_EXPIRED_CALLBACK,
     ON_LEGACY_ADMIN_REQUEST,
@@ -271,8 +273,6 @@ class ModuleApi:
         self._public_room_list_manager = PublicRoomListManager(hs)
         self._account_data_manager = AccountDataManager(hs)
 
-        self._account_data_handler = hs.get_account_data_handler()
-
     #################################################################################
     # The following methods should only be called during the module's initialisation.
 
@@ -452,7 +452,7 @@ class ModuleApi:
 
         Added in Synapse 1.57.0.
         """
-        return self._account_data_handler.register_module_callbacks(
+        return self._callbacks.account_data.register_callbacks(
             on_account_data_updated=on_account_data_updated,
         )
 

+ 2 - 0
synapse/module_api/callbacks/__init__.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from .account_data_callbacks import AccountDataModuleApiCallbacks
 from .account_validity_callbacks import AccountValidityModuleApiCallbacks
 from .background_updater_callbacks import BackgroundUpdaterModuleApiCallbacks
 from .password_auth_provider_callbacks import PasswordAuthProviderModuleApiCallbacks
@@ -26,6 +27,7 @@ __all__ = [
 
 class ModuleApiCallbacks:
     def __init__(self) -> None:
+        self.account_data = AccountDataModuleApiCallbacks()
         self.account_validity = AccountValidityModuleApiCallbacks()
         self.background_updater = BackgroundUpdaterModuleApiCallbacks()
         self.password_auth_provider = PasswordAuthProviderModuleApiCallbacks()

+ 35 - 0
synapse/module_api/callbacks/account_data_callbacks.py

@@ -0,0 +1,35 @@
+# Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2021, 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Awaitable, Callable, List, Optional
+
+from synapse.types import JsonDict
+
+ON_ACCOUNT_DATA_UPDATED_CALLBACK = Callable[
+    [str, Optional[str], str, JsonDict], Awaitable
+]
+
+
+class AccountDataModuleApiCallbacks:
+    def __init__(self) -> None:
+        self.on_account_data_updated_callbacks: List[
+            ON_ACCOUNT_DATA_UPDATED_CALLBACK
+        ] = []
+
+    def register_callbacks(
+        self, on_account_data_updated: Optional[ON_ACCOUNT_DATA_UPDATED_CALLBACK] = None
+    ) -> None:
+        """Register callbacks from modules."""
+        if on_account_data_updated is not None:
+            self.on_account_data_updated_callbacks.append(on_account_data_updated)

+ 1 - 1
tests/rest/client/test_account_data.py

@@ -33,7 +33,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
         a user's account data changes.
         """
         mocked_callback = Mock(return_value=make_awaitable(None))
-        self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
+        self.hs.get_module_api_callbacks().account_data.on_account_data_updated_callbacks.append(
             mocked_callback
         )