Browse Source

Revert "Merge two of the room join codepaths"

This reverts commit cf81375b94c4763766440471e632fc4b103450ab.

It subtly violates a guest joining auth check
Daniel Wagner-Hall 8 năm trước cách đây
mục cha
commit
4de08a4672

+ 0 - 5
synapse/api/errors.py

@@ -84,11 +84,6 @@ class RegistrationError(SynapseError):
     pass
 
 
-class BadIdentifierError(SynapseError):
-    """An error indicating an identifier couldn't be parsed."""
-    pass
-
-
 class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
     def __init__(self, *args, **kwargs):

+ 2 - 9
synapse/handlers/profile.py

@@ -169,15 +169,8 @@ class ProfileHandler(BaseHandler):
             consumeErrors=True
         ).addErrback(unwrapFirstError)
 
-        if displayname is None:
-            del state["displayname"]
-        else:
-            state["displayname"] = displayname
-
-        if avatar_url is None:
-            del state["avatar_url"]
-        else:
-            state["avatar_url"] = avatar_url
+        state["displayname"] = displayname
+        state["avatar_url"] = avatar_url
 
         defer.returnValue(None)
 

+ 9 - 35
synapse/handlers/room.py

@@ -527,17 +527,7 @@ class RoomMemberHandler(BaseHandler):
         defer.returnValue({"room_id": room_id})
 
     @defer.inlineCallbacks
-    def lookup_room_alias(self, room_alias):
-        """
-        Gets the room ID for an alias.
-
-        Args:
-            room_alias (str): The room alias to look up.
-        Returns:
-            A tuple of the room ID (str) and the hosts hosting the room ([str])
-        Raises:
-            SynapseError if the room couldn't be looked up.
-        """
+    def join_room_alias(self, joinee, room_alias, content={}):
         directory_handler = self.hs.get_handlers().directory_handler
         mapping = yield directory_handler.get_association(room_alias)
 
@@ -549,40 +539,24 @@ class RoomMemberHandler(BaseHandler):
         if not hosts:
             raise SynapseError(404, "No known servers")
 
-        defer.returnValue((room_id, hosts))
-
-    @defer.inlineCallbacks
-    def do_join(self, requester, room_id, hosts=None):
-        """
-        Joins requester to room_id.
-
-        Args:
-            requester (Requester): The user joining the room.
-            room_id (str): The room ID (not alias) being joined.
-            hosts ([str]): A list of hosts which are hopefully in the room.
-        Raises:
-            SynapseError if the room couldn't be joined.
-        """
-        hosts = hosts or []
-
-        content = {"membership": Membership.JOIN}
-        if requester.is_guest:
-            content["kind"] = "guest"
-
-        yield collect_presencelike_data(self.distributor, requester.user, content)
+        # If event doesn't include a display name, add one.
+        yield collect_presencelike_data(self.distributor, joinee, content)
 
+        content.update({"membership": Membership.JOIN})
         builder = self.event_builder_factory.new({
             "type": EventTypes.Member,
-            "state_key": requester.user.to_string(),
+            "state_key": joinee.to_string(),
             "room_id": room_id,
-            "sender": requester.user.to_string(),
-            "membership": Membership.JOIN,  # For backwards compatibility
+            "sender": joinee.to_string(),
+            "membership": Membership.JOIN,
             "content": content,
         })
         event, context = yield self._create_new_client_event(builder)
 
         yield self._do_join(event, context, room_hosts=hosts)
 
+        defer.returnValue({"room_id": room_id})
+
     @defer.inlineCallbacks
     def _do_join(self, event, context, room_hosts=None):
         room_id = event.room_id

+ 55 - 13
synapse/rest/client/v1/room.py

@@ -216,7 +216,11 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
 
 # TODO: Needs unit testing for room ID + alias joins
 class JoinRoomAliasServlet(ClientV1RestServlet):
