Răsfoiți Sursa

Merge pull request #6196 from matrix-org/erikj/await

Move rest/admin to use async/await.
Erik Johnston 4 ani în urmă
părinte
comite
d98029ea89

+ 1 - 0
changelog.d/6196.misc

@@ -0,0 +1 @@
+Port synapse.rest.admin module to use async/await.

+ 58 - 72
synapse/rest/admin/__init__.py

@@ -23,8 +23,6 @@ import re
 from six import text_type
 from six.moves import http_client
 
-from twisted.internet import defer
-
 import synapse
 from synapse.api.constants import Membership, UserTypes
 from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -46,6 +44,7 @@ from synapse.rest.admin.purge_room_servlet import PurgeRoomServlet
 from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet
 from synapse.rest.admin.users import UserAdminServlet
 from synapse.types import UserID, create_requester
+from synapse.util.async_helpers import maybe_awaitable
 from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger(__name__)
@@ -59,15 +58,14 @@ class UsersRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.handlers = hs.get_handlers()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, user_id):
+    async def on_GET(self, request, user_id):
         target_user = UserID.from_string(user_id)
-        yield assert_requester_is_admin(self.auth, request)
+        await assert_requester_is_admin(self.auth, request)
 
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "Can only users a local user")
 
-        ret = yield self.handlers.admin_handler.get_users()
+        ret = await self.handlers.admin_handler.get_users()
 
         return 200, ret
 
@@ -122,8 +120,7 @@ class UserRegisterServlet(RestServlet):
         self.nonces[nonce] = int(self.reactor.seconds())
         return 200, {"nonce": nonce}
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
+    async def on_POST(self, request):
         self._clear_old_nonces()
 
         if not self.hs.config.registration_shared_secret:
@@ -204,14 +201,14 @@ class UserRegisterServlet(RestServlet):
 
         register = RegisterRestServlet(self.hs)
 
-        user_id = yield register.registration_handler.register_user(
+        user_id = await register.registration_handler.register_user(
             localpart=body["username"].lower(),
             password=body["password"],
             admin=bool(admin),
             user_type=user_type,
         )
 
-        result = yield register._create_registration_details(user_id, body)
+        result = await register._create_registration_details(user_id, body)
         return 200, result
 
 
@@ -223,19 +220,18 @@ class WhoisRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.handlers = hs.get_handlers()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, user_id):
+    async def on_GET(self, request, user_id):
         target_user = UserID.from_string(user_id)
-        requester = yield self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request)
         auth_user = requester.user
 
         if target_user != auth_user:
-            yield assert_user_is_admin(self.auth, auth_user)
+            await assert_user_is_admin(self.auth, auth_user)
 
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "Can only whois a local user")
 
-        ret = yield self.handlers.admin_handler.get_whois(target_user)
+        ret = await self.handlers.admin_handler.get_whois(target_user)
 
         return 200, ret
 
@@ -255,9 +251,8 @@ class PurgeHistoryRestServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, event_id):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_POST(self, request, room_id, event_id):
+        await assert_requester_is_admin(self.auth, request)
 
         body = parse_json_object_from_request(request, allow_empty_body=True)
 
@@ -270,12 +265,12 @@ class PurgeHistoryRestServlet(RestServlet):
             event_id = body.get("purge_up_to_event_id")
 
         if event_id is not None:
-            event = yield self.store.get_event(event_id)
+            event = await self.store.get_event(event_id)
 
             if event.room_id != room_id:
                 raise SynapseError(400, "Event is for wrong room.")
 
-            token = yield self.store.get_topological_token_for_event(event_id)
+            token = await self.store.get_topological_token_for_event(event_id)
 
             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
         elif "purge_up_to_ts" in body:
@@ -285,12 +280,10 @@ class PurgeHistoryRestServlet(RestServlet):
                     400, "purge_up_to_ts must be an int", errcode=Codes.BAD_JSON
                 )
 
