Browse Source

Catch failures to contact remote homeserver for `/register` (#456)

* Test cases to check we handle these cases properly
* Log a warning in all cases where we return 500

No more exceptions. Let's log something so it's not silent.
David Robertson 2 years ago
parent
commit
001c046140

+ 1 - 1
.github/workflows/pipeline.yml

@@ -31,7 +31,7 @@ jobs:
       - uses: actions/setup-python@v2
         with:
          python-version: ${{ matrix.python-version }}
-      - run: python -m pip install -e .
+      - run: python -m pip install -e .[dev]
       - run: trial tests
 
   run-matrix-is-tests:

+ 1 - 0
changelog.d/456.misc

@@ -0,0 +1 @@
+Handle federation request failures in `/request` explicitly, to reduce Sentry noise.

+ 1 - 0
setup.py

@@ -56,6 +56,7 @@ setup(
     ],
     extras_require={
         "dev": [
+            "parameterized==0.8.1",
             "flake8==3.9.2",
             "flake8-pyi==20.10.0",
             "black==21.6b0",

+ 2 - 0
stubs/twisted/internet/error.pyi

@@ -2,3 +2,5 @@ from typing import Any, Optional
 
 class ConnectError(Exception):
     def __init__(self, osError: Optional[Any] = ..., string: str = ...): ...
+
+class DNSLookupError(IOError): ...

+ 7 - 1
stubs/twisted/web/client.pyi

@@ -1,8 +1,9 @@
-from typing import BinaryIO, Optional, Type, TypeVar
+from typing import BinaryIO, Optional, Sequence, Type, TypeVar
 
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IConsumer, IProtocol
 from twisted.internet.task import Cooperator
+from twisted.python.failure import Failure
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import (
     IAgent,
@@ -15,6 +16,11 @@ from zope.interface import implementer
 
 _C = TypeVar("_C")
 
+class ResponseFailed(Exception):
+    def __init__(
+        self, reasons: Sequence[Failure], response: Optional[Response] = ...
+    ): ...
+
 class HTTPConnectionPool:
     persistent: bool
     maxPersistentPerHost: int

+ 40 - 30
sydent/http/servlets/registerservlet.py

@@ -14,8 +14,11 @@
 
 import logging
 import urllib
-from typing import TYPE_CHECKING
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Dict
 
+from twisted.internet.error import ConnectError, DNSLookupError
+from twisted.web.client import ResponseFailed
 from twisted.web.resource import Resource
 from twisted.web.server import Request
 
@@ -56,52 +59,59 @@ class RegisterServlet(Resource):
                 "error": "matrix_server_name must be a valid Matrix server name (IP address or hostname)",
             }
 
-        result = await self.client.get_json(
-            "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
-            % (
-                matrix_server,
-                urllib.parse.quote(args["access_token"]),
-            ),
-            1024 * 5,
-        )
+        def federation_request_problem(error: str) -> Dict[str, str]:
+            logger.warning(error)
+            request.setResponseCode(HTTPStatus.INTERNAL_SERVER_ERROR)
+            return {
+                "errcode": "M_UNKNOWN",
+                "error": error,
+            }
+
+        try:
+            result = await self.client.get_json(
+                "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
+                % (
+                    matrix_server,
+                    urllib.parse.quote(args["access_token"]),
+                ),
+                1024 * 5,
+            )
+        except (DNSLookupError, ConnectError, ResponseFailed) as e:
+            return federation_request_problem(
+                f"Unable to contact the Matrix homeserver ({type(e).__name__})"
+            )
 
         if "sub" not in result:
-            raise Exception("Invalid response from homeserver")
+            return federation_request_problem(
+                "The Matrix homeserver did not include 'sub' in its response",
+            )
 
         user_id = result["sub"]
 
         if not isinstance(user_id, str):
-            request.setResponseCode(500)
-            return {
-                "errcode": "M_UNKNOWN",
-                "error": "The Matrix homeserver returned a malformed reply",
-            }
+            return federation_request_problem(
+                "The Matrix homeserver returned a malformed reply"
+            )
 
         user_id_components = user_id.split(":", 1)
 
         # Ensure there's a localpart and domain in the returned user ID.
         if len(user_id_components) != 2:
-            request.setResponseCode(500)
-            return {
-                "errcode": "M_UNKNOWN",
-                "error": "The Matrix homeserver returned an invalid MXID",
-            }
+            return federation_request_problem(
+                "The Matrix homeserver returned an invalid MXID"
+            )
 
         user_id_server = user_id_components[1]
 
         if not is_valid_matrix_server_name(user_id_server):
-            request.setResponseCode(500)
-            return {
-                "errcode": "M_UNKNOWN",
-                "error": "The Matrix homeserver returned an invalid MXID",
-            }
+            return federation_request_problem(
+                "The Matrix homeserver returned an invalid MXID"
+            )
 
         if user_id_server != matrix_server:
-            request.setResponseCode(500)
-            return {
-                "errcode": "M_UNKNOWN",
-                "error": "The Matrix homeserver returned a MXID belonging to another homeserver",
-            }
+            return federation_request_problem(
+                "The Matrix homeserver returned a MXID belonging to another homeserver"
+            )
 
         tok = issueToken(self.sydent, user_id)
 

+ 38 - 2
tests/test_register.py

@@ -11,7 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from http import HTTPStatus
+from unittest.mock import patch
 
+import twisted.internet.error
+import twisted.web.client
+from parameterized import parameterized
 from twisted.trial import unittest
 
 from tests.utils import make_request, make_sydent
@@ -20,11 +25,11 @@ from tests.utils import make_request, make_sydent
 class RegisterTestCase(unittest.TestCase):
     """Tests Sydent's register servlet"""
 
-    def setUp(self):
+    def setUp(self) -> None:
         # Create a new sydent
         self.sydent = make_sydent()
 
-    def test_sydent_rejects_invalid_hostname(self):
+    def test_sydent_rejects_invalid_hostname(self) -> None:
         """Tests that the /register endpoint rejects an invalid hostname passed as matrix_server_name"""
         self.sydent.run()
 
@@ -40,3 +45,34 @@ class RegisterTestCase(unittest.TestCase):
         request.render(self.sydent.servlets.registerServlet)
 
         self.assertEqual(channel.code, 400)
+
+    @parameterized.expand(
+        [
+            (twisted.internet.error.DNSLookupError(),),
+            (twisted.internet.error.TimeoutError(),),
+            (twisted.internet.error.ConnectionRefusedError(),),
+            # Naughty: strictly we're supposed to initialise a ResponseNeverReceived
+            # with a list of 1 or more failures.
+            (twisted.web.client.ResponseNeverReceived([]),),
+        ]
+    )
+    def test_connection_failure(self, exc: Exception) -> None:
+        self.sydent.run()
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            content={
+                "matrix_server_name": "matrix.alice.com",
+                "access_token": "back_in_wonderland",
+            },
+        )
+        servlet = self.sydent.servlets.registerServlet
+
+        with patch.object(servlet.client, "get_json", side_effect=exc):
+            request.render(servlet)
+        self.assertEqual(channel.code, HTTPStatus.INTERNAL_SERVER_ERROR)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN")
+        # Check that we haven't just returned the generic error message in asyncjsonwrap
+        self.assertNotEqual(channel.json_body["error"], "Internal Server Error")
+        self.assertIn("contact", channel.json_body["error"])