Browse Source

Merge branch 'gitlab-master' into dinsic

Brendan Abolivier 3 years ago
parent
commit
5600527467
38 changed files with 1221 additions and 186 deletions
  1. 18 0
      CHANGELOG.md
  2. 3 0
      matrix_is_test/launcher.py
  3. 15 16
      res/matrix-org/invite_template.eml
  4. 5 5
      res/matrix-org/verification_template.eml
  5. 14 15
      res/vector-im/invite_template.eml
  6. 5 11
      res/vector-im/verification_template.eml
  7. 18 2
      sydent/hs_federation/verifier.py
  8. 16 18
      sydent/http/auth.py
  9. 155 0
      sydent/http/blacklisting_reactor.py
  10. 24 7
      sydent/http/httpclient.py
  11. 119 1
      sydent/http/httpcommon.py
  12. 30 30
      sydent/http/httpserver.py
  13. 7 2
      sydent/http/matrixfederationagent.py
  14. 2 3
      sydent/http/servlets/accountservlet.py
  15. 5 3
      sydent/http/servlets/blindlysignstuffservlet.py
  16. 0 3
      sydent/http/servlets/bulklookupservlet.py
  17. 17 6
      sydent/http/servlets/emailservlet.py
  18. 5 3
      sydent/http/servlets/getvalidated3pidservlet.py
  19. 2 2
      sydent/http/servlets/hashdetailsservlet.py
  20. 2 2
      sydent/http/servlets/logoutservlet.py
  21. 1 3
      sydent/http/servlets/lookupservlet.py
  22. 2 2
      sydent/http/servlets/lookupv2servlet.py
  23. 9 5
      sydent/http/servlets/msisdnservlet.py
  24. 50 3
      sydent/http/servlets/registerservlet.py
  25. 24 6
      sydent/http/servlets/store_invite_servlet.py
  26. 2 3
      sydent/http/servlets/termsservlet.py
  27. 6 3
      sydent/http/servlets/threepidbindservlet.py
  28. 7 2
      sydent/http/servlets/threepidunbindservlet.py
  29. 47 9
      sydent/sydent.py
  30. 14 2
      sydent/threepid/bind.py
  31. 5 0
      sydent/util/emailutils.py
  32. 118 0
      sydent/util/ip_range.py
  33. 104 3
      sydent/util/stringutils.py
  34. 2 2
      tests/test_auth.py
  35. 243 0
      tests/test_blacklisting.py
  36. 46 0
      tests/test_register.py
  37. 37 0
      tests/test_util.py
  38. 42 14
      tests/utils.py

+ 18 - 0
CHANGELOG.md

