Browse Source

Add ability for access tokens to belong to one user but grant access to another user. (#8616)

We do it this way round so that only the "owner" can delete the access token (i.e. `/logout/all` by the "owner" also deletes that token, but `/logout/all` by the "target user" doesn't).

A future PR will add an API for creating such a token.

When the target user and authenticated entity are different the `Processed request` log line will be logged with a: `{@admin:server as @bob:server} ...`. I'm not convinced by that format (especially since it adds spaces in there, making it harder to use `cut -d ' '` to chop off the start of log lines). Suggestions welcome.
Erik Johnston 3 years ago
parent
commit
f21e24ffc2

+ 1 - 0
changelog.d/8616.misc

@@ -0,0 +1 @@
+Change schema to support access tokens belonging to one user but granting access to another.

+ 46 - 67
synapse/api/auth.py

@@ -33,6 +33,7 @@ from synapse.api.errors import (
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.logging import opentracing as opentracing
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import StateMap, UserID
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.metrics import Measure
@@ -190,10 +191,6 @@ class Auth:
 
             user_id, app_service = await self._get_appservice_user_id(request)
             if user_id:
-                request.authenticated_entity = user_id
-                opentracing.set_tag("authenticated_entity", user_id)
-                opentracing.set_tag("appservice_id", app_service.id)
-
                 if ip_addr and self._track_appservice_user_ips:
                     await self.store.insert_client_ip(
                         user_id=user_id,
@@ -203,31 +200,38 @@ class Auth:
                         device_id="dummy-device",  # stubbed
                     )
 
-                return synapse.types.create_requester(user_id, app_service=app_service)
+                requester = synapse.types.create_requester(
+                    user_id, app_service=app_service
+                )
+
+                request.requester = user_id
+                opentracing.set_tag("authenticated_entity", user_id)
+                opentracing.set_tag("user_id", user_id)
+                opentracing.set_tag("appservice_id", app_service.id)
+
+                return requester
 
             user_info = await self.get_user_by_access_token(
                 access_token, rights, allow_expired=allow_expired
             )
-            user = user_info["user"]
-            token_id = user_info["token_id"]
-            is_guest = user_info["is_guest"]
-            shadow_banned = user_info["shadow_banned"]
+            token_id = user_info.token_id
+            is_guest = user_info.is_guest
+            shadow_banned = user_info.shadow_banned
 
             # Deny the request if the user account has expired.
             if self._account_validity.enabled and not allow_expired:
-                user_id = user.to_string()
-                if await self.store.is_account_expired(user_id, self.clock.time_msec()):
+                if await self.store.is_account_expired(
+                    user_info.user_id, self.clock.time_msec()
+                ):
                     raise AuthError(
                         403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
                     )
 
-            # device_id may not be present if get_user_by_access_token has been
-            # stubbed out.
-            device_id = user_info.get("device_id")
+            device_id = user_info.device_id
 
-            if user and access_token and ip_addr:
+            if access_token and ip_addr:
                 await self.store.insert_client_ip(
-                    user_id=user.to_string(),
+                    user_id=user_info.token_owner,
                     access_token=access_token,
                     ip=ip_addr,
                     user_agent=user_agent,
@@ -241,19 +245,23 @@ class Auth:
                     errcode=Codes.GUEST_ACCESS_FORBIDDEN,
                 )
 
-            request.authenticated_entity = user.to_string()
-            opentracing.set_tag("authenticated_entity", user.to_string())
-            if device_id:
-                opentracing.set_tag("device_id", device_id)
-
-            return synapse.types.create_requester(
-                user,
+            requester = synapse.types.create_requester(
+                user_info.user_id,
                 token_id,
                 is_guest,
                 shadow_banned,
                 device_id,
                 app_service=app_service,
+                authenticated_entity=user_info.token_owner,
             )
+
+            request.requester = requester
+            opentracing.set_tag("authenticated_entity", user_info.token_owner)
+            opentracing.set_tag("user_id", user_info.user_id)
+            if device_id:
+                opentracing.set_tag("device_id", device_id)
+
+            return requester
         except KeyError:
             raise MissingClientTokenError()
 
@@ -284,7 +292,7 @@ class Auth:
 
     async def get_user_by_access_token(
         self, token: str, rights: str = "access", allow_expired: bool = False,
-    ) -> dict:
+    ) -> TokenLookupResult:
         """ Validate access token and get user_id from it
 
         Args:
@@ -293,13 +301,7 @@ class Auth:
                 allow this
             allow_expired: If False, raises an InvalidClientTokenError
                 if the token is expired
-        Returns:
-            dict that includes:
-               `user` (UserID)
-               `is_guest` (bool)
-               `shadow_banned` (bool)
-               `token_id` (int|None): access token id. May be None if guest
-               `device_id` (str|None): device corresponding to access token
+
         Raises:
             InvalidClientTokenError if a user by that token exists, but the token is
                 expired
@@ -309,9 +311,9 @@ class Auth:
 
         if rights == "access":
             # first look in the database
-            r = await self._look_up_user_by_access_token(token)
+            r = await self.store.get_user_by_access_token(token)
             if r:
-                valid_until_ms = r["valid_until_ms"]
+                valid_until_ms = r.valid_until_ms
                 if (
                     not allow_expired
                     and valid_until_ms is not None
@@ -328,7 +330,6 @@ class Auth:
         # otherwise it needs to be a valid macaroon
         try:
             user_id, guest = self._parse_and_validate_macaroon(token, rights)
-            user = UserID.from_string(user_id)
 
             if rights == "access":
                 if not guest:
@@ -354,23 +355,17 @@ class Auth:
                     raise InvalidClientTokenError(
                         "Guest access token used for regular user"
                     )
-                ret = {
-                    "user": user,
-                    "is_guest": True,
-                    "shadow_banned": False,
-                    "token_id": None,
+
+                ret = TokenLookupResult(
+                    user_id=user_id,
+                    is_guest=True,
                     # all guests get the same device id
-                    "device_id": GUEST_DEVICE_ID,
-                }
+                    device_id=GUEST_DEVICE_ID,
+                )
             elif rights == "delete_pusher":
                 # We don't store these tokens in the database
-                ret = {
-                    "user": user,
-                    "is_guest": False,
-                    "shadow_banned": False,
-                    "token_id": None,
-                    "device_id": None,
-                }
+
+                ret = TokenLookupResult(user_id=user_id, is_guest=False)
             else:
                 raise RuntimeError("Unknown rights setting %s", rights)
             return ret
@@ -479,31 +474,15 @@ class Auth:
         now = self.hs.get_clock().time_msec()
         return now < expiry
 
-    async def _look_up_user_by_access_token(self, token):
-        ret = await self.store.get_user_by_access_token(token)
-        if not ret:
-            return None
-
-        # we use ret.get() below because *lots* of unit tests stub out
-        # get_user_by_access_token in a way where it only returns a couple of
-        # the fields.
-        user_info = {
-            "user": UserID.from_string(ret.get("name")),
-            "token_id": ret.get("token_id", None),
-            "is_guest": False,
-            "shadow_banned": ret.get("shadow_banned"),
-            "device_id": ret.get("device_id"),
-            "valid_until_ms": ret.get("valid_until_ms"),
-        }
-        return user_info
-
     def get_appservice_by_req(self, request):
         token = self.get_access_token_from_request(request)
         service = self.store.get_app_service_by_token(token)
         if not service:
             logger.warning("Unrecognised appservice access token.")
             raise InvalidClientTokenError()
-        request.authenticated_entity = service.sender
+        request.requester = synapse.types.create_requester(
+            service.sender, app_service=service
+        )
         return service
 
     async def is_server_admin(self, user: UserID) -> bool:

+ 2 - 2
synapse/appservice/__init__.py

@@ -52,11 +52,11 @@ class ApplicationService:
         self,
         token,
         hostname,
+        id,
+        sender,
         url=None,
         namespaces=None,
         hs_token=None,
-        sender=None,
-        id=None,
         protocols=None,
         rate_limited=True,
         ip_range_whitelist=None,

+ 1 - 1
synapse/federation/transport/server.py

@@ -154,7 +154,7 @@ class Authenticator:
         )
 
         logger.debug("Request from %s", origin)
-        request.authenticated_entity = origin
+        request.requester = origin
 
         # If we get a valid signed request from the other side, its probably
         # alive

+ 4 - 4
synapse/handlers/auth.py

@@ -991,17 +991,17 @@ class AuthHandler(BaseHandler):
                 # This might return an awaitable, if it does block the log out
                 # until it completes.
                 result = provider.on_logged_out(
-                    user_id=str(user_info["user"]),
-                    device_id=user_info["device_id"],
+                    user_id=user_info.user_id,
+                    device_id=user_info.device_id,
                     access_token=access_token,
                 )
                 if inspect.isawaitable(result):
                     await result
 
         # delete pushers associated with this access token
-        if user_info["token_id"] is not None:
+        if user_info.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                str(user_info["user"]), (user_info["token_id"],)
+                user_info.user_id, (user_info.token_id,)
             )
 
     async def delete_access_tokens_for_user(

+ 5 - 2
synapse/handlers/register.py

@@ -115,7 +115,10 @@ class RegistrationHandler(BaseHandler):
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
             user_data = await self.auth.get_user_by_access_token(guest_access_token)
-            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+            if (
+                not user_data.is_guest
+                or UserID.from_string(user_data.user_id).localpart != localpart
+            ):
                 raise AuthError(
                     403,
                     "Cannot register taken user ID without valid guest "
@@ -741,7 +744,7 @@ class RegistrationHandler(BaseHandler):
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
             user_tuple = await self.store.get_user_by_access_token(token)
-            token_id = user_tuple["token_id"]
+            token_id = user_tuple.token_id
 
             await self.pusher_pool.add_pusher(
                 user_id=user_id,

+ 23 - 7
synapse/http/site.py

@@ -14,7 +14,7 @@
 import contextlib
 import logging
 import time
-from typing import Optional
+from typing import Optional, Union
 
 from twisted.python.failure import Failure
 from twisted.web.server import Request, Site
@@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig
 from synapse.http import redact_uri
 from synapse.http.request_metrics import RequestMetrics, requests_counter
 from synapse.logging.context import LoggingContext, PreserveLoggingContext
+from synapse.types import Requester
 
 logger = logging.getLogger(__name__)
 
@@ -54,9 +55,12 @@ class SynapseRequest(Request):
         Request.__init__(self, channel, *args, **kw)
         self.site = channel.site
         self._channel = channel  # this is used by the tests
-        self.authenticated_entity = None
         self.start_time = 0.0
 
+        # The requester, if authenticated. For federation requests this is the
+        # server name, for client requests this is the Requester object.
+        self.requester = None  # type: Optional[Union[Requester, str]]
+
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext = None  # type: Optional[LoggingContext]
 
@@ -271,11 +275,23 @@ class SynapseRequest(Request):
         # to the client (nb may be negative)
         response_send_time = self.finish_time - self._processing_finished_time
 
-        # need to decode as it could be raw utf-8 bytes
-        # from a IDN servname in an auth header
-        authenticated_entity = self.authenticated_entity
-        if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
-            authenticated_entity = authenticated_entity.decode("utf-8", "replace")
+        # Convert the requester into a string that we can log
+        authenticated_entity = None
+        if isinstance(self.requester, str):
+            authenticated_entity = self.requester
+        elif isinstance(self.requester, Requester):
+            authenticated_entity = self.requester.authenticated_entity
+
+            # If this is a request where the target user doesn't match the user who
+            # authenticated (e.g. and admin is puppetting a user) then we log both.
+            if self.requester.user.to_string() != authenticated_entity:
+                authenticated_entity = "{},{}".format(
+                    authenticated_entity, self.requester.user.to_string(),
+                )
+        elif self.requester is not None:
+            # This shouldn't happen, but we log it so we don't lose information
+            # and can see that we're doing something wrong.
+            authenticated_entity = repr(self.requester)  # type: ignore[unreachable]
 
         # ...or could be raw utf-8 bytes in the User-Agent header.
         # N.B. if you don't do this, the logger explodes cryptically

+ 2 - 4
synapse/replication/http/membership.py

@@ -77,8 +77,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         requester = Requester.deserialize(self.store, content["requester"])
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
@@ -142,8 +141,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
 
         requester = Requester.deserialize(self.store, content["requester"])
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         # hopefully we're now on the master, so this won't recurse!
         event_id, stream_id = await self.member_handler.remote_reject_invite(

+ 1 - 2
synapse/replication/http/send_event.py

@@ -115,8 +115,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
 
-        if requester.user:
-            request.authenticated_entity = requester.user.to_string()
+        request.requester = requester
 
         logger.info(
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id

+ 39 - 9
synapse/storage/databases/main/registration.py

@@ -18,6 +18,8 @@ import logging
 import re
 from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
@@ -38,6 +40,35 @@ THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
 logger = logging.getLogger(__name__)
 
 
+@attr.s(frozen=True, slots=True)
+class TokenLookupResult:
+    """Result of looking up an access token.
+
+    Attributes:
+        user_id: The user that this token authenticates as
+        is_guest
+        shadow_banned
+        token_id: The ID of the access token looked up
+        device_id: The device associated with the token, if any.
+        valid_until_ms: The timestamp the token expires, if any.
+        token_owner: The "owner" of the token. This is either the same as the
+            user, or a server admin who is logged in as the user.
+    """
+
+    user_id = attr.ib(type=str)
+    is_guest = attr.ib(type=bool, default=False)
+    shadow_banned = attr.ib(type=bool, default=False)
+    token_id = attr.ib(type=Optional[int], default=None)
+    device_id = attr.ib(type=Optional[str], default=None)
+    valid_until_ms = attr.ib(type=Optional[int], default=None)
+    token_owner = attr.ib(type=str)
+
+    # Make the token owner default to the user ID, which is the common case.
+    @token_owner.default
+    def _default_token_owner(self):
+        return self.user_id
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
@@ -102,15 +133,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         return is_trial
 
     @cached()
-    async def get_user_by_access_token(self, token: str) -> Optional[dict]:
+    async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
         """Get a user from the given access token.
 
         Args:
             token: The access token of a user.
         Returns:
-            None, if the token did not match, otherwise dict
-            including the keys `name`, `is_guest`, `device_id`, `token_id`,
-            `valid_until_ms`.
+            None, if the token did not match, otherwise a `TokenLookupResult`
         """
         return await self.db_pool.runInteraction(
             "get_user_by_access_token", self._query_for_auth, token
@@ -331,23 +360,24 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
 
-    def _query_for_auth(self, txn, token):
+    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
         sql = """
-            SELECT users.name,
+            SELECT users.name as user_id,
                 users.is_guest,
                 users.shadow_banned,
                 access_tokens.id as token_id,
                 access_tokens.device_id,
-                access_tokens.valid_until_ms
+                access_tokens.valid_until_ms,
+                access_tokens.user_id as token_owner
             FROM users
-            INNER JOIN access_tokens on users.name = access_tokens.user_id
+            INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
             WHERE token = ?
         """
 
         txn.execute(sql, (token,))
         rows = self.db_pool.cursor_to_dict(txn)
         if rows:
-            return rows[0]
+            return TokenLookupResult(**rows[0])
 
         return None
 

+ 17 - 0
synapse/storage/databases/main/schema/delta/58/22puppet_token.sql

@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Whether the access token is an admin token for controlling another user.
+ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;

+ 26 - 7
synapse/types.py

@@ -29,6 +29,7 @@ from typing import (
     Tuple,
     Type,
     TypeVar,
+    Union,
 )
 
 import attr
@@ -38,6 +39,7 @@ from unpaddedbase64 import decode_base64
 from synapse.api.errors import Codes, SynapseError
 
 if TYPE_CHECKING:
+    from synapse.appservice.api import ApplicationService
     from synapse.storage.databases.main import DataStore
 
 # define a version of typing.Collection that works on python 3.5
@@ -74,6 +76,7 @@ class Requester(
             "shadow_banned",
             "device_id",
             "app_service",
+            "authenticated_entity",
         ],
     )
 ):
@@ -104,6 +107,7 @@ class Requester(
             "shadow_banned": self.shadow_banned,
             "device_id": self.device_id,
             "app_server_id": self.app_service.id if self.app_service else None,
+            "authenticated_entity": self.authenticated_entity,
         }
 
     @staticmethod
@@ -129,16 +133,18 @@ class Requester(
             shadow_banned=input["shadow_banned"],
             device_id=input["device_id"],
             app_service=appservice,
+            authenticated_entity=input["authenticated_entity"],
         )
 
 
 def create_requester(
-    user_id,
-    access_token_id=None,
-    is_guest=False,
-    shadow_banned=False,
-    device_id=None,
-    app_service=None,
+    user_id: Union[str, "UserID"],
+    access_token_id: Optional[int] = None,
+    is_guest: Optional[bool] = False,
+    shadow_banned: Optional[bool] = False,
+    device_id: Optional[str] = None,
+    app_service: Optional["ApplicationService"] = None,
+    authenticated_entity: Optional[str] = None,
 ):
     """
     Create a new ``Requester`` object
@@ -151,14 +157,27 @@ def create_requester(
         shadow_banned (bool):  True if the user making this request is shadow-banned.
         device_id (str|None):  device_id which was set at authentication time
         app_service (ApplicationService|None):  the AS requesting on behalf of the user
+        authenticated_entity: The entity that authenticated when making the request.
+            This is different to the user_id when an admin user or the server is
+            "puppeting" the user.
 
     Returns:
         Requester
     """
     if not isinstance(user_id, UserID):
         user_id = UserID.from_string(user_id)
+
+    if authenticated_entity is None:
+        authenticated_entity = user_id.to_string()
+
     return Requester(
-        user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
+        user_id,
+        access_token_id,
+        is_guest,
+        shadow_banned,
+        device_id,
+        app_service,
+        authenticated_entity,
     )
 
 

+ 13 - 16
tests/api/test_auth.py

@@ -29,6 +29,7 @@ from synapse.api.errors import (
     MissingClientTokenError,
     ResourceLimitError,
 )
+from synapse.storage.databases.main.registration import TokenLookupResult
 from synapse.types import UserID
 
 from tests import unittest
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_get_user_by_req_user_valid_token(self):
-        user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
+        user_info = TokenLookupResult(
+            user_id=self.test_user, token_id=5, device_id="device"
+        )
         self.store.get_user_by_access_token = Mock(
             return_value=defer.succeed(user_info)
         )
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
         self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
 
     def test_get_user_by_req_user_missing_token(self):
-        user_info = {"name": self.test_user, "token_id": "ditto"}
+        user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
         self.store.get_user_by_access_token = Mock(
             return_value=defer.succeed(user_info)
         )
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
     def test_get_user_from_macaroon(self):
         self.store.get_user_by_access_token = Mock(
             return_value=defer.succeed(
-                {"name": "@baldrick:matrix.org", "device_id": "device"}
+                TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
             )
         )
 
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
         user_info = yield defer.ensureDeferred(
             self.auth.get_user_by_access_token(macaroon.serialize())
         )
-        user = user_info["user"]
-        self.assertEqual(UserID.from_string(user_id), user)
+        self.assertEqual(user_id, user_info.user_id)
 
         # TODO: device_id should come from the macaroon, but currently comes
         # from the db.
-        self.assertEqual(user_info["device_id"], "device")
+        self.assertEqual(user_info.device_id, "device")
 
     @defer.inlineCallbacks
     def test_get_guest_user_from_macaroon(self):
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
         user_info = yield defer.ensureDeferred(
             self.auth.get_user_by_access_token(serialized)
         )
-        user = user_info["user"]
-        is_guest = user_info["is_guest"]
-        self.assertEqual(UserID.from_string(user_id), user)
-        self.assertTrue(is_guest)
+        self.assertEqual(user_id, user_info.user_id)
+        self.assertTrue(user_info.is_guest)
         self.store.get_user_by_id.assert_called_with(user_id)
 
     @defer.inlineCallbacks
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
             if token != tok:
                 return defer.succeed(None)
             return defer.succeed(
-                {
-                    "name": USER_ID,
-                    "is_guest": False,
-                    "token_id": 1234,
-                    "device_id": "DEVICE",
-                }
+                TokenLookupResult(
+                    user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
+                )
             )
 
         self.store.get_user_by_access_token = get_user

+ 2 - 2
tests/api/test_ratelimiting.py

@@ -43,7 +43,7 @@ class TestRatelimiter(unittest.TestCase):
 
     def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
         appservice = ApplicationService(
-            None, "example.com", id="foo", rate_limited=True,
+            None, "example.com", id="foo", rate_limited=True, sender="@as:example.com",
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 
@@ -68,7 +68,7 @@ class TestRatelimiter(unittest.TestCase):
 
     def test_allowed_appservice_via_can_requester_do_action(self):
         appservice = ApplicationService(
-            None, "example.com", id="foo", rate_limited=False,
+            None, "example.com", id="foo", rate_limited=False, sender="@as:example.com",
         )
         as_requester = create_requester("@user:example.com", app_service=appservice)
 

+ 1 - 0
tests/appservice/test_appservice.py

@@ -31,6 +31,7 @@ class ApplicationServiceTestCase(unittest.TestCase):
     def setUp(self):
         self.service = ApplicationService(
             id="unique_identifier",
+            sender="@as:test",
             url="some_url",
             token="some_token",
             hostname="matrix.org",  # only used by get_groups_for_user

+ 1 - 1
tests/handlers/test_device.py

@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         # make sure that our device ID has changed
         user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
 
-        self.assertEqual(user_info["device_id"], retrieved_device_id)
+        self.assertEqual(user_info.device_id, retrieved_device_id)
 
         # make sure the device has the display name that was set from the login
         res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

+ 1 - 1
tests/handlers/test_message.py

@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         self.info = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(self.access_token,)
         )
-        self.token_id = self.info["token_id"]
+        self.token_id = self.info.token_id
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 

+ 1 - 1
tests/push/test_email.py

@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(self.access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.pusher = self.get_success(
             self.hs.get_pusherpool().add_pusher(

+ 5 - 5
tests/push/test_http.py

@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
         user_tuple = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_tuple["token_id"]
+        token_id = user_tuple.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(

+ 1 - 1
tests/replication/test_pusher_shard.py

@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         user_dict = self.get_success(
             self.hs.get_datastore().get_user_by_access_token(access_token)
         )
-        token_id = user_dict["token_id"]
+        token_id = user_dict.token_id
 
         self.get_success(
             self.hs.get_pusherpool().add_pusher(

+ 1 - 0
tests/rest/client/v2_alpha/test_register.py

@@ -55,6 +55,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             self.hs.config.server_name,
             id="1234",
             namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
+            sender="@as:test",
         )
 
         self.hs.get_datastore().services_cache.append(appservice)

+ 4 - 6
tests/storage/test_registration.py

@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
             self.store.get_user_by_access_token(self.tokens[1])
         )
 
-        self.assertDictContainsSubset(
-            {"name": self.user_id, "device_id": self.device_id}, result
-        )
-
-        self.assertTrue("token_id" in result)
+        self.assertEqual(result.user_id, self.user_id)
+        self.assertEqual(result.device_id, self.device_id)
+        self.assertIsNotNone(result.token_id)
 
     @defer.inlineCallbacks
     def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
         user = yield defer.ensureDeferred(
             self.store.get_user_by_access_token(self.tokens[0])
         )
-        self.assertEqual(self.user_id, user["name"])
+        self.assertEqual(self.user_id, user.user_id)
 
         # now delete the rest
         yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))