-            stream_ordering = (yield self.store.find_first_stream_ordering_after_ts(ts))
+            stream_ordering = await self.store.find_first_stream_ordering_after_ts(ts)
 
-            r = (
-                yield self.store.get_room_event_after_stream_ordering(
-                    room_id, stream_ordering
-                )
+            r = await self.store.get_room_event_after_stream_ordering(
+                room_id, stream_ordering
             )
             if not r:
                 logger.warn(
@@ -318,7 +311,7 @@ class PurgeHistoryRestServlet(RestServlet):
                 errcode=Codes.BAD_JSON,
             )
 
-        purge_id = yield self.pagination_handler.start_purge_history(
+        purge_id = self.pagination_handler.start_purge_history(
             room_id, token, delete_local_events=delete_local_events
         )
 
@@ -339,9 +332,8 @@ class PurgeHistoryStatusRestServlet(RestServlet):
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, purge_id):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_GET(self, request, purge_id):
+        await assert_requester_is_admin(self.auth, request)
 
         purge_status = self.pagination_handler.get_purge_status(purge_id)
         if purge_status is None:
@@ -357,9 +349,8 @@ class DeactivateAccountRestServlet(RestServlet):
         self._deactivate_account_handler = hs.get_deactivate_account_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, target_user_id):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_POST(self, request, target_user_id):
+        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request, allow_empty_body=True)
         erase = body.get("erase", False)
         if not isinstance(erase, bool):
@@ -371,7 +362,7 @@ class DeactivateAccountRestServlet(RestServlet):
 
         UserID.from_string(target_user_id)
 