@@ -1,3 +1,21 @@
+Sydent 2.3.0 (unreleased)
+=========================
+
+Bug fixes
+---------
+- During user registration on the identity server, validate that the MXID returned by the contacted homeserver is valid for that homeserver. ([cc97fff](https://github.com/matrix-org/sydent/commit/cc97fff))
+- Ensure that `/v2/` endponts are correctly authenticated. ([ce04a68](https://github.com/matrix-org/sydent/commit/ce04a68))
+- Perform additional validation on the response received when requesting server signing keys. ([07e6da7](https://github.com/matrix-org/sydent/commit/07e6da7))
+
+Security fixes
+--------------
+
+- Validate the `matrix_server_name` parameter given during user registration. ([9e57334](https://github.com/matrix-org/sydent/commit/9e57334), [8936925](https://github.com/matrix-org/sydent/commit/8936925), [3d531ed](https://github.com/matrix-org/sydent/commit/3d531ed), [0f00412](https://github.com/matrix-org/sydent/commit/0f00412))
+- Limit the size of requests received from HTTP clients. ([89071a1](https://github.com/matrix-org/sydent/commit/89071a1), [0523511](https://github.com/matrix-org/sydent/commit/0523511), [f56eee3](https://github.com/matrix-org/sydent/commit/f56eee3))
+- Limit the size of responses received from HTTP servers. ([89071a1](https://github.com/matrix-org/sydent/commit/89071a1), [0523511](https://github.com/matrix-org/sydent/commit/0523511), [f56eee3](https://github.com/matrix-org/sydent/commit/f56eee3))
+- In invite emails, randomise the multipart boundary, and include MXIDs where available. ([4469d1d](https://github.com/matrix-org/sydent/commit/4469d1d), [6b405a8](https://github.com/matrix-org/sydent/commit/6b405a8), [65a6e91](https://github.com/matrix-org/sydent/commit/65a6e91))
+- Perform additional validation on the `client_secret` and `email` parameters to various APIs. ([3175fd3](https://github.com/matrix-org/sydent/commit/3175fd3))
+
 Sydent 2.2.0 (2020-09-11)
 =========================
 

+ 3 - 0
matrix_is_test/launcher.py

@@ -37,6 +37,9 @@ info_path = {info_path}
 templates.path = {testsubject_path}/res
 brand.default = is-test
 
+
+ip.whitelist = 127.0.0.1
+
 [email]
 email.tlsmode = 0
 email.invite.subject = %(sender_display_name)s has invited you to chat

+ 15 - 16
res/matrix-org/invite_template.eml

@@ -4,19 +4,20 @@ To: %(to)s
 Message-ID: %(messageid)s
 Subject: %(subject_header_value)s
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 
 Hi,
 
-%(sender_display_name)s has invited you into a room %(bracketed_room_name)s on
-Matrix. To join the conversation, either pick a Matrix client from
-https://matrix.org/docs/projects/try-matrix-now.html or use the single-click
-link below to join via Element (requires Chrome, Firefox, Safari, iOS or Android)
+%(sender_display_name)s %(bracketed_verified_sender)shas invited you into a room
+%(bracketed_room_name)son Matrix. To join the conversation, either pick a
+Matrix client from https://matrix.org/docs/projects/try-matrix-now.html or use
+the single-click link below to join via Element (requires Chrome, Firefox,
+Safari, iOS or Android)
 
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fmatrix.org%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 
@@ -38,12 +39,7 @@ Thanks,
 
 Matrix
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 
@@ -68,6 +64,10 @@ pre, code {
     padding: 20px;
 }
 
+.low-contrast {
+    color: #666666
+}
+
 #inner {
     width: 640px;
 }
@@ -102,7 +102,7 @@ pre, code {
 
 <p>Hi,</p>
 
-<p>%(sender_display_name_forhtml)s has invited you into a room %(bracketed_room_name_forhtml)s on
+<p>%(sender_display_name_forhtml)s <span class="low-contrast">%(bracketed_verified_sender_forhtml)s</span> has invited you into a room %(bracketed_room_name_forhtml)s on
 Matrix. To join the conversation, either <a href="https://matrix.org/docs/projects/try-matrix-now.html">pick a Matrix client</a> or use the single-click
 link below to join via Element (requires
 <a href="https://www.google.com/chrome">Chrome</a>,
@@ -139,6 +139,5 @@ create new communication solutions or extend the capabilities and reach of exist
         </table>
     </body>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 5 - 5
res/matrix-org/verification_template.eml

@@ -4,10 +4,10 @@ To: %(to)s
 Message-ID: %(messageid)s
 Subject: Confirm your email address for Matrix
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 
@@ -35,7 +35,7 @@ Matrix defines the standard, and provides open source reference implementations
 Matrix-compatible Servers, Clients, Client SDKs and Application Services to help you
 create new communication solutions or extend the capabilities and reach of existing ones.
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 
@@ -85,4 +85,4 @@ create new communication solutions or extend the capabilities and reach of exist
 </body>
 </html>
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 14 - 15
res/vector-im/invite_template.eml

@@ -4,17 +4,18 @@ To: %(to)s
 Message-ID: %(messageid)s
 Subject: %(subject_header_value)s
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 
 Hi,
 
-%(sender_display_name)s has invited you into a room %(bracketed_room_name)s on
-Element. To join the conversation please follow the link below.
+%(sender_display_name)s %(bracketed_verified_sender)shas invited you into a room
+%(bracketed_room_name)son Element. To join the conversation please follow the
+link below.
 
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fvector.im%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 
@@ -51,12 +52,7 @@ decentralized communication delivering a community of users, bridged networks,
 integrated bots and applications plus full end-to-end encryption. To learn more about
 Matrix visit https://matrix.org.
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 
@@ -81,6 +77,10 @@ pre, code {
     padding: 20px;
 }
 
+.low-contrast {
+    color: #666666
+}
+
 #inner {
     width: 640px;
 }
@@ -123,8 +123,8 @@ pre, code {
 
 <p>Hi,</p>
 
-<p>%(sender_display_name_forhtml)s has invited you into a room %(bracketed_room_name_forhtml)s on
-Element.</p>
+<p>%(sender_display_name_forhtml)s <span class="low-contrast">%(bracketed_verified_sender_forhtml)s</span> has invited you into a
+room %(bracketed_room_name_forhtml)s on Element.</p>
 
 <p>
     <a
@@ -173,6 +173,5 @@ Matrix visit https://matrix.org.</p>
         </table>
     </body>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 5 - 11
res/vector-im/verification_template.eml

@@ -4,10 +4,10 @@ To: %(to)s
 Message-ID: %(messageid)s
 Subject: Confirm your email address for Element
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 
@@ -51,12 +51,7 @@ decentralized communication delivering a community of users, bridged networks,
 integrated bots and applications plus full end-to-end encryption. To learn more about
 Matrix visit https://matrix.org.
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 
@@ -171,6 +166,5 @@ pre, code {
         </table>
     </body>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 18 - 2
sydent/hs_federation/verifier.py

@@ -25,6 +25,7 @@ import signedjson.key
 from signedjson.sign import SignatureVerifyException
 
 from sydent.http.httpclient import FederationHttpClient
+from sydent.util.stringutils import is_valid_matrix_server_name
 
 
 logger = logging.getLogger(__name__)
@@ -37,6 +38,13 @@ class NoAuthenticationError(Exception):
     pass
 
 
+class InvalidServerName(Exception):
+    """
+    Raised when the provided origin parameter is not a valid hostname (plus optional port).
+    """
+    pass
+
+
 class Verifier(object):
     """
     Verifies signed json blobs from Matrix Homeservers by finding the
@@ -69,14 +77,19 @@ class Verifier(object):
                 defer.returnValue(self.cache[server_name]['verify_keys'])
 
         client = FederationHttpClient(self.sydent)
-        result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name)
+        result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50)
+
         if 'verify_keys' not in result:
             raise SignatureVerifyException("No key found in response")
 
         if 'valid_until_ts' in result:
+            if not isinstance(result['valid_until_ts'], int):
+                raise SignatureVerifyException("Invalid valid_until_ts received, must be an integer")
+
             # Don't cache anything without a valid_until_ts or we wouldn't
             # know when to expire it.
-            logger.info("Got keys for %s: caching until %s", server_name, result['valid_until_ts'])
+
+            logger.info("Got keys for %s: caching until %d", server_name, result['valid_until_ts'])
             self.cache[server_name] = result
 
         defer.returnValue(result['verify_keys'])
@@ -197,6 +210,9 @@ class Verifier(object):
         if not json_request["signatures"]:
             raise NoAuthenticationError("Missing X-Matrix Authorization header")
 
+        if not is_valid_matrix_server_name(json_request["origin"]):
+            raise InvalidServerName("X-Matrix header's origin parameter must be a valid Matrix server name")
+
         yield self.verifyServerSignedJson(json_request, [origin])
 
         logger.info("Verified request from HS %s", origin)

+ 16 - 18
sydent/http/auth.py

@@ -52,7 +52,7 @@ def tokenFromRequest(request):
     return token
 
 
-def authIfV2(sydent, request, requireTermsAgreed=True):
+def authV2(sydent, request, requireTermsAgreed=True):
     """For v2 APIs check that the request has a valid access token associated with it
 
     :param sydent: The Sydent instance to use.
@@ -67,25 +67,23 @@ def authIfV2(sydent, request, requireTermsAgreed=True):
     :raises MatrixRestError: If the request is v2 but could not be authed or the user has
         not accepted terms.
     """
-    if request.path.startswith(b'/_matrix/identity/v2'):
-        token = tokenFromRequest(request)
+    token = tokenFromRequest(request)
 
-        if token is None:
-            raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
+    if token is None:
+        raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
 
-        accountStore = AccountStore(sydent)
+    accountStore = AccountStore(sydent)
 
-        account = accountStore.getAccountByToken(token)
-        if account is None:
-            raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
+    account = accountStore.getAccountByToken(token)
+    if account is None:
+        raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
 
-        if requireTermsAgreed:
-            terms = get_terms(sydent)
-            if (
-                terms.getMasterVersion() is not None and
-                account.consentVersion != terms.getMasterVersion()
-            ):
-                raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
+    if requireTermsAgreed:
+        terms = get_terms(sydent)
+        if (
+            terms.getMasterVersion() is not None and
+            account.consentVersion != terms.getMasterVersion()
+        ):
+            raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
 
-        return account
-    return None
+    return account

+ 155 - 0
sydent/http/blacklisting_reactor.py

@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+    Any,
+    List,
+    Optional,
+)
+
+from zope.interface import implementer, provider
+from netaddr import IPAddress, IPSet
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostResolution,
+    IReactorPluggableNameResolver,
+    IResolutionReceiver,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def check_against_blacklist(
+    ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
+    """
+    Compares an IP address to allowed and disallowed IP sets.
+
+    Args:
+        ip_address: The IP address to check
+        ip_whitelist: Allowed IP addresses.
+        ip_blacklist: Disallowed IP addresses.
+
+    Returns:
+        True if the IP address is in the blacklist and not in the whitelist.
+    """
+    if ip_address in ip_blacklist:
+        if ip_whitelist is None or ip_address not in ip_whitelist:
+            return True
+    return False
+
+
+class _IPBlacklistingResolver:
+    """
+    A proxy for reactor.nameResolver which only produces non-blacklisted IP
+    addresses, preventing DNS rebinding attacks on URL preview.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        """
+        Args:
+            reactor: The twisted reactor.
+            ip_whitelist: IP addresses to allow.
+            ip_blacklist: IP addresses to disallow.
+        """
+        self._reactor = reactor
+        self._ip_whitelist = ip_whitelist
+        self._ip_blacklist = ip_blacklist
+
+    def resolveHostName(
+        self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+    ) -> IResolutionReceiver:
+        addresses = []  # type: List[IAddress]
+
+        def _callback() -> None:
+            has_bad_ip = False
+            for address in addresses:
+                # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+                # should go through this path.
+                if not isinstance(address, (IPv4Address, IPv6Address)):
+                    continue
+
+                ip_address = IPAddress(address.host)
+
+                if check_against_blacklist(
+                    ip_address, self._ip_whitelist, self._ip_blacklist
+                ):
+                    logger.info(
+                        "Dropped %s from DNS resolution to %s due to blacklist"
+                        % (ip_address, hostname)
+                    )
+                    has_bad_ip = True
+
+            # if we have a blacklisted IP, we'd like to raise an error to block the
+            # request, but all we can really do from here is claim that there were no
+            # valid results.
+            if not has_bad_ip:
+                for address in addresses:
+                    recv.addressResolved(address)
+            recv.resolutionComplete()
+
+        @provider(IResolutionReceiver)
+        class EndpointReceiver:
+            @staticmethod
+            def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
+                recv.resolutionBegan(resolutionInProgress)
+
+            @staticmethod
+            def addressResolved(address: IAddress) -> None:
+                addresses.append(address)
+
+            @staticmethod
+            def resolutionComplete() -> None:
+                _callback()
+
+        self._reactor.nameResolver.resolveHostName(
+            EndpointReceiver, hostname, portNumber=portNumber
+        )
+
+        return recv
+
+
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+    """
+    A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+    addresses, to prevent DNS rebinding.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        self._reactor = reactor
+
+        # We need to use a DNS resolver which filters out blacklisted IP
+        # addresses, to prevent DNS rebinding.
+        self.nameResolver = _IPBlacklistingResolver(
+            self._reactor, ip_whitelist, ip_blacklist
+        )
+
+    def __getattr__(self, attr: str) -> Any:
+        # Passthrough to the real reactor except for the DNS resolver.
+        return getattr(self._reactor, attr)

+ 24 - 7
sydent/http/httpclient.py

@@ -22,9 +22,11 @@ from io import BytesIO
 from twisted.internet import defer
 from twisted.web.client import FileBodyProducer, Agent, readBody
 from twisted.web.http_headers import Headers
-from sydent.http.matrixfederationagent import MatrixFederationAgent
 
+from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
+from sydent.http.matrixfederationagent import MatrixFederationAgent
 from sydent.http.federation_tls_options import ClientTLSOptionsFactory
+from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size
 from sydent.util import json_decoder
 
 logger = logging.getLogger(__name__)
@@ -35,12 +37,15 @@ class HTTPClient(object):
     requests.
     """
     @defer.inlineCallbacks
-    def get_json(self, uri):
+    def get_json(self, uri, max_size = None):
         """Make a GET request to an endpoint returning JSON and parse result
 
         :param uri: The URI to make a GET request to.
         :type uri: unicode
 
+        :param max_size: The maximum size (in bytes) to allow as a response.
+        :type max_size: int
+
         :return: A deferred containing JSON parsed into a Python object.
         :rtype: twisted.internet.defer.Deferred[dict[any, any]]
         """
@@ -50,7 +55,7 @@ class HTTPClient(object):
             b"GET",
             uri.encode("utf8"),
         )
-        body = yield readBody(response)
+        body = yield read_body_with_max_size(response, max_size)
         try:
             # json.loads doesn't allow bytes in Python 3.5
             json_body = json_decoder.decode(body.decode("UTF-8"))
@@ -95,7 +100,11 @@ class HTTPClient(object):
         # Ensure the body object is read otherwise we'll leak HTTP connections
         # as per
         # https://twistedmatrix.com/documents/current/web/howto/client.html
-        yield readBody(response)
+        try:
+            # TODO Will this cause the server to think the request was a failure?
+            yield read_body_with_max_size(response, 0)
+        except BodyExceededMaxSize:
+            pass
 
         defer.returnValue(response)
 
@@ -109,7 +118,11 @@ class SimpleHttpClient(HTTPClient):
         # BrowserLikePolicyForHTTPS context factory which will do regular cert validation
         # 'like a browser'
         self.agent = Agent(
-            self.sydent.reactor,
+            BlacklistingReactorWrapper(
+                reactor=self.sydent.reactor,
+                ip_whitelist=sydent.ip_whitelist,
+                ip_blacklist=sydent.ip_blacklist,
+            ),
             connectTimeout=15,
         )
 
@@ -120,6 +133,10 @@ class FederationHttpClient(HTTPClient):
     def __init__(self, sydent):
         self.sydent = sydent
         self.agent = MatrixFederationAgent(
-            self.sydent.reactor,
-            ClientTLSOptionsFactory(sydent.cfg),
+            BlacklistingReactorWrapper(
+                reactor=self.sydent.reactor,
+                ip_whitelist=sydent.ip_whitelist,
+                ip_blacklist=sydent.ip_blacklist,
+            ),
+            ClientTLSOptionsFactory(sydent.cfg) if sydent.use_tls_for_federation else None,
         )

+ 119 - 1
sydent/http/httpcommon.py

@@ -15,11 +15,23 @@
 # limitations under the License.
 
 import logging
+from io import BytesIO
 
 import twisted.internet.ssl
+from twisted.internet import defer, protocol
+from twisted.internet.protocol import connectionDone
+from twisted.web._newclient import ResponseDone
+from twisted.web.http import PotentialDataLoss
+from twisted.web.iweb import UNKNOWN_LENGTH
+from twisted.web import server
+
 
 logger = logging.getLogger(__name__)
 
+# Arbitrarily limited to 512 KiB.
+MAX_REQUEST_SIZE = 512 * 1024
+
+
 class SslComponents:
     def __init__(self, sydent):
         self.sydent = sydent
@@ -55,10 +67,116 @@ class SslComponents:
                 fp = open(caCertFilename)
                 caCert = twisted.internet.ssl.Certificate.loadPEM(fp.read())
                 fp.close()
-            except:
+            except Exception:
                 logger.warn("Failed to open CA cert file %s", caCertFilename)
                 raise
             logger.warn("Using custom CA cert file: %s", caCertFilename)
             return twisted.internet._sslverify.OpenSSLCertificateAuthorities([caCert.original])
         else:
             return twisted.internet.ssl.OpenSSLDefaultPaths()
+
+
+class BodyExceededMaxSize(Exception):
+    """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which immediately errors upon receiving data."""
+
+    def __init__(self, deferred):
+        self.deferred = deferred
+
+    def _maybe_fail(self):
+        """
+        Report a max size exceed error and disconnect the first time this is called.
+        """
+        if not self.deferred.called:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
+
+    def dataReceived(self, data) -> None:
+        self._maybe_fail()
+
+    def connectionLost(self, reason) -> None:
+        self._maybe_fail()
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+
+    def __init__(self, deferred, max_size):
+        self.stream = BytesIO()
+        self.deferred = deferred
+        self.length = 0
+        self.max_size = max_size
+
+    def dataReceived(self, data) -> None:
+        # If the deferred was called, bail early.
+        if self.deferred.called:
+            return
+
+        self.stream.write(data)
+        self.length += len(data)
+        # The first time the maximum size is exceeded, error and cancel the
+        # connection. dataReceived might be called again if data was received
+        # in the meantime.
+        if self.max_size is not None and self.length >= self.max_size:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
+
+    def connectionLost(self, reason=connectionDone) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
+        if reason.check(ResponseDone):
+            self.deferred.callback(self.stream.getvalue())
+        elif reason.check(PotentialDataLoss):
+            # stolen from https://github.com/twisted/treq/pull/49/files
+            # http://twistedmatrix.com/trac/ticket/4840
+            self.deferred.callback(self.stream.getvalue())
+        else:
+            self.deferred.errback(reason)
+
+
+def read_body_with_max_size(response, max_size):
+    """
+    Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+
+    If the maximum file size is reached, the returned Deferred will resolve to a
+    Failure with a BodyExceededMaxSize exception.
+
+    Args:
+        response: The HTTP response to read from.
+        max_size: The maximum file size to allow.
+
+    Returns:
+        A Deferred which resolves to the read body.
+    """
+    d = defer.Deferred()
+
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_size is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_size:
+            response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+            return d
+
+    response.deliverBody(_ReadBodyWithMaxSizeProtocol(d, max_size))
+    return d
+
+
+class SizeLimitingRequest(server.Request):
+    def handleContentChunk(self, data):
+        if self.content.tell() + len(data) > MAX_REQUEST_SIZE:
+            logger.info(
+                "Aborting connection from %s because the request exceeds maximum size",
+                self.client.host)
+            self.transport.abortConnection()
+            return
+
+        return super().handleContentChunk(data)

+ 30 - 30
sydent/http/httpserver.py

@@ -29,6 +29,7 @@ from sydent.http.servlets.authenticated_bind_threepid_servlet import (
 from sydent.http.servlets.authenticated_unbind_threepid_servlet import (
     AuthenticatedUnbindThreePidServlet,
 )
+from sydent.http.httpcommon import SizeLimitingRequest
 
 logger = logging.getLogger(__name__)
 
@@ -45,26 +46,17 @@ class ClientApiHttpServer:
         v2 = self.sydent.servlets.v2
 
         validate = Resource()
+        validate_v2 = Resource()
         email = Resource()
+        email_v2 = Resource()
         msisdn = Resource()
-        emailReqCode = self.sydent.servlets.emailRequestCode
-        emailValCode = self.sydent.servlets.emailValidate
-        msisdnReqCode = self.sydent.servlets.msisdnRequestCode
-        msisdnValCode = self.sydent.servlets.msisdnValidate
-        getValidated3pid = self.sydent.servlets.getValidated3pid
-
-        lookup = self.sydent.servlets.lookup
-        bulk_lookup = self.sydent.servlets.bulk_lookup
+        msisdn_v2 = Resource()
 
         info = self.sydent.servlets.info
         internalInfo = self.sydent.servlets.internalInfo
 
-        hash_details = self.sydent.servlets.hash_details
-        lookup_v2 = self.sydent.servlets.lookup_v2
-
         threepid_v1 = Resource()
         threepid_v2 = Resource()
-        bind = self.sydent.servlets.threepidBind
         unbind = self.sydent.servlets.threepidUnbind
 
         pubkey = Resource()
@@ -72,8 +64,6 @@ class ClientApiHttpServer:
 
         userDirectory = Resource()
 
-        pk_ed25519 = self.sydent.servlets.pubkey_ed25519
-
         root.putChild(b'_matrix', matrix)
         matrix.putChild(b'identity', identity)
         identity.putChild(b'api', api)
@@ -83,36 +73,45 @@ class ClientApiHttpServer:
         validate.putChild(b'email', email)
         validate.putChild(b'msisdn', msisdn)
 
+        validate_v2.putChild(b'email', email_v2)
+        validate_v2.putChild(b'msisdn', msisdn_v2)
+
         v1.putChild(b'validate', validate)
 
-        v1.putChild(b'lookup', lookup)
-        v1.putChild(b'bulk_lookup', bulk_lookup)
+        v1.putChild(b'lookup', self.sydent.servlets.lookup)
+        v1.putChild(b'bulk_lookup', self.sydent.servlets.bulk_lookup)
 
         v1.putChild(b'pubkey', pubkey)
         pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid)
-        pubkey.putChild(b'ed25519:0', pk_ed25519)
+        pubkey.putChild(b'ed25519:0', self.sydent.servlets.pubkey_ed25519)
         pubkey.putChild(b'ephemeral', ephemeralPubkey)
         ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid)
 
-        threepid_v2.putChild(b'getValidated3pid', getValidated3pid)
-        threepid_v2.putChild(b'bind', bind)
+        threepid_v2.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pidV2)
+        threepid_v2.putChild(b'bind', self.sydent.servlets.threepidBindV2)
         threepid_v2.putChild(b'unbind', unbind)
 
-        threepid_v1.putChild(b'getValidated3pid', getValidated3pid)
+        threepid_v1.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pid)
         threepid_v1.putChild(b'unbind', unbind)
         if self.sydent.enable_v1_associations:
-            threepid_v1.putChild(b'bind', bind)
+            threepid_v1.putChild(b'bind', self.sydent.servlets.threepidBind)
 
         v1.putChild(b'3pid', threepid_v1)
 
         v1.putChild(b'info', info)
         v1.putChild(b'internal-info', internalInfo)
 
-        email.putChild(b'requestToken', emailReqCode)
-        email.putChild(b'submitToken', emailValCode)
+        email.putChild(b'requestToken', self.sydent.servlets.emailRequestCode)
+        email.putChild(b'submitToken', self.sydent.servlets.emailValidate)
+
+        email_v2.putChild(b'requestToken', self.sydent.servlets.emailRequestCodeV2)
+        email_v2.putChild(b'submitToken', self.sydent.servlets.emailValidateV2)
+
+        msisdn.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCode)
+        msisdn.putChild(b'submitToken', self.sydent.servlets.msisdnValidate)
 
-        msisdn.putChild(b'requestToken', msisdnReqCode)
-        msisdn.putChild(b'submitToken', msisdnValCode)
+        msisdn_v2.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCodeV2)
+        msisdn_v2.putChild(b'submitToken', self.sydent.servlets.msisdnValidateV2)
 
         v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
 
@@ -136,15 +135,16 @@ class ClientApiHttpServer:
         account.putChild(b'logout', self.sydent.servlets.logoutServlet)
 
         # v2 versions of existing APIs
-        v2.putChild(b'validate', validate)
+        v2.putChild(b'validate', validate_v2)
         v2.putChild(b'pubkey', pubkey)
         v2.putChild(b'3pid', threepid_v2)
-        v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
-        v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServlet)
-        v2.putChild(b'lookup', lookup_v2)
-        v2.putChild(b'hash_details', hash_details)
+        v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServletV2)
+        v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServletV2)
+        v2.putChild(b'lookup', self.sydent.servlets.lookup_v2)
+        v2.putChild(b'hash_details', self.sydent.servlets.hash_details)
 
         self.factory = Site(root)
+        self.factory.requestFactory = SizeLimitingRequest
         self.factory.displayTracebacks = False
 
     def setup(self):

+ 7 - 2
sydent/http/matrixfederationagent.py

@@ -25,11 +25,12 @@ from zope.interface import implementer
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 
+from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size
 from sydent.http.srvresolver import SrvResolver, pick_server_from_list
 from sydent.util import json_decoder
 from sydent.util.ttlcache import TTLCache
@@ -46,6 +47,9 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
 # cap for .well-known cache period
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024  # 50 KiB
+
 logger = logging.getLogger(__name__)
 well_known_cache = TTLCache('well-known')
 
@@ -316,7 +320,7 @@ class MatrixFederationAgent(object):
         logger.info("Fetching %s", uri_str)
         try:
             response = yield self._well_known_agent.request(b"GET", uri)
-            body = yield readBody(response)
+            body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code, ))
 
@@ -334,6 +338,7 @@ class MatrixFederationAgent(object):
             cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
             cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
             defer.returnValue((None, cache_period))
+            return
 
         result = parsed_body["m.server"].encode("ascii")
 

+ 2 - 3
sydent/http/servlets/accountservlet.py

@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from twisted.web.resource import Resource
 
 from sydent.http.servlets import jsonwrap, send_cors
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 
 class AccountServlet(Resource):
@@ -36,7 +36,7 @@ class AccountServlet(Resource):
         """
         send_cors(request)
 
-        account = authIfV2(self.sydent, request)
+        account = authV2(self.sydent, request)
 
         return {
             "user_id": account.userId,
@@ -45,4 +45,3 @@ class AccountServlet(Resource):
     def render_OPTIONS(self, request):
         send_cors(request)
         return b''
-

+ 5 - 3
sydent/http/servlets/blindlysignstuffservlet.py

@@ -22,7 +22,7 @@ import signedjson.key
 import signedjson.sign
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 logger = logging.getLogger(__name__)
 
@@ -30,16 +30,18 @@ logger = logging.getLogger(__name__)
 class BlindlySignStuffServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.server_name = syd.server_name
         self.tokenStore = JoinTokenStore(syd)
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         args = get_args(request, ("private_key", "token", "mxid"))
 

+ 0 - 3
sydent/http/servlets/bulklookupservlet.py

@@ -21,7 +21,6 @@ from sydent.db.threepid_associations import GlobalAssociationStore
 import logging
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
 
 
 logger = logging.getLogger(__name__)
@@ -45,8 +44,6 @@ class BulkLookupServlet(Resource):
         """
         send_cors(request)
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ('threepids',))
 
         threepids = args['threepids']

+ 17 - 6
sydent/http/servlets/emailservlet.py

@@ -19,7 +19,7 @@ import logging
 
 from twisted.web.resource import Resource
 
-from sydent.util.stringutils import is_valid_client_secret
+from sydent.util.stringutils import is_valid_client_secret, MAX_EMAIL_ADDRESS_LENGTH
 from sydent.util.emailutils import EmailAddressException, EmailSendException
 from sydent.validators import (
     IncorrectClientSecretException,
@@ -31,7 +31,7 @@ from sydent.validators import (
 from sydent.validators.common import validate_next_link
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 logger = logging.getLogger(__name__)
 
@@ -39,14 +39,16 @@ logger = logging.getLogger(__name__)
 class EmailRequestCodeServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         args = get_args(request, ('email', 'client_secret', 'send_attempt'))
 
@@ -61,6 +63,13 @@ class EmailRequestCodeServlet(Resource):
                 'error': 'Invalid client_secret provided'
             }
 
+        if not (0 < len(email) <= MAX_EMAIL_ADDRESS_LENGTH):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'Invalid email provided'
+            }
+
         ipaddress = self.sydent.ip_from_request(request)
         brand = self.sydent.brand_from_request(request)
 
@@ -99,8 +108,9 @@ class EmailRequestCodeServlet(Resource):
 class EmailValidateCodeServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
+        self.require_auth = require_auth
 
     def render_GET(self, request):
         args = get_args(request, ('nextLink',), required=False)
@@ -137,7 +147,8 @@ class EmailValidateCodeServlet(Resource):
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         return self.do_validate_request(request)
 

+ 5 - 3
sydent/http/servlets/getvalidated3pidservlet.py

@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from twisted.web.resource import Resource
 
 from sydent.http.servlets import jsonwrap, get_args
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.validators import (
@@ -32,12 +32,14 @@ from sydent.validators import (
 class GetValidated3pidServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_GET(self, request):
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         args = get_args(request, ('sid', 'client_secret'))
 

+ 2 - 2
sydent/http/servlets/hashdetailsservlet.py

@@ -16,7 +16,7 @@
 from __future__ import absolute_import
 
 from twisted.web.resource import Resource
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 import logging
 
@@ -48,7 +48,7 @@ class HashDetailsServlet(Resource):
         """
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        authV2(self.sydent, request)
 
         return {
             "algorithms": self.known_algorithms,

+ 2 - 2
sydent/http/servlets/logoutservlet.py

@@ -21,7 +21,7 @@ import logging
 
 from sydent.http.servlets import jsonwrap, send_cors
 from sydent.db.accounts import AccountStore
-from sydent.http.auth import authIfV2, tokenFromRequest
+from sydent.http.auth import authV2, tokenFromRequest
 
 
 logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class LogoutServlet(Resource):
         """
         send_cors(request)
 
-        authIfV2(self.sydent, request, False)
+        authV2(self.sydent, request, False)
 
         token = tokenFromRequest(request)
 

+ 1 - 3
sydent/http/servlets/lookupservlet.py

@@ -23,9 +23,9 @@ import logging
 import signedjson.sign
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
 from sydent.util import json_decoder
 
+
 logger = logging.getLogger(__name__)
 
 
@@ -48,8 +48,6 @@ class LookupServlet(Resource):
         """
         send_cors(request)
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ('medium', 'address'))
 
         medium = args['medium']

+ 2 - 2
sydent/http/servlets/lookupv2servlet.py

@@ -21,7 +21,7 @@ import logging
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.db.threepid_associations import GlobalAssociationStore
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.http.servlets.hashdetailsservlet import HashDetailsServlet
 
 logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class LookupV2Servlet(Resource):
         """
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        authV2(self.sydent, request)
 
         args = get_args(request, ('addresses', 'algorithm', 'pepper'))
 

+ 9 - 5
sydent/http/servlets/msisdnservlet.py

@@ -30,7 +30,7 @@ from sydent.validators import (
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.util.stringutils import is_valid_client_secret
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.util.stringutils import is_valid_client_secret
 
 
@@ -40,14 +40,16 @@ logger = logging.getLogger(__name__)
 class MsisdnRequestCodeServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         args = get_args(request, ('phone_number', 'country', 'client_secret', 'send_attempt'))
 
@@ -115,8 +117,9 @@ class MsisdnRequestCodeServlet(Resource):
 class MsisdnValidateCodeServlet(Resource):
     isLeaf = True
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
+        self.require_auth = require_auth
 
     def render_GET(self, request):
         send_cors(request)
@@ -150,7 +153,8 @@ class MsisdnValidateCodeServlet(Resource):
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
         return self.do_validate_request(request)
 

+ 50 - 3
sydent/http/servlets/registerservlet.py

@@ -25,7 +25,7 @@ from six.moves import urllib
 from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors
 from sydent.http.httpclient import FederationHttpClient
 from sydent.users.tokens import issueToken
-
+from sydent.util.stringutils import is_valid_matrix_server_name
 
 logger = logging.getLogger(__name__)
 
@@ -47,15 +47,62 @@ class RegisterServlet(Resource):
 
         args = get_args(request, ('matrix_server_name', 'access_token'))
 
+        matrix_server = args['matrix_server_name'].lower()
+
+        if not is_valid_matrix_server_name(matrix_server):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'matrix_server_name must be a valid Matrix server name (IP address or hostname)'
+            }
+
         result = yield self.client.get_json(
-            "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % (
-                args['matrix_server_name'], urllib.parse.quote(args['access_token']),
+            "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
+            % (
+                matrix_server,
+                urllib.parse.quote(args['access_token']),
             ),
+            1024 * 5,
         )
+
         if 'sub' not in result:
             raise Exception("Invalid response from homeserver")
 
         user_id = result['sub']
+
+        if not isinstance(user_id, str):
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned a malformed reply'
+            }
+
+        user_id_components = user_id.split(':', 1)
+
+        # Ensure there's a localpart and domain in the returned user ID.
+        if len(user_id_components) != 2:
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned an invalid MXID'
+            }
+
+        user_id_server = user_id_components[1]
+
+        if not is_valid_matrix_server_name(user_id_server):
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned an invalid MXID'
+            }
+
+        if user_id_server != matrix_server:
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned a MXID belonging to another homeserver'
+            }
+
         tok = yield issueToken(self.sydent, user_id)
 
         # XXX: `token` is correct for the spec, but we released with `access_token`

+ 24 - 6
sydent/http/servlets/store_invite_servlet.py

@@ -28,28 +28,35 @@ from unpaddedbase64 import encode_base64
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.db.threepid_associations import GlobalAssociationStore
 
-from sydent.http.servlets import get_args, send_cors, jsonwrap
-from sydent.http.auth import authIfV2
+from sydent.http.servlets import get_args, send_cors, jsonwrap, MatrixRestError
+from sydent.http.auth import authV2
 from sydent.util.emailutils import sendEmail
+from sydent.util.stringutils import MAX_EMAIL_ADDRESS_LENGTH
 
 
 class StoreInviteServlet(Resource):
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.random = random.SystemRandom()
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_POST(self, request):
         send_cors(request)
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ("medium", "address", "room_id", "sender",))
         medium = args["medium"]
         address = args["address"]
         roomId = args["room_id"]
         sender = args["sender"]
 
+        verified_sender = None
+        if self.require_auth:
+            account = authV2(self.sydent, request)
+            verified_sender = sender
+            if account.userId != sender:
+                raise MatrixRestError(403, "M_UNAUTHORIZED", "'sender' doesn't match")
+
         globalAssocStore = GlobalAssociationStore(self.sydent)
         mxid = globalAssocStore.getMxid(medium, address)
         if mxid:
@@ -67,6 +74,13 @@ class StoreInviteServlet(Resource):
                 "error": "Didn't understand medium '%s'" % (medium,),
             }
 
+        if not (0 < len(address) <= MAX_EMAIL_ADDRESS_LENGTH):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'Invalid email provided'
+            }
+
         token = self._randomString(128)
 
         tokenStore = JoinTokenStore(self.sydent)
@@ -103,9 +117,13 @@ class StoreInviteServlet(Resource):
         for k in extra_substitutions:
             substitutions.setdefault(k, '')
 
+        substitutions["bracketed_verified_sender"] = ""
+        if verified_sender:
+            substitutions["bracketed_verified_sender"] = "(%s) " % (verified_sender,)
+
         substitutions["ephemeral_private_key"] = ephemeralPrivateKeyBase64
         if substitutions["room_name"] != '':
-            substitutions["bracketed_room_name"] = "(%s)" % substitutions["room_name"]
+            substitutions["bracketed_room_name"] = "(%s) " % substitutions["room_name"]
 
         substitutions["web_client_location"] = self.sydent.default_web_client_location
         if 'org.matrix.web_client_location' in substitutions:

+ 2 - 3
sydent/http/servlets/termsservlet.py

@@ -21,7 +21,7 @@ import logging
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.terms.terms import get_terms
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.db.terms import TermsStore
 from sydent.db.accounts import AccountStore
 
@@ -54,7 +54,7 @@ class TermsServlet(Resource):
         """
         send_cors(request)
 
-        account = authIfV2(self.sydent, request, False)
+        account = authV2(self.sydent, request, False)
 
         args = get_args(request, ("user_accepts",))
 
@@ -80,4 +80,3 @@ class TermsServlet(Resource):
     def render_OPTIONS(self, request):
         send_cors(request)
         return b''
-

+ 6 - 3
sydent/http/servlets/threepidbindservlet.py

@@ -21,7 +21,7 @@ from twisted.web.resource import Resource
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.validators import SessionExpiredException, IncorrectClientSecretException, InvalidSessionIdException,\
     SessionNotValidatedException
@@ -30,14 +30,17 @@ from sydent.threepid.bind import BindingNotPermittedException
 
 
 class ThreePidBindServlet(Resource):
-    def __init__(self, sydent):
+    def __init__(self, sydent, require_auth=False):
         self.sydent = sydent
+        self.require_auth = require_auth
 
     @jsonwrap
     def render_POST(self, request):
         send_cors(request)
 
-        account = authIfV2(self.sydent, request)
+        account = None
+        if self.require_auth:
+            account = authV2(self.sydent, request)
 
         args = get_args(request, ('sid', 'client_secret', 'mxid'))
 

+ 7 - 2
sydent/http/servlets/threepidunbindservlet.py

@@ -19,7 +19,7 @@ from __future__ import absolute_import
 import json
 import logging
 
-from sydent.hs_federation.verifier import NoAuthenticationError
+from sydent.hs_federation.verifier import NoAuthenticationError, InvalidServerName
 from signedjson.sign import SignatureVerifyException
 
 from sydent.http.servlets import dict_to_json_bytes
@@ -144,7 +144,12 @@ class ThreePidUnbindServlet(Resource):
                     request.write(dict_to_json_bytes({'errcode': 'M_FORBIDDEN', 'error': str(ex)}))
                     request.finish()
                     return
-                except:
+                except InvalidServerName as ex:
+                    request.setResponseCode(400)
+                    request.write(dict_to_json_bytes({'errcode': 'M_INVALID_PARAM', 'error': str(ex)}))
+                    request.finish()
+                    return
+                except Exception:
                     logger.exception("Exception whilst authenticating unbind request")
                     request.setResponseCode(500)
                     request.write(dict_to_json_bytes({'errcode': 'M_UNKNOWN', 'error': 'Internal Server Error'}))

+ 47 - 9
sydent/sydent.py

@@ -24,7 +24,7 @@ import copy
 import logging
 import logging.handlers
 import os
-import re
+from typing import Set
 
 import twisted.internet.reactor
 from twisted.internet import task
@@ -49,6 +49,7 @@ from sydent.hs_federation.verifier import Verifier
 
 from sydent.util.hash import sha256_and_url_safe_base64
 from sydent.util.tokenutils import generateAlphanumericTokenOfLength
+from sydent.util.ip_range import generate_ip_set, DEFAULT_IP_RANGE_BLACKLIST
 
 from sydent.sign.ed25519 import SydentEd25519
 
@@ -84,13 +85,6 @@ from sydent.replication.pusher import Pusher
 
 logger = logging.getLogger(__name__)
 
-
-def set_from_comma_sep_string(rawstr):
-    if rawstr == '':
-        return set()
-    return {x.strip() for x in rawstr.split(',')}
-
-
 CONFIG_DEFAULTS = {
     'general': {
         'server.name': os.environ.get('SYDENT_SERVER_NAME', ''),
@@ -136,6 +130,26 @@ CONFIG_DEFAULTS = {
         # Whether clients and homeservers can register an association using v1 endpoints.
         'enable_v1_associations': 'true',
         'delete_tokens_on_bind': 'true',
+
+        # Prevent outgoing requests from being sent to the following blacklisted
+        # IP address CIDR ranges. If this option is not specified or empty then
+        # it defaults to private IP address ranges.
+        #
+        # The blacklist applies to all outbound requests except replication
+        # requests.
+        #
+        # (0.0.0.0 and :: are always blacklisted, whether or not they are
+        # explicitly listed here, since they correspond to unroutable
+        # addresses.)
+        'ip.blacklist': '',
+
+        # List of IP address CIDR ranges that should be allowed for outbound
+        # requests. This is useful for specifying exceptions to wide-ranging
+        # blacklisted target IP ranges.
+        #
+        # This whitelist overrides `ip.blacklist` and defaults to an empty
+        # list.
+        'ip.whitelist': '',
     },
     'db': {
         'db.file': os.environ.get('SYDENT_DB_PATH', 'sydent.db'),
@@ -245,9 +259,10 @@ CONFIG_DEFAULTS = {
 
 
 class Sydent:
-    def __init__(self, cfg, reactor=twisted.internet.reactor):
+    def __init__(self, cfg, reactor=twisted.internet.reactor, use_tls_for_federation=True):
         self.reactor = reactor
         self.config_file = get_config_file_path()
+        self.use_tls_for_federation = use_tls_for_federation
 
         self.cfg = cfg
 
@@ -328,6 +343,15 @@ class Sydent:
             self.cfg.get("general", "delete_tokens_on_bind")
         )
 
+        ip_blacklist = set_from_comma_sep_string(self.cfg.get("general", "ip.blacklist"))
+        if not ip_blacklist:
+            ip_blacklist = DEFAULT_IP_RANGE_BLACKLIST
+
+        ip_whitelist = set_from_comma_sep_string(self.cfg.get("general", "ip.whitelist"))
+
+        self.ip_blacklist = generate_ip_set(ip_blacklist)
+        self.ip_whitelist = generate_ip_set(ip_whitelist)
+
         self.username_reveal_characters = int(self.cfg.get(
             "email", "email.third_party_invite_username_reveal_characters"
         ))
@@ -400,9 +424,13 @@ class Sydent:
         self.servlets.v1 = V1Servlet(self)
         self.servlets.v2 = V2Servlet(self)
         self.servlets.emailRequestCode = EmailRequestCodeServlet(self)
+        self.servlets.emailRequestCodeV2 = EmailRequestCodeServlet(self, require_auth=True)
         self.servlets.emailValidate = EmailValidateCodeServlet(self)
+        self.servlets.emailValidateV2 = EmailValidateCodeServlet(self, require_auth=True)
         self.servlets.msisdnRequestCode = MsisdnRequestCodeServlet(self)
+        self.servlets.msisdnRequestCodeV2 = MsisdnRequestCodeServlet(self, require_auth=True)
         self.servlets.msisdnValidate = MsisdnValidateCodeServlet(self)
+        self.servlets.msisdnValidateV2 = MsisdnValidateCodeServlet(self, require_auth=True)
         self.servlets.lookup = LookupServlet(self)
         self.servlets.bulk_lookup = BulkLookupServlet(self)
         self.servlets.hash_details = HashDetailsServlet(self, lookup_pepper)
@@ -411,11 +439,15 @@ class Sydent:
         self.servlets.pubkeyIsValid = PubkeyIsValidServlet(self)
         self.servlets.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(self)
         self.servlets.threepidBind = ThreePidBindServlet(self)
+        self.servlets.threepidBindV2 = ThreePidBindServlet(self, require_auth=True)
         self.servlets.threepidUnbind = ThreePidUnbindServlet(self)
         self.servlets.replicationPush = ReplicationPushServlet(self)
         self.servlets.getValidated3pid = GetValidated3pidServlet(self)
+        self.servlets.getValidated3pidV2 = GetValidated3pidServlet(self, require_auth=True)
         self.servlets.storeInviteServlet = StoreInviteServlet(self)
+        self.servlets.storeInviteServletV2 = StoreInviteServlet(self, require_auth=True)
         self.servlets.blindlySignStuffServlet = BlindlySignStuffServlet(self)
+        self.servlets.blindlySignStuffServletV2 = BlindlySignStuffServlet(self, require_auth=True)
         self.servlets.profileReplicationServlet = ProfileReplicationServlet(self)
         self.servlets.userDirectorySearchServlet = UserDirectorySearchServlet(self)
         self.servlets.termsServlet = TermsServlet(self)
@@ -657,6 +689,12 @@ def parse_cfg_bool(value):
     return value.lower() == "true"
 
 
+def set_from_comma_sep_string(rawstr: str) -> Set[str]:
+    if rawstr == '':
+        return set()
+    return {x.strip() for x in rawstr.split(',')}
+
+
 def run_gc():
     threshold = gc.get_threshold()
     counts = gc.get_count()

+ 14 - 2
sydent/threepid/bind.py

@@ -32,6 +32,8 @@ from sydent.http.httpclient import FederationHttpClient
 
 from sydent.threepid import ThreepidAssociation
 
+from sydent.util.stringutils import is_valid_matrix_server_name
+
 from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
@@ -180,6 +182,7 @@ class ThreepidBinder:
         """
         mxid = assoc["mxid"]
         mxid_parts = mxid.split(":", 1)
+
         if len(mxid_parts) != 2:
             logger.error(
                 "Can't notify on bind for unparseable mxid %s. Not retrying.",
@@ -187,8 +190,17 @@ class ThreepidBinder:
             )
             return
 
+        matrix_server = mxid_parts[1]
+
+        if not is_valid_matrix_server_name(matrix_server):
+            logger.error(
+                "MXID server part '%s' not a valid Matrix server name. Not retrying.",
+                matrix_server,
+            )
+            return
+
         post_url = "matrix://%s/_matrix/federation/v1/3pid/onbind" % (
-            mxid_parts[1],
+            matrix_server,
         )
 
         logger.info("Making bind callback to: %s", post_url)
@@ -229,7 +241,7 @@ class ThreepidBinder:
                 "Successfully deleted invite for %s from the store",
                 assoc["address"],
             )