-    PATTERNS = client_path_patterns("/join/(?P<room_identifier>[^/]*)$")
+
+    def register(self, http_server):
+        # /join/$room_identifier[/$txn_id]
+        PATTERNS = ("/join/(?P<room_identifier>[^/]*)")
+        register_txn_path(self, PATTERNS, http_server)
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_identifier, txn_id=None):
@@ -225,22 +229,60 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
             allow_guest=True,
         )
 
-        handler = self.handlers.room_member_handler
+        # the identifier could be a room alias or a room id. Try one then the
+        # other if it fails to parse, without swallowing other valid
+        # SynapseErrors.
 
-        room_id = None
-        hosts = []
-        if RoomAlias.is_valid(room_identifier):
-            room_alias = RoomAlias.from_string(room_identifier)
-            room_id, hosts = yield handler.lookup_room_alias(room_alias)
-        else:
-            room_id = RoomID.from_string(room_identifier).to_string()
+        identifier = None
+        is_room_alias = False
+        try:
+            identifier = RoomAlias.from_string(room_identifier)
+            is_room_alias = True
+        except SynapseError:
+            identifier = RoomID.from_string(room_identifier)
 
         # TODO: Support for specifying the home server to join with?
 
-        yield handler.do_join(
-            requester, room_id, hosts=hosts
-        )
-        defer.returnValue((200, {"room_id": room_id}))
+        if is_room_alias:
+            handler = self.handlers.room_member_handler
+            ret_dict = yield handler.join_room_alias(
+                requester.user,
+                identifier,
+            )
+            defer.returnValue((200, ret_dict))
+        else:  # room id
+            msg_handler = self.handlers.message_handler
+            content = {"membership": Membership.JOIN}
+            if requester.is_guest:
+                content["kind"] = "guest"
+            yield msg_handler.create_and_send_event(
+                {
+                    "type": EventTypes.Member,
+                    "content": content,
+                    "room_id": identifier.to_string(),
+                    "sender": requester.user.to_string(),
+                    "state_key": requester.user.to_string(),
+                },
+                token_id=requester.access_token_id,
+                txn_id=txn_id,
+                is_guest=requester.is_guest,
+            )
+
+            defer.returnValue((200, {"room_id": identifier.to_string()}))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, room_identifier, txn_id):
+        try:
+            defer.returnValue(
+                self.txns.get_client_transaction(request, txn_id)
+            )
+        except KeyError:
+            pass
+
+        response = yield self.on_POST(request, room_identifier, txn_id)
+
+        self.txns.store_client_transaction(request, txn_id, response)
+        defer.returnValue(response)
 
 
 # TODO: Needs unit testing

+ 3 - 11
synapse/types.py

@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.api.errors import SynapseError, BadIdentifierError
+from synapse.api.errors import SynapseError
 
 from collections import namedtuple
 
@@ -51,13 +51,13 @@ class DomainSpecificString(
     def from_string(cls, s):
         """Parse the string given by 's' into a structure object."""
         if len(s) < 1 or s[0] != cls.SIGIL:
-            raise BadIdentifierError(400, "Expected %s string to start with '%s'" % (
+            raise SynapseError(400, "Expected %s string to start with '%s'" % (
                 cls.__name__, cls.SIGIL,
             ))
 
         parts = s[1:].split(':', 1)
         if len(parts) != 2:
-            raise BadIdentifierError(
+            raise SynapseError(
                 400, "Expected %s of the form '%slocalname:domain'" % (
                     cls.__name__, cls.SIGIL,
                 )
@@ -69,14 +69,6 @@ class DomainSpecificString(
         # names on one HS
         return cls(localpart=parts[0], domain=domain)
 
-    @classmethod
-    def is_valid(cls, s):
-        try:
-            cls.from_string(s)
-            return True
-        except:
-            return False
-
     def to_string(self):
         """Return a string encoding the fields of the structure object."""
         return "%s%s:%s" % (self.SIGIL, self.localpart, self.domain)