-        result = yield self._deactivate_account_handler.deactivate_account(
+        result = await self._deactivate_account_handler.deactivate_account(
             target_user_id, erase
         )
         if result:
@@ -405,10 +396,9 @@ class ShutdownRoomRestServlet(RestServlet):
         self.room_member_handler = hs.get_room_member_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
-        yield assert_user_is_admin(self.auth, requester.user)
+    async def on_POST(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
 
         content = parse_json_object_from_request(request)
         assert_params_in_dict(content, ["new_room_user_id"])
@@ -419,7 +409,7 @@ class ShutdownRoomRestServlet(RestServlet):
         message = content.get("message", self.DEFAULT_MESSAGE)
         room_name = content.get("room_name", "Content Violation Notification")
 
-        info = yield self._room_creation_handler.create_room(
+        info = await self._room_creation_handler.create_room(
             room_creator_requester,
             config={
                 "preset": "public_chat",
@@ -438,9 +428,9 @@ class ShutdownRoomRestServlet(RestServlet):
 
         # This will work even if the room is already blocked, but that is
         # desirable in case the first attempt at blocking the room failed below.
-        yield self.store.block_room(room_id, requester_user_id)
+        await self.store.block_room(room_id, requester_user_id)
 
-        users = yield self.state.get_current_users_in_room(room_id)
+        users = await self.state.get_current_users_in_room(room_id)
         kicked_users = []
         failed_to_kick_users = []
         for user_id in users:
@@ -451,7 +441,7 @@ class ShutdownRoomRestServlet(RestServlet):
 
             try:
                 target_requester = create_requester(user_id)
-                yield self.room_member_handler.update_membership(
+                await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
                     room_id=room_id,
@@ -461,9 +451,9 @@ class ShutdownRoomRestServlet(RestServlet):
                     require_consent=False,
                 )
 
-                yield self.room_member_handler.forget(target_requester.user, room_id)
+                await self.room_member_handler.forget(target_requester.user, room_id)
 
-                yield self.room_member_handler.update_membership(
+                await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
                     room_id=new_room_id,
@@ -480,7 +470,7 @@ class ShutdownRoomRestServlet(RestServlet):
                 )
                 failed_to_kick_users.append(user_id)
 
-        yield self.event_creation_handler.create_and_send_nonmember_event(
+        await self.event_creation_handler.create_and_send_nonmember_event(
             room_creator_requester,
             {
                 "type": "m.room.message",
@@ -491,9 +481,11 @@ class ShutdownRoomRestServlet(RestServlet):
             ratelimit=False,
         )
 
-        aliases_for_room = yield self.store.get_aliases_for_room(room_id)
+        aliases_for_room = await maybe_awaitable(
+            self.store.get_aliases_for_room(room_id)
+        )
 
-        yield self.store.update_aliases_for_room(
+        await self.store.update_aliases_for_room(
             room_id, new_room_id, requester_user_id
         )
 
@@ -532,13 +524,12 @@ class ResetPasswordRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self._set_password_handler = hs.get_set_password_handler()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, target_user_id):
+    async def on_POST(self, request, target_user_id):
         """Post request to allow an administrator reset password for a user.
         This needs user to have administrator access in Synapse.
         """
-        requester = yield self.auth.get_user_by_req(request)
-        yield assert_user_is_admin(self.auth, requester.user)
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
 
         UserID.from_string(target_user_id)
 
@@ -546,7 +537,7 @@ class ResetPasswordRestServlet(RestServlet):
         assert_params_in_dict(params, ["new_password"])
         new_password = params["new_password"]
 
-        yield self._set_password_handler.set_password(
+        await self._set_password_handler.set_password(
             target_user_id, new_password, requester
         )
         return 200, {}
@@ -572,12 +563,11 @@ class GetUsersPaginatedRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.handlers = hs.get_handlers()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, target_user_id):
+    async def on_GET(self, request, target_user_id):
         """Get request to get specific number of users from Synapse.
         This needs user to have administrator access in Synapse.
         """
-        yield assert_requester_is_admin(self.auth, request)
+        await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(target_user_id)
 
@@ -590,11 +580,10 @@ class GetUsersPaginatedRestServlet(RestServlet):
 
         logger.info("limit: %s, start: %s", limit, start)
 
-        ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
+        ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
         return 200, ret
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, target_user_id):
+    async def on_POST(self, request, target_user_id):
         """Post request to get specific number of users from Synapse..
         This needs user to have administrator access in Synapse.
         Example:
@@ -608,7 +597,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
         Returns:
             200 OK with json object {list[dict[str, Any]], count} or empty object.
         """
-        yield assert_requester_is_admin(self.auth, request)
+        await assert_requester_is_admin(self.auth, request)
         UserID.from_string(target_user_id)
 
         order = "name"  # order by name in user table
@@ -618,7 +607,7 @@ class GetUsersPaginatedRestServlet(RestServlet):
         start = params["start"]
         logger.info("limit: %s, start: %s", limit, start)
 
-        ret = yield self.handlers.admin_handler.get_users_paginate(order, start, limit)
+        ret = await self.handlers.admin_handler.get_users_paginate(order, start, limit)
         return 200, ret
 
 
@@ -641,13 +630,12 @@ class SearchUsersRestServlet(RestServlet):
         self.auth = hs.get_auth()
         self.handlers = hs.get_handlers()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, target_user_id):
+    async def on_GET(self, request, target_user_id):
         """Get request to search user table for specific users according to
         search term.
         This needs user to have a administrator access in Synapse.
         """
-        yield assert_requester_is_admin(self.auth, request)
+        await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(target_user_id)
 
@@ -661,7 +649,7 @@ class SearchUsersRestServlet(RestServlet):
         term = parse_string(request, "term", required=True)
         logger.info("term: %s ", term)
 
-        ret = yield self.handlers.admin_handler.search_users(term)
+        ret = await self.handlers.admin_handler.search_users(term)
         return 200, ret
 
 
@@ -676,15 +664,14 @@ class DeleteGroupAdminRestServlet(RestServlet):
         self.is_mine_id = hs.is_mine_id
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, group_id):
-        requester = yield self.auth.get_user_by_req(request)
-        yield assert_user_is_admin(self.auth, requester.user)
+    async def on_POST(self, request, group_id):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
 
         if not self.is_mine_id(group_id):
             raise SynapseError(400, "Can only delete local groups")
 
-        yield self.group_server.delete_group(group_id, requester.user.to_string())
+        await self.group_server.delete_group(group_id, requester.user.to_string())
         return 200, {}
 
 
@@ -700,16 +687,15 @@ class AccountValidityRenewServlet(RestServlet):
         self.account_activity_handler = hs.get_account_validity_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_POST(self, request):
+        await assert_requester_is_admin(self.auth, request)
 
         body = parse_json_object_from_request(request)
 
         if "user_id" not in body:
             raise SynapseError(400, "Missing property 'user_id' in the request body")
 
-        expiration_ts = yield self.account_activity_handler.renew_account_for_user(
+        expiration_ts = await self.account_activity_handler.renew_account_for_user(
             body["user_id"],
             body.get("expiration_ts"),
             not body.get("enable_renewal_emails", True),

+ 5 - 9
synapse/rest/admin/_base.py

@@ -15,8 +15,6 @@
 
 import re
 
-from twisted.internet import defer
-
 from synapse.api.errors import AuthError
 
 
@@ -42,8 +40,7 @@ def historical_admin_path_patterns(path_regex):
     )
 
 
-@defer.inlineCallbacks
-def assert_requester_is_admin(auth, request):
+async def assert_requester_is_admin(auth, request):
     """Verify that the requester is an admin user
 
     WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -58,12 +55,11 @@ def assert_requester_is_admin(auth, request):
     Raises:
         AuthError if the requester is not an admin
     """
-    requester = yield auth.get_user_by_req(request)
-    yield assert_user_is_admin(auth, requester.user)
+    requester = await auth.get_user_by_req(request)
+    await assert_user_is_admin(auth, requester.user)
 
 
-@defer.inlineCallbacks
-def assert_user_is_admin(auth, user_id):
+async def assert_user_is_admin(auth, user_id):
     """Verify that the given user is an admin user
 
     WARNING: MAKE SURE YOU YIELD ON THE RESULT!
@@ -79,6 +75,6 @@ def assert_user_is_admin(auth, user_id):
         AuthError if the user is not an admin
     """
 
-    is_admin = yield auth.is_server_admin(user_id)
+    is_admin = await auth.is_server_admin(user_id)
     if not is_admin:
         raise AuthError(403, "You are not a server admin")

+ 11 - 16
synapse/rest/admin/media.py

@@ -16,8 +16,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import AuthError
 from synapse.http.servlet import RestServlet, parse_integer
 from synapse.rest.admin._base import (
@@ -40,12 +38,11 @@ class QuarantineMediaInRoom(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
-        yield assert_user_is_admin(self.auth, requester.user)
+    async def on_POST(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
 
-        num_quarantined = yield self.store.quarantine_media_ids_in_room(
+        num_quarantined = await self.store.quarantine_media_ids_in_room(
             room_id, requester.user.to_string()
         )
 
@@ -62,14 +59,13 @@ class ListMediaInRoom(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
-        is_admin = yield self.auth.is_server_admin(requester.user)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
+        is_admin = await self.auth.is_server_admin(requester.user)
         if not is_admin:
             raise AuthError(403, "You are not a server admin")
 
-        local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
+        local_mxcs, remote_mxcs = await self.store.get_media_mxcs_in_room(room_id)
 
         return 200, {"local": local_mxcs, "remote": remote_mxcs}
 
@@ -81,14 +77,13 @@ class PurgeMediaCacheRestServlet(RestServlet):
         self.media_repository = hs.get_media_repository()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_POST(self, request):
+        await assert_requester_is_admin(self.auth, request)
 
         before_ts = parse_integer(request, "before_ts", required=True)
         logger.info("before_ts: %r", before_ts)
 
-        ret = yield self.media_repository.delete_old_remote_media(before_ts)
+        ret = await self.media_repository.delete_old_remote_media(before_ts)
 
         return 200, ret
 

+ 3 - 6
synapse/rest/admin/server_notice_servlet.py

@@ -14,8 +14,6 @@
 # limitations under the License.
 import re
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import (
@@ -69,9 +67,8 @@ class SendServerNoticeServlet(RestServlet):
             self.__class__.__name__,
         )
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, txn_id=None):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_POST(self, request, txn_id=None):
+        await assert_requester_is_admin(self.auth, request)
         body = parse_json_object_from_request(request)
         assert_params_in_dict(body, ("user_id", "content"))
         event_type = body.get("type", EventTypes.Message)
@@ -85,7 +82,7 @@ class SendServerNoticeServlet(RestServlet):
         if not self.hs.is_mine_id(user_id):
             raise SynapseError(400, "Server notices can only be sent to local users")
 
-        event = yield self.snm.send_notice(
+        event = await self.snm.send_notice(
             user_id=body["user_id"],
             type=event_type,
             state_key=state_key,

+ 7 - 11
synapse/rest/admin/users.py

@@ -14,8 +14,6 @@
 # limitations under the License.
 import re
 
-from twisted.internet import defer
-
 from synapse.api.errors import SynapseError
 from synapse.http.servlet import (
     RestServlet,
@@ -59,24 +57,22 @@ class UserAdminServlet(RestServlet):
         self.auth = hs.get_auth()
         self.handlers = hs.get_handlers()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, user_id):
-        yield assert_requester_is_admin(self.auth, request)
+    async def on_GET(self, request, user_id):
+        await assert_requester_is_admin(self.auth, request)
 
         target_user = UserID.from_string(user_id)
 
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "Only local users can be admins of this homeserver")
 
-        is_admin = yield self.handlers.admin_handler.get_user_server_admin(target_user)
+        is_admin = await self.handlers.admin_handler.get_user_server_admin(target_user)
         is_admin = bool(is_admin)
 
         return 200, {"admin": is_admin}
 
-    @defer.inlineCallbacks
-    def on_PUT(self, request, user_id):
-        requester = yield self.auth.get_user_by_req(request)
-        yield assert_user_is_admin(self.auth, requester.user)
+    async def on_PUT(self, request, user_id):
+        requester = await self.auth.get_user_by_req(request)
+        await assert_user_is_admin(self.auth, requester.user)
         auth_user = requester.user
 
         target_user = UserID.from_string(user_id)
@@ -93,7 +89,7 @@ class UserAdminServlet(RestServlet):
         if target_user == auth_user and not set_admin_to:
             raise SynapseError(400, "You may not demote yourself.")
 
-        yield self.handlers.admin_handler.set_user_server_admin(
+        await self.handlers.admin_handler.set_user_server_admin(
             target_user, set_admin_to
         )
 

+ 29 - 0
synapse/util/async_helpers.py

@@ -21,6 +21,8 @@ from typing import Dict, Sequence, Set, Union
 
 from six.moves import range
 
+import attr
+
 from twisted.internet import defer
 from twisted.internet.defer import CancelledError
 from twisted.python import failure
@@ -483,3 +485,30 @@ def timeout_deferred(deferred, timeout, reactor, on_timeout_cancel=None):
     deferred.addCallbacks(success_cb, failure_cb)
 
     return new_d
+
+
+@attr.s(slots=True, frozen=True)
+class DoneAwaitable(object):
+    """Simple awaitable that returns the provided value.
+    """
+
+    value = attr.ib()
+
+    def __await__(self):
+        return self
+
+    def __iter__(self):
+        return self
+
+    def __next__(self):
+        raise StopIteration(self.value)
+
+
+def maybe_awaitable(value):
+    """Convert a value to an awaitable if not already an awaitable.
+    """
+
+    if hasattr(value, "__await__"):
+        return value
+
+    return DoneAwaitable(value)