Browse Source

Add type hints to groups code. (#9393)

Patrick Cloke 3 years ago
parent
commit
d2f0ec12d5

+ 1 - 1
changelog.d/9321.bugfix

@@ -1 +1 @@
-Assert a maximum length for the `client_secret` parameter for spec compliance.
+Assert a maximum length for some parameters for spec compliance.

+ 1 - 0
changelog.d/9393.bugfix

@@ -0,0 +1 @@
+Assert a maximum length for some parameters for spec compliance.

+ 1 - 0
mypy.ini

@@ -23,6 +23,7 @@ files =
   synapse/events/validator.py,
   synapse/events/spamcheck.py,
   synapse/federation,
+  synapse/groups,
   synapse/handlers,
   synapse/http/client.py,
   synapse/http/federation/matrix_federation_agent.py,

+ 5 - 0
synapse/api/constants.py

@@ -27,6 +27,11 @@ MAX_ALIAS_LENGTH = 255
 # the maximum length for a user id is 255 characters
 MAX_USERID_LENGTH = 255
 
+# The maximum length for a group id is 255 characters
+MAX_GROUPID_LENGTH = 255
+MAX_GROUP_CATEGORYID_LENGTH = 255
+MAX_GROUP_ROLEID_LENGTH = 255
+
 
 class Membership:
 

+ 39 - 2
synapse/federation/transport/server.py

@@ -21,6 +21,7 @@ import re
 from typing import Optional, Tuple, Type
 
 import synapse
+from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH
 from synapse.api.errors import Codes, FederationDeniedError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.api.urls import (
@@ -1118,7 +1119,17 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         if category_id == "":
-            raise SynapseError(400, "category_id cannot be empty string")
+            raise SynapseError(
+                400, "category_id cannot be empty string", Codes.INVALID_PARAM
+            )
+
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
 
         resp = await self.handler.update_group_summary_room(
             group_id,
@@ -1184,6 +1195,14 @@ class FederationGroupsCategoryServlet(BaseFederationServlet):
         if category_id == "":
             raise SynapseError(400, "category_id cannot be empty string")
 
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         resp = await self.handler.upsert_group_category(
             group_id, requester_user_id, category_id, content
         )
@@ -1240,7 +1259,17 @@ class FederationGroupsRoleServlet(BaseFederationServlet):
             raise SynapseError(403, "requester_user_id doesn't match origin")
 
         if role_id == "":
-            raise SynapseError(400, "role_id cannot be empty string")
+            raise SynapseError(
+                400, "role_id cannot be empty string", Codes.INVALID_PARAM
+            )
+
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
 
         resp = await self.handler.update_group_role(
             group_id, requester_user_id, role_id, content
@@ -1285,6 +1314,14 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
         if role_id == "":
             raise SynapseError(400, "role_id cannot be empty string")
 
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         resp = await self.handler.update_group_summary_user(
             group_id,
             requester_user_id,

+ 24 - 13
synapse/groups/attestations.py

@@ -37,13 +37,16 @@ An attestation is a signed blob of json that looks like:
 
 import logging
 import random
-from typing import Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from signedjson.sign import sign_json
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -63,15 +66,19 @@ UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
 class GroupAttestationSigning:
     """Creates and verifies group attestations."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.keyring = hs.get_keyring()
         self.clock = hs.get_clock()
         self.server_name = hs.hostname
         self.signing_key = hs.signing_key
 
     async def verify_attestation(
-        self, attestation, group_id, user_id, server_name=None
-    ):
+        self,
+        attestation: JsonDict,
+        group_id: str,
+        user_id: str,
+        server_name: Optional[str] = None,
+    ) -> None:
         """Verifies that the given attestation matches the given parameters.
 
         An optional server_name can be supplied to explicitly set which server's
@@ -100,16 +107,18 @@ class GroupAttestationSigning:
         if valid_until_ms < now:
             raise SynapseError(400, "Attestation expired")
 
+        assert server_name is not None
         await self.keyring.verify_json_for_server(
             server_name, attestation, now, "Group attestation"
         )
 
-    def create_attestation(self, group_id, user_id):
+    def create_attestation(self, group_id: str, user_id: str) -> JsonDict:
         """Create an attestation for the group_id and user_id with default
         validity length.
         """
-        validity_period = DEFAULT_ATTESTATION_LENGTH_MS
-        validity_period *= random.uniform(*DEFAULT_ATTESTATION_JITTER)
+        validity_period = DEFAULT_ATTESTATION_LENGTH_MS * random.uniform(
+            *DEFAULT_ATTESTATION_JITTER
+        )
         valid_until_ms = int(self.clock.time_msec() + validity_period)
 
         return sign_json(
@@ -126,7 +135,7 @@ class GroupAttestationSigning:
 class GroupAttestionRenewer:
     """Responsible for sending and receiving attestation updates."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
         self.assestations = hs.get_groups_attestation_signing()
@@ -139,7 +148,9 @@ class GroupAttestionRenewer:
                 self._start_renew_attestations, 30 * 60 * 1000
             )
 
-    async def on_renew_attestation(self, group_id, user_id, content):
+    async def on_renew_attestation(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> JsonDict:
         """When a remote updates an attestation"""
         attestation = content["attestation"]
 
@@ -154,10 +165,10 @@ class GroupAttestionRenewer:
 
         return {}
 
-    def _start_renew_attestations(self):
+    def _start_renew_attestations(self) -> None:
         return run_as_background_process("renew_attestations", self._renew_attestations)
 
-    async def _renew_attestations(self):
+    async def _renew_attestations(self) -> None:
         """Called periodically to check if we need to update any of our attestations"""
 
         now = self.clock.time_msec()
@@ -166,7 +177,7 @@ class GroupAttestionRenewer:
             now + UPDATE_ATTESTATION_TIME_MS
         )
 
-        async def _renew_attestation(group_user: Tuple[str, str]):
+        async def _renew_attestation(group_user: Tuple[str, str]) -> None:
             group_id, user_id = group_user
             try:
                 if not self.is_mine_id(group_id):

+ 143 - 83
synapse/groups/groups_server.py

@@ -16,12 +16,17 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.handlers.groups_local import GroupsLocalHandler
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
-from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
+from synapse.types import GroupID, JsonDict, RoomID, UserID, get_domain_from_id
 from synapse.util.async_helpers import concurrently_execute
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -39,7 +44,7 @@ MAX_LONG_DESC_LEN = 10000
 
 
 class GroupsServerWorkerHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.store = hs.get_datastore()
         self.room_list_handler = hs.get_room_list_handler()
@@ -54,16 +59,21 @@ class GroupsServerWorkerHandler:
         self.profile_handler = hs.get_profile_handler()
 
     async def check_group_is_ours(
-        self, group_id, requester_user_id, and_exists=False, and_is_admin=None
-    ):
+        self,
+        group_id: str,
+        requester_user_id: str,
+        and_exists: bool = False,
+        and_is_admin: Optional[str] = None,
+    ) -> Optional[dict]:
         """Check that the group is ours, and optionally if it exists.
 
         If group does exist then return group.
 
         Args:
-            group_id (str)
-            and_exists (bool): whether to also check if group exists
-            and_is_admin (str): whether to also check if given str is a user_id
+            group_id: The group ID to check.
+            requester_user_id: The user ID of the requester.
+            and_exists: whether to also check if group exists
+            and_is_admin: whether to also check if given str is a user_id
                 that is an admin
         """
         if not self.is_mine_id(group_id):
@@ -86,7 +96,9 @@ class GroupsServerWorkerHandler:
 
         return group
 
-    async def get_group_summary(self, group_id, requester_user_id):
+    async def get_group_summary(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the summary for a group as seen by requester_user_id.
 
         The group summary consists of the profile of the room, and a curated
@@ -119,6 +131,8 @@ class GroupsServerWorkerHandler:
             entry = await self.room_list_handler.generate_room_entry(
                 room_id, len(joined_users), with_alias=False, allow_private=True
             )
+            if entry is None:
+                continue
             entry = dict(entry)  # so we don't change what's cached
             entry.pop("room_id", None)
 
@@ -126,22 +140,22 @@ class GroupsServerWorkerHandler:
 
         rooms.sort(key=lambda e: e.get("order", 0))
 
-        for entry in users:
-            user_id = entry["user_id"]
+        for user in users:
+            user_id = user["user_id"]
 
             if not self.is_mine_id(requester_user_id):
                 attestation = await self.store.get_remote_attestation(group_id, user_id)
                 if not attestation:
                     continue
 
-                entry["attestation"] = attestation
+                user["attestation"] = attestation
             else:
-                entry["attestation"] = self.attestations.create_attestation(
+                user["attestation"] = self.attestations.create_attestation(
                     group_id, user_id
                 )
 
             user_profile = await self.profile_handler.get_profile_from_cache(user_id)
-            entry.update(user_profile)
+            user.update(user_profile)
 
         users.sort(key=lambda e: e.get("order", 0))
 
@@ -164,40 +178,43 @@ class GroupsServerWorkerHandler:
             "user": membership_info,
         }
 
-    async def get_group_categories(self, group_id, requester_user_id):
+    async def get_group_categories(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get all categories in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
         categories = await self.store.get_group_categories(group_id=group_id)
         return {"categories": categories}
 
-    async def get_group_category(self, group_id, requester_user_id, category_id):
+    async def get_group_category(
+        self, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
         """Get a specific category in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
-        res = await self.store.get_group_category(
+        return await self.store.get_group_category(
             group_id=group_id, category_id=category_id
         )
 
-        logger.info("group %s", res)
-
-        return res
-
-    async def get_group_roles(self, group_id, requester_user_id):
+    async def get_group_roles(self, group_id: str, requester_user_id: str) -> JsonDict:
         """Get all roles in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
         roles = await self.store.get_group_roles(group_id=group_id)
         return {"roles": roles}
 
-    async def get_group_role(self, group_id, requester_user_id, role_id):
+    async def get_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
         """Get a specific role in a group (as seen by user)"""
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
 
-        res = await self.store.get_group_role(group_id=group_id, role_id=role_id)
-        return res
+        return await self.store.get_group_role(group_id=group_id, role_id=role_id)
 
-    async def get_group_profile(self, group_id, requester_user_id):
+    async def get_group_profile(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the group profile as seen by requester_user_id"""
 
         await self.check_group_is_ours(group_id, requester_user_id)
@@ -219,7 +236,9 @@ class GroupsServerWorkerHandler:
         else:
             raise SynapseError(404, "Unknown group")
 
-    async def get_users_in_group(self, group_id, requester_user_id):
+    async def get_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the users in group as seen by requester_user_id.
 
         The ordering is arbitrary at the moment
@@ -268,7 +287,9 @@ class GroupsServerWorkerHandler:
 
         return {"chunk": chunk, "total_user_count_estimate": len(user_results)}
 
-    async def get_invited_users_in_group(self, group_id, requester_user_id):
+    async def get_invited_users_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the users that have been invited to a group as seen by requester_user_id.
 
         The ordering is arbitrary at the moment
@@ -298,7 +319,9 @@ class GroupsServerWorkerHandler:
 
         return {"chunk": user_profiles, "total_user_count_estimate": len(invited_users)}
 
-    async def get_rooms_in_group(self, group_id, requester_user_id):
+    async def get_rooms_in_group(
+        self, group_id: str, requester_user_id: str
+    ) -> JsonDict:
         """Get the rooms in group as seen by requester_user_id
 
         This returns rooms in order of decreasing number of joined users
@@ -336,15 +359,20 @@ class GroupsServerWorkerHandler:
 
 
 class GroupsServerHandler(GroupsServerWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         # Ensure attestations get renewed
         hs.get_groups_attestation_renewer()
 
     async def update_group_summary_room(
-        self, group_id, requester_user_id, room_id, category_id, content
-    ):
+        self,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        category_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Add/update a room to the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -367,8 +395,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def delete_group_summary_room(
-        self, group_id, requester_user_id, room_id, category_id
-    ):
+        self, group_id: str, requester_user_id: str, room_id: str, category_id: str
+    ) -> JsonDict:
         """Remove a room from the summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -380,7 +408,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def set_group_join_policy(self, group_id, requester_user_id, content):
+    async def set_group_join_policy(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Sets the group join policy.
 
         Currently supported policies are:
@@ -400,8 +430,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_group_category(
-        self, group_id, requester_user_id, category_id, content
-    ):
+        self, group_id: str, requester_user_id: str, category_id: str, content: JsonDict
+    ) -> JsonDict:
         """Add/Update a group category"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -419,7 +449,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def delete_group_category(self, group_id, requester_user_id, category_id):
+    async def delete_group_category(
+        self, group_id: str, requester_user_id: str, category_id: str
+    ) -> JsonDict:
         """Delete a group category"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -431,7 +463,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def update_group_role(self, group_id, requester_user_id, role_id, content):
+    async def update_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str, content: JsonDict
+    ) -> JsonDict:
         """Add/update a role in a group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -447,7 +481,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def delete_group_role(self, group_id, requester_user_id, role_id):
+    async def delete_group_role(
+        self, group_id: str, requester_user_id: str, role_id: str
+    ) -> JsonDict:
         """Remove role from group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -458,8 +494,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_group_summary_user(
-        self, group_id, requester_user_id, user_id, role_id, content
-    ):
+        self,
+        group_id: str,
+        requester_user_id: str,
+        user_id: str,
+        role_id: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Add/update a users entry in the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -480,8 +521,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def delete_group_summary_user(
-        self, group_id, requester_user_id, user_id, role_id
-    ):
+        self, group_id: str, requester_user_id: str, user_id: str, role_id: str
+    ) -> JsonDict:
         """Remove a user from the group summary"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -493,7 +534,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def update_group_profile(self, group_id, requester_user_id, content):
+    async def update_group_profile(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> None:
         """Update the group profile"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -524,7 +567,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         await self.store.update_group_profile(group_id, profile)
 
-    async def add_room_to_group(self, group_id, requester_user_id, room_id, content):
+    async def add_room_to_group(
+        self, group_id: str, requester_user_id: str, room_id: str, content: JsonDict
+    ) -> JsonDict:
         """Add room to group"""
         RoomID.from_string(room_id)  # Ensure valid room id
 
@@ -539,8 +584,13 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         return {}
 
     async def update_room_in_group(
-        self, group_id, requester_user_id, room_id, config_key, content
-    ):
+        self,
+        group_id: str,
+        requester_user_id: str,
+        room_id: str,
+        config_key: str,
+        content: JsonDict,
+    ) -> JsonDict:
         """Update room in group"""
         RoomID.from_string(room_id)  # Ensure valid room id
 
@@ -559,7 +609,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def remove_room_from_group(self, group_id, requester_user_id, room_id):
+    async def remove_room_from_group(
+        self, group_id: str, requester_user_id: str, room_id: str
+    ) -> JsonDict:
         """Remove room from group"""
         await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
@@ -569,12 +621,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def invite_to_group(self, group_id, user_id, requester_user_id, content):
+    async def invite_to_group(
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Invite user to group"""
 
         group = await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True, and_is_admin=requester_user_id
         )
+        if not group:
+            raise SynapseError(400, "Group does not exist", errcode=Codes.BAD_STATE)
 
         # TODO: Check if user knocked
 
@@ -597,6 +653,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         if self.hs.is_mine_id(user_id):
             groups_local = self.hs.get_groups_local_handler()
+            assert isinstance(
+                groups_local, GroupsLocalHandler
+            ), "Workers cannot invites users to groups."
             res = await groups_local.on_invite(group_id, user_id, content)
             local_attestation = None
         else:
@@ -632,6 +691,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
                 local_attestation=local_attestation,
                 remote_attestation=remote_attestation,
             )
+            return {"state": "join"}
         elif res["state"] == "invite":
             await self.store.add_group_invite(group_id, user_id)
             return {"state": "invite"}
@@ -640,13 +700,17 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         else:
             raise SynapseError(502, "Unknown state returned by HS")
 
-    async def _add_user(self, group_id, user_id, content):
+    async def _add_user(
+        self, group_id: str, user_id: str, content: JsonDict
+    ) -> Optional[JsonDict]:
         """Add a user to a group based on a content dict.
 
         See accept_invite, join_group.
         """
         if not self.hs.is_mine_id(user_id):
-            local_attestation = self.attestations.create_attestation(group_id, user_id)
+            local_attestation = self.attestations.create_attestation(
+                group_id, user_id
+            )  # type: Optional[JsonDict]
 
             remote_attestation = content["attestation"]
 
@@ -670,7 +734,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return local_attestation
 
-    async def accept_invite(self, group_id, requester_user_id, content):
+    async def accept_invite(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """User tries to accept an invite to the group.
 
         This is different from them asking to join, and so should error if no
@@ -689,7 +755,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"state": "join", "attestation": local_attestation}
 
-    async def join_group(self, group_id, requester_user_id, content):
+    async def join_group(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """User tries to join the group.
 
         This will error if the group requires an invite/knock to join
@@ -698,6 +766,8 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         group_info = await self.check_group_is_ours(
             group_id, requester_user_id, and_exists=True
         )
+        if not group_info:
+            raise SynapseError(404, "Group does not exist", errcode=Codes.NOT_FOUND)
         if group_info["join_policy"] != "open":
             raise SynapseError(403, "Group is not publicly joinable")
 
@@ -705,25 +775,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"state": "join", "attestation": local_attestation}
 
-    async def knock(self, group_id, requester_user_id, content):
-        """A user requests becoming a member of the group"""
-        await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
-        raise NotImplementedError()
-
-    async def accept_knock(self, group_id, requester_user_id, content):
-        """Accept a users knock to the room.
-
-        Errors if the user hasn't knocked, rather than inviting them.
-        """
-
-        await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
-
-        raise NotImplementedError()
-
     async def remove_user_from_group(
-        self, group_id, user_id, requester_user_id, content
-    ):
+        self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         """Remove a user from the group; either a user is leaving or an admin
         kicked them.
         """
@@ -745,6 +799,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         if is_kick:
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
+                assert isinstance(
+                    groups_local, GroupsLocalHandler
+                ), "Workers cannot remove users from groups."
                 await groups_local.user_removed_from_group(group_id, user_id, {})
             else:
                 await self.transport_client.remove_user_from_group_notification(
@@ -761,14 +818,15 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {}
 
-    async def create_group(self, group_id, requester_user_id, content):
-        group = await self.check_group_is_ours(group_id, requester_user_id)
-
+    async def create_group(
+        self, group_id: str, requester_user_id: str, content: JsonDict
+    ) -> JsonDict:
         logger.info("Attempting to create group with ID: %r", group_id)
 
         # parsing the id into a GroupID validates it.
         group_id_obj = GroupID.from_string(group_id)
 
+        group = await self.check_group_is_ours(group_id, requester_user_id)
         if group:
             raise SynapseError(400, "Group already exists")
 
@@ -813,7 +871,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
             local_attestation = self.attestations.create_attestation(
                 group_id, requester_user_id
-            )
+            )  # type: Optional[JsonDict]
         else:
             local_attestation = None
             remote_attestation = None
@@ -836,15 +894,14 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
 
         return {"group_id": group_id}
 
-    async def delete_group(self, group_id, requester_user_id):
+    async def delete_group(self, group_id: str, requester_user_id: str) -> None:
         """Deletes a group, kicking out all current members.
 
         Only group admins or server admins can call this request
 
         Args:
-            group_id (str)
-            request_user_id (str)
-
+            group_id: The group ID to delete.
+            requester_user_id: The user requesting to delete the group.
         """
 
         await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
@@ -867,6 +924,9 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         async def _kick_user_from_group(user_id):
             if self.hs.is_mine_id(user_id):
                 groups_local = self.hs.get_groups_local_handler()
+                assert isinstance(
+                    groups_local, GroupsLocalHandler
+                ), "Workers cannot kick users from groups."
                 await groups_local.user_removed_from_group(group_id, user_id, {})
             else:
                 await self.transport_client.remove_user_from_group_notification(
@@ -898,7 +958,7 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
         await self.store.delete_group(group_id)
 
 
-def _parse_join_policy_from_contents(content):
+def _parse_join_policy_from_contents(content: JsonDict) -> Optional[str]:
     """Given a content for a request, return the specified join policy or None"""
 
     join_policy_dict = content.get("m.join_policy")
@@ -908,7 +968,7 @@ def _parse_join_policy_from_contents(content):
         return None
 
 
-def _parse_join_policy_dict(join_policy_dict):
+def _parse_join_policy_dict(join_policy_dict: JsonDict) -> str:
     """Given a dict for the "m.join_policy" config return the join policy specified"""
     join_policy_type = join_policy_dict.get("type")
     if not join_policy_type:
@@ -919,7 +979,7 @@ def _parse_join_policy_dict(join_policy_dict):
     return join_policy_type
 
 
-def _parse_visibility_from_contents(content):
+def _parse_visibility_from_contents(content: JsonDict) -> bool:
     """Given a content for a request parse out whether the entity should be
     public or not
     """
@@ -933,7 +993,7 @@ def _parse_visibility_from_contents(content):
     return is_public
 
 
-def _parse_visibility_dict(visibility):
+def _parse_visibility_dict(visibility: JsonDict) -> bool:
     """Given a dict for the "m.visibility" config return if the entity should
     be public or not
     """

+ 120 - 23
synapse/rest/client/v2_alpha/groups.py

@@ -16,11 +16,16 @@
 
 import logging
 from functools import wraps
-from typing import TYPE_CHECKING, Tuple
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.web.http import Request
 
-from synapse.api.errors import SynapseError
+from synapse.api.constants import (
+    MAX_GROUP_CATEGORYID_LENGTH,
+    MAX_GROUP_ROLEID_LENGTH,
+    MAX_GROUPID_LENGTH,
+)
+from synapse.api.errors import Codes, SynapseError
 from synapse.handlers.groups_local import GroupsLocalHandler
 from synapse.http.servlet import (
     RestServlet,
@@ -84,7 +89,9 @@ class GroupServlet(RestServlet):
         assert_params_in_dict(
             content, ("name", "avatar_url", "short_description", "long_description")
         )
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot create group profiles."
         await self.groups_handler.update_group_profile(
             group_id, requester_user_id, content
         )
@@ -137,13 +144,26 @@ class GroupSummaryRoomsCatServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, category_id: str, room_id: str
+        self, request: Request, group_id: str, category_id: Optional[str], room_id: str
     ):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if category_id == "":
+            raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+        if category_id and len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.update_group_summary_room(
             group_id,
             requester_user_id,
@@ -161,7 +181,9 @@ class GroupSummaryRoomsCatServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group profiles."
         resp = await self.groups_handler.delete_group_summary_room(
             group_id, requester_user_id, room_id=room_id, category_id=category_id
         )
@@ -202,8 +224,21 @@ class GroupCategoryServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if not category_id:
+            raise SynapseError(400, "category_id cannot be empty", Codes.INVALID_PARAM)
+
+        if len(category_id) > MAX_GROUP_CATEGORYID_LENGTH:
+            raise SynapseError(
+                400,
+                "category_id may not be longer than %s characters"
+                % (MAX_GROUP_CATEGORYID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         resp = await self.groups_handler.update_group_category(
             group_id, requester_user_id, category_id=category_id, content=content
         )
@@ -217,7 +252,9 @@ class GroupCategoryServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         resp = await self.groups_handler.delete_group_category(
             group_id, requester_user_id, category_id=category_id
         )
@@ -279,8 +316,21 @@ class GroupRoleServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if not role_id:
+            raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+        if len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group roles."
         resp = await self.groups_handler.update_group_role(
             group_id, requester_user_id, role_id=role_id, content=content
         )
@@ -294,7 +344,9 @@ class GroupRoleServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group roles."
         resp = await self.groups_handler.delete_group_role(
             group_id, requester_user_id, role_id=role_id
         )
@@ -347,13 +399,26 @@ class GroupSummaryUsersRoleServlet(RestServlet):
 
     @_validate_group_id
     async def on_PUT(
-        self, request: Request, group_id: str, role_id: str, user_id: str
+        self, request: Request, group_id: str, role_id: Optional[str], user_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
+        if role_id == "":
+            raise SynapseError(400, "role_id cannot be empty", Codes.INVALID_PARAM)
+
+        if role_id and len(role_id) > MAX_GROUP_ROLEID_LENGTH:
+            raise SynapseError(
+                400,
+                "role_id may not be longer than %s characters"
+                % (MAX_GROUP_ROLEID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.update_group_summary_user(
             group_id,
             requester_user_id,
@@ -371,7 +436,9 @@ class GroupSummaryUsersRoleServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group summaries."
         resp = await self.groups_handler.delete_group_summary_user(
             group_id, requester_user_id, user_id=user_id, role_id=role_id
         )
@@ -465,7 +532,9 @@ class GroupSettingJoinPolicyServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group join policy."
         result = await self.groups_handler.set_group_join_policy(
             group_id, requester_user_id, content
         )
@@ -494,7 +563,19 @@ class GroupCreateServlet(RestServlet):
         localpart = content.pop("localpart")
         group_id = GroupID(localpart, self.server_name).to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        if not localpart:
+            raise SynapseError(400, "Group ID cannot be empty", Codes.INVALID_PARAM)
+
+        if len(group_id) > MAX_GROUPID_LENGTH:
+            raise SynapseError(
+                400,
+                "Group ID may not be longer than %s characters" % (MAX_GROUPID_LENGTH,),
+                Codes.INVALID_PARAM,
+            )
+
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot create groups."
         result = await self.groups_handler.create_group(
             group_id, requester_user_id, content
         )
@@ -523,7 +604,9 @@ class GroupAdminRoomsServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify rooms in a group."
         result = await self.groups_handler.add_room_to_group(
             group_id, requester_user_id, room_id, content
         )
@@ -537,7 +620,9 @@ class GroupAdminRoomsServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
         requester_user_id = requester.user.to_string()
 
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         result = await self.groups_handler.remove_room_from_group(
             group_id, requester_user_id, room_id
         )
@@ -567,7 +652,9 @@ class GroupAdminRoomsConfigServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot modify group categories."
         result = await self.groups_handler.update_room_in_group(
             group_id, requester_user_id, room_id, config_key, content
         )
@@ -597,7 +684,9 @@ class GroupAdminUsersInviteServlet(RestServlet):
 
         content = parse_json_object_from_request(request)
         config = content.get("config", {})
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot invite users to a group."
         result = await self.groups_handler.invite(
             group_id, user_id, requester_user_id, config
         )
@@ -624,7 +713,9 @@ class GroupAdminUsersKickServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot kick users from a group."
         result = await self.groups_handler.remove_user_from_group(
             group_id, user_id, requester_user_id, content
         )
@@ -649,7 +740,9 @@ class GroupSelfLeaveServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot leave a group for a users."
         result = await self.groups_handler.remove_user_from_group(
             group_id, requester_user_id, requester_user_id, content
         )
@@ -674,7 +767,9 @@ class GroupSelfJoinServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot join a user to a group."
         result = await self.groups_handler.join_group(
             group_id, requester_user_id, content
         )
@@ -699,7 +794,9 @@ class GroupSelfAcceptInviteServlet(RestServlet):
         requester_user_id = requester.user.to_string()
 
         content = parse_json_object_from_request(request)
-        assert isinstance(self.groups_handler, GroupsLocalHandler)
+        assert isinstance(
+            self.groups_handler, GroupsLocalHandler
+        ), "Workers cannot accept an invite to a group."
         result = await self.groups_handler.accept_invite(
             group_id, requester_user_id, content
         )

+ 7 - 2
synapse/storage/databases/main/group_server.py

@@ -14,7 +14,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple
+
+from typing_extensions import TypedDict
 
 from synapse.api.errors import SynapseError
 from synapse.storage._base import SQLBaseStore, db_to_json
@@ -26,6 +28,9 @@ from synapse.util import json_encoder
 _DEFAULT_CATEGORY_ID = ""
 _DEFAULT_ROLE_ID = ""
 
+# A room in a group.
+_RoomInGroup = TypedDict("_RoomInGroup", {"room_id": str, "is_public": bool})
+
 
 class GroupServerWorkerStore(SQLBaseStore):
     async def get_group(self, group_id: str) -> Optional[Dict[str, Any]]:
@@ -72,7 +77,7 @@ class GroupServerWorkerStore(SQLBaseStore):
 
     async def get_rooms_in_group(
         self, group_id: str, include_private: bool = False
-    ) -> List[Dict[str, Union[str, bool]]]:
+    ) -> List[_RoomInGroup]:
         """Retrieve the rooms that belong to a given group. Does not return rooms that
         lack members.