-        except Exception as e:
+        except Exception:
             logger.exception(
                 "Couldn't remove invite for %s from the store",
                 assoc["address"],

+ 5 - 0
sydent/util/emailutils.py

@@ -35,6 +35,7 @@ else:
 import email.utils
 
 from sydent.util import time_msec
+from sydent.util.tokenutils import generateAlphanumericTokenOfLength
 
 logger = logging.getLogger(__name__)
 
@@ -75,6 +76,10 @@ def sendEmail(sydent, templateFile, mailTo, substitutions):
         allSubstitutions[k+"_forhtml"] = escape(v)
         allSubstitutions[k+"_forurl"] = urllib.parse.quote(v)
 
+    # We add randomize the multipart boundary to stop user input from
+    # conflicting with it.
+    allSubstitutions["multipart_boundary"] = generateAlphanumericTokenOfLength(32)
+
     mailString = open(templateFile, encoding="utf-8").read() % allSubstitutions
     parsedFrom = email.utils.parseaddr(mailFrom)[1]
     parsedTo = email.utils.parseaddr(mailTo)[1]

+ 118 - 0
sydent/util/ip_range.py

@@ -0,0 +1,118 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import itertools
+from typing import Iterable, Optional
+
+from netaddr import AddrFormatError, IPNetwork, IPSet
+
+# IP ranges that are considered private / unroutable / don't make sense.
+DEFAULT_IP_RANGE_BLACKLIST = [
+    # Localhost
+    "127.0.0.0/8",
+    # Private networks.
+    "10.0.0.0/8",
+    "172.16.0.0/12",
+    "192.168.0.0/16",
+    # Carrier grade NAT.
+    "100.64.0.0/10",
+    # Address registry.
+    "192.0.0.0/24",
+    # Link-local networks.
+    "169.254.0.0/16",
+    # Formerly used for 6to4 relay.
+    "192.88.99.0/24",
+    # Testing networks.
+    "198.18.0.0/15",
+    "192.0.2.0/24",
+    "198.51.100.0/24",
+    "203.0.113.0/24",
+    # Multicast.
+    "224.0.0.0/4",
+    # Localhost
+    "::1/128",
+    # Link-local addresses.
+    "fe80::/10",
+    # Unique local addresses.
+    "fc00::/7",
+    # Testing networks.
+    "2001:db8::/32",
+    # Multicast.
+    "ff00::/8",
+    # Site-local addresses
+    "fec0::/10",
+]
+
+
+def generate_ip_set(
+    ip_addresses: Optional[Iterable[str]],
+    extra_addresses: Optional[Iterable[str]] = None,
+    config_path: Optional[Iterable[str]] = None,
+) -> IPSet:
+    """
+    Generate an IPSet from a list of IP addresses or CIDRs.
+
+    Additionally, for each IPv4 network in the list of IP addresses, also
+    includes the corresponding IPv6 networks.
+
+    This includes:
+
+    * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
+    * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
+    * 6to4 Address (see RFC 3056, section 2)
+
+    Args:
+        ip_addresses: An iterable of IP addresses or CIDRs.
+        extra_addresses: An iterable of IP addresses or CIDRs.
+        config_path: The path in the configuration for error messages.
+
+    Returns:
+        A new IP set.
+    """
+    result = IPSet()
+    for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
+        try:
+            network = IPNetwork(ip)
+        except AddrFormatError as e:
+            raise Exception(
+                "Invalid IP range provided: %s." % (ip,), config_path
+            ) from e
+        result.add(network)
+
+        # It is possible that these already exist in the set, but that's OK.
+        if ":" not in str(network):
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
+            result.add(_6to4(network))
+
+    return result
+
+
+def _6to4(network: IPNetwork) -> IPNetwork:
+    """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""
+
+    # 6to4 networks consist of:
+    # * 2002 as the first 16 bits
+    # * The first IPv4 address in the network hex-encoded as the next 32 bits
+    # * The new prefix length needs to include the bits from the 2002 prefix.
+    hex_network = hex(network.first)[2:]
+    hex_network = ("0" * (8 - len(hex_network))) + hex_network
+    return IPNetwork(
+        "2002:%s:%s::/%d"
+        % (
+            hex_network[:4],
+            hex_network[4:],
+            16 + network.prefixlen,
+        )
+    )

+ 104 - 3
sydent/util/stringutils.py

@@ -13,18 +13,119 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import re
+from typing import Optional, Tuple
+
+from twisted.internet.abstract import isIPAddress, isIPv6Address
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
+
+# hostname/domain name
+# https://regex101.com/r/OyN1lg/2
+hostname_regex = re.compile(
+    r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)*$",
+    flags=re.IGNORECASE)
+
+# it's unclear what the maximum length of an email address is. RFC3696 (as corrected
+# by errata) says:
+#    the upper limit on address lengths should normally be considered to be 254.
+#
+# In practice, mail servers appear to be more tolerant and allow 400 characters
+# or so. Let's allow 500, which should be plenty for everyone.
+#
+MAX_EMAIL_ADDRESS_LENGTH = 500
 
 
 def is_valid_client_secret(client_secret):
     """Validate that a given string matches the client_secret regex defined by the spec
 
     :param client_secret: The client_secret to validate
-    :type client_secret: unicode
+    :type client_secret: str
 
     :return: Whether the client_secret is valid
     :rtype: bool
     """
-    return client_secret_regex.match(client_secret) is not None
+    return (
+        0 < len(client_secret) <= 255
+        and CLIENT_SECRET_REGEX.match(client_secret) is not None
+    )
+
+
+def is_valid_hostname(string: str) -> bool:
+    """Validate that a given string is a valid hostname or domain name.
+
+    For domain names, this only validates that the form is right (for
+    instance, it doesn't check that the TLD is valid).
+
+    :param string: The string to validate
+    :type string: str
+
+    :return: Whether the input is a valid hostname
+    :rtype: bool
+    """
+
+    return hostname_regex.match(string) is not None
+
+
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts.
+
+    No validation is done on the host part. The port part is validated to be
+    a valid port number.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == "]":
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        host_port = server_name.rsplit(":", 1)
+        host = host_port[0]
+        port = host_port[1] if host_port[1:] else None
+
+        if port:
+            port_num = int(port)
+
+            # exclude things like '08090' or ' 8090'
+            if port != str(port_num) or not (1 <= port_num < 65536):
+                raise ValueError("Invalid port")
+
+        return host, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+def is_valid_matrix_server_name(string: str) -> bool:
+    """Validate that the given string is a valid Matrix server name.
+
+    A string is a valid Matrix server name if it is one of the following, plus
+    an optional port:
+
+    a. IPv4 address
+    b. IPv6 literal (`[IPV6_ADDRESS]`)
+    c. A valid hostname
+
+    :param string: The string to validate
+    :type string: str
+
+    :return: Whether the input is a valid Matrix server name
+    :rtype: bool
+    """
+
+    try:
+        host, port = parse_server_name(string)
+    except ValueError:
+        return False
+
+    valid_ipv4_addr = isIPAddress(host)
+    valid_ipv6_literal = host[0] == "[" and host[-1] == "]" and isIPv6Address(host[1:-1])
+
+    return valid_ipv4_addr or valid_ipv6_literal or is_valid_hostname(host)

+ 2 - 2
tests/test_auth.py

@@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
         self.sydent.db.commit()
 
     def test_can_read_token_from_headers(self):
-        """Tests that Sydent correct extracts an auth token from request headers"""
+        """Tests that Sydent correctly extracts an auth token from request headers"""
         self.sydent.run()
 
         request, _ = make_request(
@@ -59,7 +59,7 @@ class AuthTestCase(unittest.TestCase):
         self.assertEqual(token, self.test_token)
 
     def test_can_read_token_from_query_parameters(self):
-        """Tests that Sydent correct extracts an auth token from query parameters"""
+        """Tests that Sydent correctly extracts an auth token from query parameters"""
         self.sydent.run()
 
         request, _ = make_request(

+ 243 - 0
tests/test_blacklisting.py

@@ -0,0 +1,243 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+
+from mock import patch
+from netaddr import IPSet
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+from twisted.test.proto_helpers import StringTransport
+from twisted.trial.unittest import TestCase
+from twisted.web.client import Agent
+
+from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
+from sydent.http.srvresolver import Server
+from tests.utils import make_request, make_sydent
+
+
+class BlacklistingAgentTest(TestCase):
+    def setUp(self):
+        config = {
+            "general": {
+                "ip.blacklist": "5.0.0.0/8",
+                "ip.whitelist": "5.1.1.1",
+            },
+        }
+
+        self.sydent = make_sydent(test_config=config)
+
+        self.reactor = self.sydent.reactor
+
+        self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+        self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+        self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+        # Configure the reactor's DNS resolver.
+        for (domain, ip) in (
+            (self.safe_domain, self.safe_ip),
+            (self.unsafe_domain, self.unsafe_ip),
+            (self.allowed_domain, self.allowed_ip),
+        ):
+            self.reactor.lookups[domain.decode()] = ip.decode()
+            self.reactor.lookups[ip.decode()] = ip.decode()
+
+        self.ip_whitelist = self.sydent.ip_whitelist
+        self.ip_blacklist = self.sydent.ip_blacklist
+
+    def test_reactor(self):
+        """Apply the blacklisting reactor and ensure it properly blocks
+        connections to particular domains and IPs.
+        """
+        agent = Agent(
+            BlacklistingReactorWrapper(
+                self.reactor,
+                ip_whitelist=self.ip_whitelist,
+                ip_blacklist=self.ip_blacklist,
+            ),
+        )
+
+        # The unsafe domains and IPs should be rejected.
+        for domain in (self.unsafe_domain, self.unsafe_ip):
+            self.failureResultOf(
+                agent.request(b"GET", b"http://" + domain), DNSLookupError
+            )
+
+        self.reactor.tcpClients = []
+
+        # The safe domains IPs should be accepted.
+        for domain in (
+            self.safe_domain,
+            self.allowed_domain,
+            self.safe_ip,
+            self.allowed_ip,
+        ):
+            agent.request(b"GET", b"http://" + domain)
+
+            # Grab the latest TCP connection.
+            (
+                host,
+                port,
+                client_factory,
+                _timeout,
+                _bindAddress,
+            ) = self.reactor.tcpClients.pop()
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_allowed_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.allowed_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        transport, protocol = self._get_http_request(
+            self.allowed_ip.decode("ascii"), 443
+        )
+
+        self.assertRegex(
+            transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
+        )
+        self.assertRegex(transport.value(), b"Host: example.com")
+
+        # Send it the HTTP response
+        res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(res_json), res_json)
+        )
+
+        self.assertEqual(channel.code, 200)
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_safe_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.safe_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
+
+        self.assertRegex(
+            transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
+        )
+        self.assertRegex(transport.value(), b"Host: example.com")
+
+        # Send it the HTTP response
+        res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(res_json), res_json)
+        )
+
+        self.assertEqual(channel.code, 200)
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_unsafe_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.unsafe_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        self.assertNot(self.reactor.tcpClients)
+
+        self.assertEqual(channel.code, 500)
+
+    def _get_http_request(self, expected_host, expected_port):
+        clients = self.reactor.tcpClients
+        (host, port, factory, _timeout, _bindAddress) = clients[-1]
+        self.assertEqual(host, expected_host)
+        self.assertEqual(port, expected_port)
+
+        # complete the connection and wire it up to a fake transport
+        protocol = factory.buildProtocol(None)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        return transport, protocol

+ 46 - 0
tests/test_register.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.trial import unittest
+
+from tests.utils import make_request, make_sydent
+
+
+class RegisterTestCase(unittest.TestCase):
+    """Tests Sydent's register servlet"""
+
+    def setUp(self):
+        # Create a new sydent
+        self.sydent = make_sydent()
+
+    def test_sydent_rejects_invalid_hostname(self):
+        """Tests that the /register endpoint rejects an invalid hostname passed as matrix_server_name"""
+        self.sydent.run()
+
+        bad_hostname = "example.com#"
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            content={
+                "matrix_server_name": bad_hostname,
+                "access_token": "foo"
+            })
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        self.assertEqual(channel.code, 400)

+ 37 - 0
tests/test_util.py

@@ -0,0 +1,37 @@
+from twisted.trial import unittest
+from sydent.util.stringutils import is_valid_matrix_server_name
+
+
+class UtilTests(unittest.TestCase):
+    """Tests Sydent utility functions."""
+
+    def test_is_valid_matrix_server_name(self):
+        """Tests that the is_valid_matrix_server_name function accepts only
+        valid hostnames (or domain names), with optional port number.
+        """
+        self.assertTrue(is_valid_matrix_server_name("9.9.9.9"))
+        self.assertTrue(is_valid_matrix_server_name("9.9.9.9:4242"))
+        self.assertTrue(is_valid_matrix_server_name("[::]"))
+        self.assertTrue(is_valid_matrix_server_name("[::]:4242"))
+        self.assertTrue(is_valid_matrix_server_name("[a:b:c::]:4242"))
+
+        self.assertTrue(is_valid_matrix_server_name("example.com"))
+        self.assertTrue(is_valid_matrix_server_name("EXAMPLE.COM"))
+        self.assertTrue(is_valid_matrix_server_name("ExAmPlE.CoM"))
+        self.assertTrue(is_valid_matrix_server_name("example.com:4242"))
+        self.assertTrue(is_valid_matrix_server_name("localhost"))
+        self.assertTrue(is_valid_matrix_server_name("localhost:9000"))
+        self.assertTrue(is_valid_matrix_server_name("a.b.c.d:1234"))
+
+        self.assertFalse(is_valid_matrix_server_name("[:::]"))
+        self.assertFalse(is_valid_matrix_server_name("a:b:c::"))
+
+        self.assertFalse(is_valid_matrix_server_name("example.com:65536"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:0"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:-1"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:a"))
+        self.assertFalse(is_valid_matrix_server_name("example.com: "))
+        self.assertFalse(is_valid_matrix_server_name("example.com:04242"))
+        self.assertFalse(is_valid_matrix_server_name("example.com: 4242"))
+        self.assertFalse(is_valid_matrix_server_name("example.com/example.com"))
+        self.assertFalse(is_valid_matrix_server_name("example.com#example.com"))

+ 42 - 14
tests/utils.py

@@ -2,9 +2,19 @@ import json
 from io import BytesIO
 import logging
 import os
-
+from typing import Dict
 import attr
 from six import text_type
+from zope.interface import implementer
+from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.defer import fail, succeed
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import (
+    IHostnameResolver,
+    IReactorPluggableNameResolver,
+    IResolverSimple,
+)
+
 from twisted.internet import address
 import twisted.logger
 from twisted.web.http_headers import Headers
@@ -53,13 +63,13 @@ def make_sydent(test_config={}):
     # Use an in-memory SQLite database. Note that the database isn't cleaned up between
     # tests, so by default the same database will be used for each test if changed to be
     # a file on disk.
-    if 'db' not in test_config:
-        test_config['db'] = {'db.file': ':memory:'}
+    if "db" not in test_config:
+        test_config["db"] = {"db.file": ":memory:"}
     else:
-        test_config['db'].setdefault('db.file', ':memory:')
+        test_config["db"].setdefault("db.file", ":memory:")
 
-    reactor = MemoryReactorClock()
-    return Sydent(reactor=reactor, cfg=parse_config_dict(test_config))
+    reactor = ResolvingMemoryReactorClock()
+    return Sydent(reactor=reactor, cfg=parse_config_dict(test_config), use_tls_for_federation=False)
 
 
 @attr.s
@@ -149,6 +159,7 @@ class FakeChannel(object):
 
 class FakeSite:
     """A fake Twisted Web Site."""
+
     pass
 
 
@@ -191,10 +202,7 @@ def make_request(
         path = path.encode("ascii")
 
     # Decorate it to be the full path, if we're using shorthand
-    if (
-        shorthand
-        and not path.startswith(b"/_matrix")
-    ):
+    if shorthand and not path.startswith(b"/_matrix"):
         path = b"/_matrix/identity/v2/" + path
         path = path.replace(b"//", b"/")
 
@@ -253,10 +261,7 @@ def setup_logging():
     """
     root_logger = logging.getLogger()
 
-    log_format = (
-        "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s"
-        " - %(message)s"
-    )
+    log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
 
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
@@ -268,3 +273,26 @@ def setup_logging():
 
 
 setup_logging()
+
+
+@implementer(IReactorPluggableNameResolver)
+class ResolvingMemoryReactorClock(MemoryReactorClock):
+    """
+    A MemoryReactorClock that supports name resolution.
+    """
+
+    def __init__(self):
+        lookups = self.lookups = {}  # type: Dict[str, str]
+
+        @implementer(IResolverSimple)
+        class FakeResolver:
+            def getHostByName(self, name, timeout=None):
+                if name not in lookups:
+                    return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
+                return succeed(lookups[name])
+
+        self.nameResolver = SimpleResolverComplexifier(FakeResolver())
+        super().__init__()
+
+    def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+        raise NotImplementedError()