Browse Source

Convert user_get_threepids response to attrs. (#16468)

This improves type annotations by not having a dictionary of Any values.
Patrick Cloke 6 months ago
parent
commit
cc865fffc0

+ 1 - 0
changelog.d/16468.misc

@@ -0,0 +1 @@
+Improve type hints.

+ 2 - 2
synapse/handlers/account_validity.py

@@ -212,8 +212,8 @@ class AccountValidityHandler:
 
         addresses = []
         for threepid in threepids:
-            if threepid["medium"] == "email":
-                addresses.append(threepid["address"])
+            if threepid.medium == "email":
+                addresses.append(threepid.address)
 
         return addresses
 

+ 3 - 1
synapse/handlers/admin.py

@@ -16,6 +16,8 @@ import abc
 import logging
 from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set
 
+import attr
+
 from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
 from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
@@ -93,7 +95,7 @@ class AdminHandler:
         ]
         user_info_dict["displayname"] = profile.display_name
         user_info_dict["avatar_url"] = profile.avatar_url
-        user_info_dict["threepids"] = threepids
+        user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
         user_info_dict["external_ids"] = external_ids
         user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
 

+ 2 - 2
synapse/handlers/deactivate_account.py

@@ -117,9 +117,9 @@ class DeactivateAccountHandler:
 
         # Remove any local threepid associations for this account.
         local_threepids = await self.store.user_get_threepids(user_id)
-        for threepid in local_threepids:
+        for local_threepid in local_threepids:
             await self._auth_handler.delete_local_threepid(
-                user_id, threepid["medium"], threepid["address"]
+                user_id, local_threepid.medium, local_threepid.address
             )
 
         # delete any devices belonging to the user, which will also

+ 1 - 1
synapse/module_api/__init__.py

@@ -678,7 +678,7 @@ class ModuleApi:
             "msisdn" for phone numbers, and an "address" key which value is the
             threepid's address.
         """
-        return await self._store.user_get_threepids(user_id)
+        return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]
 
     def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
         """Check if user exists.

+ 1 - 2
synapse/rest/admin/users.py

@@ -329,9 +329,8 @@ class UserRestServletV2(RestServlet):
 
             if threepids is not None:
                 # get changed threepids (added and removed)
-                # convert List[Dict[str, Any]] into Set[Tuple[str, str]]
                 cur_threepids = {
-                    (threepid["medium"], threepid["address"])
+                    (threepid.medium, threepid.address)
                     for threepid in await self.store.user_get_threepids(user_id)
                 }
                 add_threepids = new_threepids - cur_threepids

+ 3 - 1
synapse/rest/client/account.py

@@ -24,6 +24,8 @@ if TYPE_CHECKING or HAS_PYDANTIC_V2:
     from pydantic.v1 import StrictBool, StrictStr, constr
 else:
     from pydantic import StrictBool, StrictStr, constr
+
+import attr
 from typing_extensions import Literal
 
 from twisted.web.server import Request
@@ -595,7 +597,7 @@ class ThreepidRestServlet(RestServlet):
 
         threepids = await self.datastore.user_get_threepids(requester.user.to_string())
 
-        return 200, {"threepids": threepids}
+        return 200, {"threepids": [attr.asdict(t) for t in threepids]}
 
     # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
     # the endpoint is deprecated. (If you really want to, you could do this by reusing

+ 14 - 5
synapse/storage/databases/main/registration.py

@@ -143,6 +143,14 @@ class LoginTokenLookupResult:
     """The session ID advertised by the SSO Identity Provider."""
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class ThreepidResult:
+    medium: str
+    address: str
+    validated_at: int
+    added_at: int
+
+
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     def __init__(
         self,
@@ -988,13 +996,14 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
-        return await self.db_pool.simple_select_list(
+    async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
+        results = await self.db_pool.simple_select_list(
             "user_threepids",
-            {"user_id": user_id},
-            ["medium", "address", "validated_at", "added_at"],
-            "user_get_threepids",
+            keyvalues={"user_id": user_id},
+            retcols=["medium", "address", "validated_at", "added_at"],
+            desc="user_get_threepids",
         )
+        return [ThreepidResult(**r) for r in results]
 
     async def user_delete_threepid(
         self, user_id: str, medium: str, address: str

+ 4 - 4
tests/module_api/test_api.py

@@ -94,12 +94,12 @@ class ModuleApiTestCase(BaseModuleApiTestCase):
         self.assertEqual(len(emails), 1)
 
         email = emails[0]
-        self.assertEqual(email["medium"], "email")
-        self.assertEqual(email["address"], "bob@bobinator.bob")
+        self.assertEqual(email.medium, "email")
+        self.assertEqual(email.address, "bob@bobinator.bob")
 
         # Should these be 0?
-        self.assertEqual(email["validated_at"], 0)
-        self.assertEqual(email["added_at"], 0)
+        self.assertEqual(email.validated_at, 0)
+        self.assertEqual(email.added_at, 0)
 
         # Check that the displayname was assigned
         displayname = self.get_success(