Browse Source

Fix checking for access_token query parameter (#294)

Andrew Morgan 3 years ago
parent
commit
08b4085ad0
4 changed files with 106 additions and 27 deletions
  1. 1 0
      changelog.d/294.bugfix
  2. 4 5
      sydent/http/auth.py
  3. 29 22
      sydent/http/servlets/__init__.py
  4. 72 0
      tests/test_auth.py

+ 1 - 0
changelog.d/294.bugfix

@@ -0,0 +1 @@
+Fix a bug that prevented Sydent from checking for OpenID auth tokens in request parameters when running on Python3.

+ 4 - 5
sydent/http/auth.py

@@ -17,11 +17,9 @@ from __future__ import absolute_import
 
 import logging
 
-import twisted.internet.ssl
-
 from sydent.db.accounts import AccountStore
 from sydent.terms.terms import get_terms
-from sydent.http.servlets import MatrixRestError
+from sydent.http.servlets import MatrixRestError, get_args
 
 
 logger = logging.getLogger(__name__)
@@ -43,8 +41,9 @@ def tokenFromRequest(request):
         token = authHeader[len("Bearer "):]
 
     # no? try access_token query param
-    if token is None and 'access_token' in request.args:
-        token = request.args['access_token'][0]
+    if token is None:
+        args = get_args(request, ('access_token',), required=False)
+        token = args.get('access_token')
 
     # Ensure we're dealing with unicode.
     if token and isinstance(token, bytes):

+ 29 - 22
sydent/http/servlets/__init__.py

@@ -38,7 +38,7 @@ class MatrixRestError(Exception):
         self.error = error
 
 
-def get_args(request, required_args):
+def get_args(request, args, required=True):
     """
     Helper function to get arguments for an HTTP request.
     Currently takes args from the top level keys of a json object or
@@ -49,9 +49,14 @@ def get_args(request, required_args):
 
     :param request: The request received by the servlet.
     :type request: twisted.web.server.Request
-    :param required_args: The args that needs to be found in the
-        request's parameters.
-    :type required_args: tuple[unicode]
+    :param args: The args to look for in the request's parameters.
+    :type args: tuple[unicode]
+    :param required: Whether to raise a MatrixRestError with 400
+        M_MISSING_PARAMS if an argument is not found.
+    :type required: bool
+
+    :raises: MatrixRestError if required is True and a given parameter
+        was not found in the request's query parameters.
 
     :return: A dict containing the requested args and their values. String values
         are of type unicode.
@@ -59,7 +64,7 @@ def get_args(request, required_args):
     """
     v1_path = request.path.startswith(b'/_matrix/identity/api/v1')
 
-    args = None
+    request_args = None
     # for v1 paths, only look for json args if content type is json
     if (
         request.method in (b'POST', b'PUT') and (
@@ -71,25 +76,25 @@ def get_args(request, required_args):
     ):
         try:
             # json.loads doesn't allow bytes in Python 3.5
-            args = json.loads(request.content.read().decode("UTF-8"))
+            request_args = json.loads(request.content.read().decode("UTF-8"))
         except ValueError:
             raise MatrixRestError(400, 'M_BAD_JSON', 'Malformed JSON')
 
     # If we didn't get anything from that, and it's a v1 api path, try the request args
     # (riot-web's usage of the ed25519 sign servlet currently involves
     # sending the params in the query string with a json body of 'null')
-    if args is None and (v1_path or request.method == b'GET'):
-        args_bytes = copy.copy(request.args)
+    if request_args is None and (v1_path or request.method == b'GET'):
+        request_args_bytes = copy.copy(request.args)
         # Twisted supplies everything as an array because it's valid to
         # supply the same params multiple times with www-form-urlencoded
         # params. This make it incompatible with the json object though,
         # so we need to convert one of them. Since this is the
         # backwards-compat option, we convert this one.
-        args = {}
-        for k, v in args_bytes.items():
+        request_args = {}
+        for k, v in request_args_bytes.items():
             if isinstance(v, list) and len(v) == 1:
                 try:
-                    args[k.decode("UTF-8")] = v[0].decode("UTF-8")
+                    request_args[k.decode("UTF-8")] = v[0].decode("UTF-8")
                 except UnicodeDecodeError:
                     # Get a version of the key that has non-UTF-8 characters replaced by
                     # their \xNN escape sequence so it doesn't raise another exception.
@@ -100,20 +105,22 @@ def get_args(request, required_args):
                         "Parameter %s and its value must be valid UTF-8" % safe_k,
                     )
 
-    elif args is None:
-        args = {}
+    elif request_args is None:
+        request_args = {}
 
-    missing = []
-    for a in required_args:
-        if a not in args:
-            missing.append(a)
+    if required:
+        # Check for any missing arguments
+        missing = []
+        for a in args:
+            if a not in request_args:
+                missing.append(a)
 
-    if len(missing) > 0:
-        request.setResponseCode(400)
-        msg = "Missing parameters: "+(",".join(missing))
-        raise MatrixRestError(400, 'M_MISSING_PARAMS', msg)
+        if len(missing) > 0:
+            request.setResponseCode(400)
+            msg = "Missing parameters: "+(",".join(missing))
+            raise MatrixRestError(400, 'M_MISSING_PARAMS', msg)
 
-    return args
+    return request_args
 
 
 def jsonwrap(f):

+ 72 - 0
tests/test_auth.py

@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.trial import unittest
+
+from sydent.http.auth import tokenFromRequest
+from tests.utils import make_request, make_sydent
+
+
+class AuthTestCase(unittest.TestCase):
+    """Tests Sydent's auth code"""
+
+    def setUp(self):
+        # Create a new sydent
+        self.sydent = make_sydent()
+        self.test_token = "testingtoken"
+
+        # Inject a fake OpenID token into the database
+        cur = self.sydent.db.cursor()
+        cur.execute(
+            "INSERT INTO accounts (user_id, created_ts, consent_version)"
+            "VALUES (?, ?, ?)",
+            ("@bob:localhost", 101010101, "asd")
+        )
+        cur.execute(
+            "INSERT INTO tokens (user_id, token)"
+            "VALUES (?, ?)",
+            ("@bob:localhost", self.test_token)
+        )
+
+        self.sydent.db.commit()
+
+    def test_can_read_token_from_headers(self):
+        """Tests that Sydent correct extracts an auth token from request headers"""
+        self.sydent.run()
+
+        request, _ = make_request(
+            self.sydent.reactor, "GET", "/_matrix/identity/v2/hash_details"
+        )
+        request.requestHeaders.addRawHeader(
+            b"Authorization", b"Bearer " + self.test_token.encode("ascii")
+        )
+
+        token = tokenFromRequest(request)
+
+        self.assertEqual(token, self.test_token)
+
+    def test_can_read_token_from_query_parameters(self):
+        """Tests that Sydent correct extracts an auth token from query parameters"""
+        self.sydent.run()
+
+        request, _ = make_request(
+            self.sydent.reactor, "GET",
+            "/_matrix/identity/v2/hash_details?access_token=" + self.test_token
+        )
+
+        token = tokenFromRequest(request)
+
+        self.assertEqual(token, self.test_token)