Browse Source

Factor SSO success handling out of CAS login (#4264)

This is mostly factoring out the post-CAS-login code to somewhere we can reuse
it for other SSO flows, but it also fixes the userid mapping while we're at it.
Richard van der Hoff 5 years ago
parent
commit
c588b9b9e4
5 changed files with 184 additions and 32 deletions
  1. 1 0
      changelog.d/4264.bugfix
  2. 11 2
      synapse/handlers/auth.py
  3. 76 29
      synapse/rest/client/v1/login.py
  4. 66 0
      synapse/types.py
  5. 30 1
      tests/test_types.py

+ 1 - 0
changelog.d/4264.bugfix

@@ -0,0 +1 @@
+Fix CAS login when username is not valid in an MXID

+ 11 - 2
synapse/handlers/auth.py

@@ -563,10 +563,10 @@ class AuthHandler(BaseHandler):
         insensitively, but return None if there are multiple inexact matches.
 
         Args:
-            (str) user_id: complete @user:id
+            (unicode|bytes) user_id: complete @user:id
 
         Returns:
-            defer.Deferred: (str) canonical_user_id, or None if zero or
+            defer.Deferred: (unicode) canonical_user_id, or None if zero or
             multiple matches
         """
         res = yield self._find_user_id_and_pwd_hash(user_id)
@@ -954,6 +954,15 @@ class MacaroonGenerator(object):
         return macaroon.serialize()
 
     def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
+        """
+
+        Args:
+            user_id (unicode):
+            duration_in_ms (int):
+
+        Returns:
+            unicode
+        """
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()

+ 76 - 29
synapse/rest/client/v1/login.py

@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.http.server import finish_request
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+    RestServlet,
+    parse_json_object_from_request,
+    parse_string,
+)
+from synapse.types import UserID, map_username_to_mxid_localpart
 from synapse.util.msisdn import phone_number_to_msisdn
 
 from .base import ClientV1RestServlet, client_path_patterns
@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
         self.cas_server_url = hs.config.cas_server_url
         self.cas_service_url = hs.config.cas_service_url
         self.cas_required_attributes = hs.config.cas_required_attributes
-        self.auth_handler = hs.get_auth_handler()
-        self.handlers = hs.get_handlers()
-        self.macaroon_gen = hs.get_macaroon_generator()
+        self._sso_auth_handler = SSOAuthHandler(hs)
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        client_redirect_url = request.args[b"redirectUrl"][0]
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
         http_client = self.hs.get_simple_http_client()
         uri = self.cas_server_url + "/proxyValidate"
         args = {
-            "ticket": request.args[b"ticket"][0].decode('ascii'),
+            "ticket": parse_string(request, "ticket", required=True),
             "service": self.cas_service_url
         }
         try:
@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
         result = yield self.handle_cas_response(request, body, client_redirect_url)
         defer.returnValue(result)
 
-    @defer.inlineCallbacks
     def handle_cas_response(self, request, cas_response_body, client_redirect_url):
         user, attributes = self.parse_cas_response(cas_response_body)
 
@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
                 if required_value != actual_value:
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
-        user_id = UserID(user, self.hs.hostname).to_string()
-        auth_handler = self.auth_handler
-        registered_user_id = yield auth_handler.check_user_exists(user_id)
-        if not registered_user_id:
-            registered_user_id, _ = (
-                yield self.handlers.registration_handler.register(localpart=user)
-            )
-
-        login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+        return self._sso_auth_handler.on_successful_auth(
+            user, request, client_redirect_url,
         )
-        redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
-                                                            login_token)
-        request.redirect(redirect_url)
-        finish_request(request)
-
-    def add_login_token_to_redirect_url(self, url, token):
-        url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({"loginToken": token})
-        url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
-        return urllib.parse.urlunparse(url_parts)
 
     def parse_cas_response(self, cas_response_body):
         user = None
@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
         return user, attributes
 
 
+class SSOAuthHandler(object):
+    """
+    Utility class for Resources and Servlets which handle the response from a SSO
+    service
+
+    Args:
+        hs (synapse.server.HomeServer)
+    """
+    def __init__(self, hs):
+        self._hostname = hs.hostname
+        self._auth_handler = hs.get_auth_handler()
+        self._registration_handler = hs.get_handlers().registration_handler
+        self._macaroon_gen = hs.get_macaroon_generator()
+
+    @defer.inlineCallbacks
+    def on_successful_auth(
+        self, username, request, client_redirect_url,
+    ):
+        """Called once the user has successfully authenticated with the SSO.
+
+        Registers the user if necessary, and then returns a redirect (with
+        a login token) to the client.
+
+        Args:
+            username (unicode|bytes): the remote user id. We'll map this onto
+                something sane for a MXID localpath.
+
+            request (SynapseRequest): the incoming request from the browser. We'll
+                respond to it with a redirect.
+
+            client_redirect_url (unicode): the redirect_url the client gave us when
+                it first started the process.
+
+        Returns:
+            Deferred[none]: Completes once we have handled the request.
+        """
+        localpart = map_username_to_mxid_localpart(username)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+        if not registered_user_id:
+            registered_user_id, _ = (
+                yield self._registration_handler.register(
+                    localpart=localpart,
+                    generate_token=False,
+                )
+            )
+
+        login_token = self._macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+        redirect_url = self._add_login_token_to_redirect_url(
+            client_redirect_url, login_token
+        )
+        request.redirect(redirect_url)
+        finish_request(request)
+
+    @staticmethod
+    def _add_login_token_to_redirect_url(url, token):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({"loginToken": token})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
+
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled:

+ 66 - 0
synapse/types.py

@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import re
 import string
 from collections import namedtuple
 
@@ -228,6 +229,71 @@ def contains_invalid_mxid_characters(localpart):
     return any(c not in mxid_localpart_allowed_characters for c in localpart)
 
 
+UPPER_CASE_PATTERN = re.compile(b"[A-Z_]")
+
+# the following is a pattern which matches '=', and bytes which are not allowed in a mxid
+# localpart.
+#
+# It works by:
+#  * building a string containing the allowed characters (excluding '=')
+#  * escaping every special character with a backslash (to stop '-' being interpreted as a
+#    range operator)
+#  * wrapping it in a '[^...]' regex
+#  * converting the whole lot to a 'bytes' sequence, so that we can use it to match
+#    bytes rather than strings
+#
+NON_MXID_CHARACTER_PATTERN = re.compile(
+    ("[^%s]" % (
+        re.escape("".join(mxid_localpart_allowed_characters - {"="}),),
+    )).encode("ascii"),
+)
+
+
+def map_username_to_mxid_localpart(username, case_sensitive=False):
+    """Map a username onto a string suitable for a MXID
+
+    This follows the algorithm laid out at
+    https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets.
+
+    Args:
+        username (unicode|bytes): username to be mapped
+        case_sensitive (bool): true if TEST and test should be mapped
+            onto different mxids
+
+    Returns:
+        unicode: string suitable for a mxid localpart
+    """
+    if not isinstance(username, bytes):
+        username = username.encode('utf-8')
+
+    # first we sort out upper-case characters
+    if case_sensitive:
+        def f1(m):
+            return b"_" + m.group().lower()
+
+        username = UPPER_CASE_PATTERN.sub(f1, username)
+    else:
+        username = username.lower()
+
+    # then we sort out non-ascii characters
+    def f2(m):
+        g = m.group()[0]
+        if isinstance(g, str):
+            # on python 2, we need to do a ord(). On python 3, the
+            # byte itself will do.
+            g = ord(g)
+        return b"=%02x" % (g,)
+
+    username = NON_MXID_CHARACTER_PATTERN.sub(f2, username)
+
+    # we also do the =-escaping to mxids starting with an underscore.
+    username = re.sub(b'^_', b'=5f', username)
+
+    # we should now only have ascii bytes left, so can decode back to a
+    # unicode.
+    return username.decode('ascii')
+
+
 class StreamToken(
     namedtuple("Token", (
         "room_key",

+ 30 - 1
tests/test_types.py

@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from synapse.api.errors import SynapseError
-from synapse.types import GroupID, RoomAlias, UserID
+from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
 
 from tests import unittest
 from tests.utils import TestHomeServer
@@ -79,3 +79,32 @@ class GroupIDTestCase(unittest.TestCase):
             except SynapseError as exc:
                 self.assertEqual(400, exc.code)
                 self.assertEqual("M_UNKNOWN", exc.errcode)
+
+
+class MapUsernameTestCase(unittest.TestCase):
+    def testPassThrough(self):
+        self.assertEqual(map_username_to_mxid_localpart("test1234"), "test1234")
+
+    def testUpperCase(self):
+        self.assertEqual(map_username_to_mxid_localpart("tEST_1234"), "test_1234")
+        self.assertEqual(
+            map_username_to_mxid_localpart("tEST_1234", case_sensitive=True),
+            "t_e_s_t__1234",
+        )
+
+    def testSymbols(self):
+        self.assertEqual(
+            map_username_to_mxid_localpart("test=$?_1234"),
+            "test=3d=24=3f_1234",
+        )
+
+    def testLeadingUnderscore(self):
+        self.assertEqual(map_username_to_mxid_localpart("_test_1234"), "=5ftest_1234")
+
+    def testNonAscii(self):
+        # this should work with either a unicode or a bytes
+        self.assertEqual(map_username_to_mxid_localpart(u'têst'), "t=c3=aast")
+        self.assertEqual(
+            map_username_to_mxid_localpart(u'têst'.encode('utf-8')),
+            "t=c3=aast",
+        )