Browse Source

Merge branch 'babolivier/dinsic-merge-2' into dinsic

Brendan Abolivier 3 years ago
parent
commit
aeda939403
38 changed files with 1221 additions and 186 deletions
  1. 18 0
      CHANGELOG.md
  2. 3 0
      matrix_is_test/launcher.py
  3. 15 16
      res/matrix-org/invite_template.eml
  4. 5 5
      res/matrix-org/verification_template.eml
  5. 14 15
      res/vector-im/invite_template.eml
  6. 5 11
      res/vector-im/verification_template.eml
  7. 18 2
      sydent/hs_federation/verifier.py
  8. 16 18
      sydent/http/auth.py
  9. 155 0
      sydent/http/blacklisting_reactor.py
  10. 24 7
      sydent/http/httpclient.py
  11. 119 1
      sydent/http/httpcommon.py
  12. 30 30
      sydent/http/httpserver.py
  13. 7 2
      sydent/http/matrixfederationagent.py
  14. 2 3
      sydent/http/servlets/accountservlet.py
  15. 5 3
      sydent/http/servlets/blindlysignstuffservlet.py
  16. 0 3
      sydent/http/servlets/bulklookupservlet.py
  17. 17 6
      sydent/http/servlets/emailservlet.py
  18. 5 3
      sydent/http/servlets/getvalidated3pidservlet.py
  19. 2 2
      sydent/http/servlets/hashdetailsservlet.py
  20. 2 2
      sydent/http/servlets/logoutservlet.py
  21. 1 3
      sydent/http/servlets/lookupservlet.py
  22. 2 2
      sydent/http/servlets/lookupv2servlet.py
  23. 9 5
      sydent/http/servlets/msisdnservlet.py
  24. 50 3
      sydent/http/servlets/registerservlet.py
  25. 24 6
      sydent/http/servlets/store_invite_servlet.py
  26. 2 3
      sydent/http/servlets/termsservlet.py
  27. 6 3
      sydent/http/servlets/threepidbindservlet.py
  28. 7 2
      sydent/http/servlets/threepidunbindservlet.py
  29. 47 9
      sydent/sydent.py
  30. 14 2
      sydent/threepid/bind.py
  31. 5 0
      sydent/util/emailutils.py
  32. 118 0
      sydent/util/ip_range.py
  33. 104 3
      sydent/util/stringutils.py
  34. 2 2
      tests/test_auth.py
  35. 243 0
      tests/test_blacklisting.py
  36. 46 0
      tests/test_register.py
  37. 37 0
      tests/test_util.py
  38. 42 14
      tests/utils.py

+ 18 - 0
CHANGELOG.md

@@ -1,3 +1,21 @@
+Sydent 2.3.0 (unreleased)
+=========================
+
+Bug fixes
+---------
+- During user registration on the identity server, validate that the MXID returned by the contacted homeserver is valid for that homeserver. ([cc97fff](https://github.com/matrix-org/sydent/commit/cc97fff))
+- Ensure that `/v2/` endponts are correctly authenticated. ([ce04a68](https://github.com/matrix-org/sydent/commit/ce04a68))
+- Perform additional validation on the response received when requesting server signing keys. ([07e6da7](https://github.com/matrix-org/sydent/commit/07e6da7))
+
+Security fixes
+--------------
+
+- Validate the `matrix_server_name` parameter given during user registration. ([9e57334](https://github.com/matrix-org/sydent/commit/9e57334), [8936925](https://github.com/matrix-org/sydent/commit/8936925), [3d531ed](https://github.com/matrix-org/sydent/commit/3d531ed), [0f00412](https://github.com/matrix-org/sydent/commit/0f00412))
+- Limit the size of requests received from HTTP clients. ([89071a1](https://github.com/matrix-org/sydent/commit/89071a1), [0523511](https://github.com/matrix-org/sydent/commit/0523511), [f56eee3](https://github.com/matrix-org/sydent/commit/f56eee3))
+- Limit the size of responses received from HTTP servers. ([89071a1](https://github.com/matrix-org/sydent/commit/89071a1), [0523511](https://github.com/matrix-org/sydent/commit/0523511), [f56eee3](https://github.com/matrix-org/sydent/commit/f56eee3))
+- In invite emails, randomise the multipart boundary, and include MXIDs where available. ([4469d1d](https://github.com/matrix-org/sydent/commit/4469d1d), [6b405a8](https://github.com/matrix-org/sydent/commit/6b405a8), [65a6e91](https://github.com/matrix-org/sydent/commit/65a6e91))
+- Perform additional validation on the `client_secret` and `email` parameters to various APIs. ([3175fd3](https://github.com/matrix-org/sydent/commit/3175fd3))
+
 Sydent 2.2.0 (2020-09-11)
 Sydent 2.2.0 (2020-09-11)
 =========================
 =========================
 
 

+ 3 - 0
matrix_is_test/launcher.py

@@ -37,6 +37,9 @@ info_path = {info_path}
 templates.path = {testsubject_path}/res
 templates.path = {testsubject_path}/res
 brand.default = is-test
 brand.default = is-test
 
 
+
+ip.whitelist = 127.0.0.1
+
 [email]
 [email]
 email.tlsmode = 0
 email.tlsmode = 0
 email.invite.subject = %(sender_display_name)s has invited you to chat
 email.invite.subject = %(sender_display_name)s has invited you to chat

+ 15 - 16
res/matrix-org/invite_template.eml

@@ -4,19 +4,20 @@ To: %(to)s
 Message-ID: %(messageid)s
 Message-ID: %(messageid)s
 Subject: %(subject_header_value)s
 Subject: %(subject_header_value)s
 MIME-Version: 1.0
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
 Hi,
 Hi,
 
 
-%(sender_display_name)s has invited you into a room %(bracketed_room_name)s on
-Matrix. To join the conversation, either pick a Matrix client from
-https://matrix.org/docs/projects/try-matrix-now.html or use the single-click
-link below to join via Element (requires Chrome, Firefox, Safari, iOS or Android)
+%(sender_display_name)s %(bracketed_verified_sender)shas invited you into a room
+%(bracketed_room_name)son Matrix. To join the conversation, either pick a
+Matrix client from https://matrix.org/docs/projects/try-matrix-now.html or use
+the single-click link below to join via Element (requires Chrome, Firefox,
+Safari, iOS or Android)
 
 
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fmatrix.org%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fmatrix.org%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 
 
@@ -38,12 +39,7 @@ Thanks,
 
 
 Matrix
 Matrix
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -68,6 +64,10 @@ pre, code {
     padding: 20px;
     padding: 20px;
 }
 }
 
 
+.low-contrast {
+    color: #666666
+}
+
 #inner {
 #inner {
     width: 640px;
     width: 640px;
 }
 }
@@ -102,7 +102,7 @@ pre, code {
 
 
 <p>Hi,</p>
 <p>Hi,</p>
 
 
-<p>%(sender_display_name_forhtml)s has invited you into a room %(bracketed_room_name_forhtml)s on
+<p>%(sender_display_name_forhtml)s <span class="low-contrast">%(bracketed_verified_sender_forhtml)s</span> has invited you into a room %(bracketed_room_name_forhtml)s on
 Matrix. To join the conversation, either <a href="https://matrix.org/docs/projects/try-matrix-now.html">pick a Matrix client</a> or use the single-click
 Matrix. To join the conversation, either <a href="https://matrix.org/docs/projects/try-matrix-now.html">pick a Matrix client</a> or use the single-click
 link below to join via Element (requires
 link below to join via Element (requires
 <a href="https://www.google.com/chrome">Chrome</a>,
 <a href="https://www.google.com/chrome">Chrome</a>,
@@ -139,6 +139,5 @@ create new communication solutions or extend the capabilities and reach of exist
         </table>
         </table>
     </body>
     </body>
 </html>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 5 - 5
res/matrix-org/verification_template.eml

@@ -4,10 +4,10 @@ To: %(to)s
 Message-ID: %(messageid)s
 Message-ID: %(messageid)s
 Subject: Confirm your email address for Matrix
 Subject: Confirm your email address for Matrix
 MIME-Version: 1.0
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -35,7 +35,7 @@ Matrix defines the standard, and provides open source reference implementations
 Matrix-compatible Servers, Clients, Client SDKs and Application Services to help you
 Matrix-compatible Servers, Clients, Client SDKs and Application Services to help you
 create new communication solutions or extend the capabilities and reach of existing ones.
 create new communication solutions or extend the capabilities and reach of existing ones.
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -85,4 +85,4 @@ create new communication solutions or extend the capabilities and reach of exist
 </body>
 </body>
 </html>
 </html>
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 14 - 15
res/vector-im/invite_template.eml

@@ -4,17 +4,18 @@ To: %(to)s
 Message-ID: %(messageid)s
 Message-ID: %(messageid)s
 Subject: %(subject_header_value)s
 Subject: %(subject_header_value)s
 MIME-Version: 1.0
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
 Hi,
 Hi,
 
 
-%(sender_display_name)s has invited you into a room %(bracketed_room_name)s on
-Element. To join the conversation please follow the link below.
+%(sender_display_name)s %(bracketed_verified_sender)shas invited you into a room
+%(bracketed_room_name)son Element. To join the conversation please follow the
+link below.
 
 
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fvector.im%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 %(web_client_location)/#/room/%(room_id_forurl)s?email=%(to_forurl)s&signurl=https%%3A%%2F%%2Fvector.im%%2F_matrix%%2Fidentity%%2Fapi%%2Fv1%%2Fsign-ed25519%%3Ftoken%%3D%(token)s%%26private_key%%3D%(ephemeral_private_key)s&room_name=%(room_name_forurl)s&room_avatar_url=%(room_avatar_url_forurl)s&inviter_name=%(sender_display_name_forurl)s&guest_access_token=%(guest_access_token_forurl)s&guest_user_id=%(guest_user_id_forurl)s
 
 
@@ -51,12 +52,7 @@ decentralized communication delivering a community of users, bridged networks,
 integrated bots and applications plus full end-to-end encryption. To learn more about
 integrated bots and applications plus full end-to-end encryption. To learn more about
 Matrix visit https://matrix.org.
 Matrix visit https://matrix.org.
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -81,6 +77,10 @@ pre, code {
     padding: 20px;
     padding: 20px;
 }
 }
 
 
+.low-contrast {
+    color: #666666
+}
+
 #inner {
 #inner {
     width: 640px;
     width: 640px;
 }
 }
@@ -123,8 +123,8 @@ pre, code {
 
 
 <p>Hi,</p>
 <p>Hi,</p>
 
 
-<p>%(sender_display_name_forhtml)s has invited you into a room %(bracketed_room_name_forhtml)s on
-Element.</p>
+<p>%(sender_display_name_forhtml)s <span class="low-contrast">%(bracketed_verified_sender_forhtml)s</span> has invited you into a
+room %(bracketed_room_name_forhtml)s on Element.</p>
 
 
 <p>
 <p>
     <a
     <a
@@ -173,6 +173,5 @@ Matrix visit https://matrix.org.</p>
         </table>
         </table>
     </body>
     </body>
 </html>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 5 - 11
res/vector-im/verification_template.eml

@@ -4,10 +4,10 @@ To: %(to)s
 Message-ID: %(messageid)s
 Message-ID: %(messageid)s
 Subject: Confirm your email address for Element
 Subject: Confirm your email address for Element
 MIME-Version: 1.0
 MIME-Version: 1.0
-Content-Type: multipart/alternative; 
-	boundary="7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ"
+Content-Type: multipart/alternative;
+	boundary="%(multipart_boundary)s"
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
+--%(multipart_boundary)s
 Content-Type: text/plain; charset=UTF-8
 Content-Type: text/plain; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -51,12 +51,7 @@ decentralized communication delivering a community of users, bridged networks,
 integrated bots and applications plus full end-to-end encryption. To learn more about
 integrated bots and applications plus full end-to-end encryption. To learn more about
 Matrix visit https://matrix.org.
 Matrix visit https://matrix.org.
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ
-Content-Type: multipart/related;
-	boundary="M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR";
-	type="text/html"
-
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR
+--%(multipart_boundary)s
 Content-Type: text/html; charset=UTF-8
 Content-Type: text/html; charset=UTF-8
 Content-Disposition: inline
 Content-Disposition: inline
 
 
@@ -171,6 +166,5 @@ pre, code {
         </table>
         </table>
     </body>
     </body>
 </html>
 </html>
---M3yzHl5YZehm9v4bAM8sKEdcOoVnRnKR--
 
 
---7REaIwWQCioQ6NaBlAQlg8ztbUQj6PKJ--
+--%(multipart_boundary)s--

+ 18 - 2
sydent/hs_federation/verifier.py

@@ -25,6 +25,7 @@ import signedjson.key
 from signedjson.sign import SignatureVerifyException
 from signedjson.sign import SignatureVerifyException
 
 
 from sydent.http.httpclient import FederationHttpClient
 from sydent.http.httpclient import FederationHttpClient
+from sydent.util.stringutils import is_valid_matrix_server_name
 
 
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -37,6 +38,13 @@ class NoAuthenticationError(Exception):
     pass
     pass
 
 
 
 
+class InvalidServerName(Exception):
+    """
+    Raised when the provided origin parameter is not a valid hostname (plus optional port).
+    """
+    pass
+
+
 class Verifier(object):
 class Verifier(object):
     """
     """
     Verifies signed json blobs from Matrix Homeservers by finding the
     Verifies signed json blobs from Matrix Homeservers by finding the
@@ -69,14 +77,19 @@ class Verifier(object):
                 defer.returnValue(self.cache[server_name]['verify_keys'])
                 defer.returnValue(self.cache[server_name]['verify_keys'])
 
 
         client = FederationHttpClient(self.sydent)
         client = FederationHttpClient(self.sydent)
-        result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name)
+        result = yield client.get_json("matrix://%s/_matrix/key/v2/server/" % server_name, 1024 * 50)
+
         if 'verify_keys' not in result:
         if 'verify_keys' not in result:
             raise SignatureVerifyException("No key found in response")
             raise SignatureVerifyException("No key found in response")
 
 
         if 'valid_until_ts' in result:
         if 'valid_until_ts' in result:
+            if not isinstance(result['valid_until_ts'], int):
+                raise SignatureVerifyException("Invalid valid_until_ts received, must be an integer")
+
             # Don't cache anything without a valid_until_ts or we wouldn't
             # Don't cache anything without a valid_until_ts or we wouldn't
             # know when to expire it.
             # know when to expire it.
-            logger.info("Got keys for %s: caching until %s", server_name, result['valid_until_ts'])
+
+            logger.info("Got keys for %s: caching until %d", server_name, result['valid_until_ts'])
             self.cache[server_name] = result
             self.cache[server_name] = result
 
 
         defer.returnValue(result['verify_keys'])
         defer.returnValue(result['verify_keys'])
@@ -197,6 +210,9 @@ class Verifier(object):
         if not json_request["signatures"]:
         if not json_request["signatures"]:
             raise NoAuthenticationError("Missing X-Matrix Authorization header")
             raise NoAuthenticationError("Missing X-Matrix Authorization header")
 
 
+        if not is_valid_matrix_server_name(json_request["origin"]):
+            raise InvalidServerName("X-Matrix header's origin parameter must be a valid Matrix server name")
+
         yield self.verifyServerSignedJson(json_request, [origin])
         yield self.verifyServerSignedJson(json_request, [origin])
 
 
         logger.info("Verified request from HS %s", origin)
         logger.info("Verified request from HS %s", origin)

+ 16 - 18
sydent/http/auth.py

@@ -52,7 +52,7 @@ def tokenFromRequest(request):
     return token
     return token
 
 
 
 
-def authIfV2(sydent, request, requireTermsAgreed=True):
+def authV2(sydent, request, requireTermsAgreed=True):
     """For v2 APIs check that the request has a valid access token associated with it
     """For v2 APIs check that the request has a valid access token associated with it
 
 
     :param sydent: The Sydent instance to use.
     :param sydent: The Sydent instance to use.
@@ -67,25 +67,23 @@ def authIfV2(sydent, request, requireTermsAgreed=True):
     :raises MatrixRestError: If the request is v2 but could not be authed or the user has
     :raises MatrixRestError: If the request is v2 but could not be authed or the user has
         not accepted terms.
         not accepted terms.
     """
     """
-    if request.path.startswith(b'/_matrix/identity/v2'):
-        token = tokenFromRequest(request)
+    token = tokenFromRequest(request)
 
 
-        if token is None:
-            raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
+    if token is None:
+        raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
 
 
-        accountStore = AccountStore(sydent)
+    accountStore = AccountStore(sydent)
 
 
-        account = accountStore.getAccountByToken(token)
-        if account is None:
-            raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
+    account = accountStore.getAccountByToken(token)
+    if account is None:
+        raise MatrixRestError(401, "M_UNAUTHORIZED", "Unauthorized")
 
 
-        if requireTermsAgreed:
-            terms = get_terms(sydent)
-            if (
-                terms.getMasterVersion() is not None and
-                account.consentVersion != terms.getMasterVersion()
-            ):
-                raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
+    if requireTermsAgreed:
+        terms = get_terms(sydent)
+        if (
+            terms.getMasterVersion() is not None and
+            account.consentVersion != terms.getMasterVersion()
+        ):
+            raise MatrixRestError(403, "M_TERMS_NOT_SIGNED", "Terms not signed")
 
 
-        return account
-    return None
+    return account

+ 155 - 0
sydent/http/blacklisting_reactor.py

@@ -0,0 +1,155 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import (
+    Any,
+    List,
+    Optional,
+)
+
+from zope.interface import implementer, provider
+from netaddr import IPAddress, IPSet
+
+from twisted.internet.address import IPv4Address, IPv6Address
+from twisted.internet.interfaces import (
+    IAddress,
+    IHostResolution,
+    IReactorPluggableNameResolver,
+    IResolutionReceiver,
+)
+
+
+logger = logging.getLogger(__name__)
+
+
+def check_against_blacklist(
+    ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
+) -> bool:
+    """
+    Compares an IP address to allowed and disallowed IP sets.
+
+    Args:
+        ip_address: The IP address to check
+        ip_whitelist: Allowed IP addresses.
+        ip_blacklist: Disallowed IP addresses.
+
+    Returns:
+        True if the IP address is in the blacklist and not in the whitelist.
+    """
+    if ip_address in ip_blacklist:
+        if ip_whitelist is None or ip_address not in ip_whitelist:
+            return True
+    return False
+
+
+class _IPBlacklistingResolver:
+    """
+    A proxy for reactor.nameResolver which only produces non-blacklisted IP
+    addresses, preventing DNS rebinding attacks on URL preview.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        """
+        Args:
+            reactor: The twisted reactor.
+            ip_whitelist: IP addresses to allow.
+            ip_blacklist: IP addresses to disallow.
+        """
+        self._reactor = reactor
+        self._ip_whitelist = ip_whitelist
+        self._ip_blacklist = ip_blacklist
+
+    def resolveHostName(
+        self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
+    ) -> IResolutionReceiver:
+        addresses = []  # type: List[IAddress]
+
+        def _callback() -> None:
+            has_bad_ip = False
+            for address in addresses:
+                # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
+                # should go through this path.
+                if not isinstance(address, (IPv4Address, IPv6Address)):
+                    continue
+
+                ip_address = IPAddress(address.host)
+
+                if check_against_blacklist(
+                    ip_address, self._ip_whitelist, self._ip_blacklist
+                ):
+                    logger.info(
+                        "Dropped %s from DNS resolution to %s due to blacklist"
+                        % (ip_address, hostname)
+                    )
+                    has_bad_ip = True
+
+            # if we have a blacklisted IP, we'd like to raise an error to block the
+            # request, but all we can really do from here is claim that there were no
+            # valid results.
+            if not has_bad_ip:
+                for address in addresses:
+                    recv.addressResolved(address)
+            recv.resolutionComplete()
+
+        @provider(IResolutionReceiver)
+        class EndpointReceiver:
+            @staticmethod
+            def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
+                recv.resolutionBegan(resolutionInProgress)
+
+            @staticmethod
+            def addressResolved(address: IAddress) -> None:
+                addresses.append(address)
+
+            @staticmethod
+            def resolutionComplete() -> None:
+                _callback()
+
+        self._reactor.nameResolver.resolveHostName(
+            EndpointReceiver, hostname, portNumber=portNumber
+        )
+
+        return recv
+
+
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+    """
+    A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+    addresses, to prevent DNS rebinding.
+    """
+
+    def __init__(
+        self,
+        reactor: IReactorPluggableNameResolver,
+        ip_whitelist: Optional[IPSet],
+        ip_blacklist: IPSet,
+    ):
+        self._reactor = reactor
+
+        # We need to use a DNS resolver which filters out blacklisted IP
+        # addresses, to prevent DNS rebinding.
+        self.nameResolver = _IPBlacklistingResolver(
+            self._reactor, ip_whitelist, ip_blacklist
+        )
+
+    def __getattr__(self, attr: str) -> Any:
+        # Passthrough to the real reactor except for the DNS resolver.
+        return getattr(self._reactor, attr)

+ 24 - 7
sydent/http/httpclient.py

@@ -22,9 +22,11 @@ from io import BytesIO
 from twisted.internet import defer
 from twisted.internet import defer
 from twisted.web.client import FileBodyProducer, Agent, readBody
 from twisted.web.client import FileBodyProducer, Agent, readBody
 from twisted.web.http_headers import Headers
 from twisted.web.http_headers import Headers
-from sydent.http.matrixfederationagent import MatrixFederationAgent
 
 
+from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
+from sydent.http.matrixfederationagent import MatrixFederationAgent
 from sydent.http.federation_tls_options import ClientTLSOptionsFactory
 from sydent.http.federation_tls_options import ClientTLSOptionsFactory
+from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size
 from sydent.util import json_decoder
 from sydent.util import json_decoder
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -35,12 +37,15 @@ class HTTPClient(object):
     requests.
     requests.
     """
     """
     @defer.inlineCallbacks
     @defer.inlineCallbacks
-    def get_json(self, uri):
+    def get_json(self, uri, max_size = None):
         """Make a GET request to an endpoint returning JSON and parse result
         """Make a GET request to an endpoint returning JSON and parse result
 
 
         :param uri: The URI to make a GET request to.
         :param uri: The URI to make a GET request to.
         :type uri: unicode
         :type uri: unicode
 
 
+        :param max_size: The maximum size (in bytes) to allow as a response.
+        :type max_size: int
+
         :return: A deferred containing JSON parsed into a Python object.
         :return: A deferred containing JSON parsed into a Python object.
         :rtype: twisted.internet.defer.Deferred[dict[any, any]]
         :rtype: twisted.internet.defer.Deferred[dict[any, any]]
         """
         """
@@ -50,7 +55,7 @@ class HTTPClient(object):
             b"GET",
             b"GET",
             uri.encode("utf8"),
             uri.encode("utf8"),
         )
         )
-        body = yield readBody(response)
+        body = yield read_body_with_max_size(response, max_size)
         try:
         try:
             # json.loads doesn't allow bytes in Python 3.5
             # json.loads doesn't allow bytes in Python 3.5
             json_body = json_decoder.decode(body.decode("UTF-8"))
             json_body = json_decoder.decode(body.decode("UTF-8"))
@@ -95,7 +100,11 @@ class HTTPClient(object):
         # Ensure the body object is read otherwise we'll leak HTTP connections
         # Ensure the body object is read otherwise we'll leak HTTP connections
         # as per
         # as per
         # https://twistedmatrix.com/documents/current/web/howto/client.html
         # https://twistedmatrix.com/documents/current/web/howto/client.html
-        yield readBody(response)
+        try:
+            # TODO Will this cause the server to think the request was a failure?
+            yield read_body_with_max_size(response, 0)
+        except BodyExceededMaxSize:
+            pass
 
 
         defer.returnValue(response)
         defer.returnValue(response)
 
 
@@ -109,7 +118,11 @@ class SimpleHttpClient(HTTPClient):
         # BrowserLikePolicyForHTTPS context factory which will do regular cert validation
         # BrowserLikePolicyForHTTPS context factory which will do regular cert validation
         # 'like a browser'
         # 'like a browser'
         self.agent = Agent(
         self.agent = Agent(
-            self.sydent.reactor,
+            BlacklistingReactorWrapper(
+                reactor=self.sydent.reactor,
+                ip_whitelist=sydent.ip_whitelist,
+                ip_blacklist=sydent.ip_blacklist,
+            ),
             connectTimeout=15,
             connectTimeout=15,
         )
         )
 
 
@@ -120,6 +133,10 @@ class FederationHttpClient(HTTPClient):
     def __init__(self, sydent):
     def __init__(self, sydent):
         self.sydent = sydent
         self.sydent = sydent
         self.agent = MatrixFederationAgent(
         self.agent = MatrixFederationAgent(
-            self.sydent.reactor,
-            ClientTLSOptionsFactory(sydent.cfg),
+            BlacklistingReactorWrapper(
+                reactor=self.sydent.reactor,
+                ip_whitelist=sydent.ip_whitelist,
+                ip_blacklist=sydent.ip_blacklist,
+            ),
+            ClientTLSOptionsFactory(sydent.cfg) if sydent.use_tls_for_federation else None,
         )
         )

+ 119 - 1
sydent/http/httpcommon.py

@@ -15,11 +15,23 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import logging
 import logging
+from io import BytesIO
 
 
 import twisted.internet.ssl
 import twisted.internet.ssl
+from twisted.internet import defer, protocol
+from twisted.internet.protocol import connectionDone
+from twisted.web._newclient import ResponseDone
+from twisted.web.http import PotentialDataLoss
+from twisted.web.iweb import UNKNOWN_LENGTH
+from twisted.web import server
+
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
+# Arbitrarily limited to 512 KiB.
+MAX_REQUEST_SIZE = 512 * 1024
+
+
 class SslComponents:
 class SslComponents:
     def __init__(self, sydent):
     def __init__(self, sydent):
         self.sydent = sydent
         self.sydent = sydent
@@ -55,10 +67,116 @@ class SslComponents:
                 fp = open(caCertFilename)
                 fp = open(caCertFilename)
                 caCert = twisted.internet.ssl.Certificate.loadPEM(fp.read())
                 caCert = twisted.internet.ssl.Certificate.loadPEM(fp.read())
                 fp.close()
                 fp.close()
-            except:
+            except Exception:
                 logger.warn("Failed to open CA cert file %s", caCertFilename)
                 logger.warn("Failed to open CA cert file %s", caCertFilename)
                 raise
                 raise
             logger.warn("Using custom CA cert file: %s", caCertFilename)
             logger.warn("Using custom CA cert file: %s", caCertFilename)
             return twisted.internet._sslverify.OpenSSLCertificateAuthorities([caCert.original])
             return twisted.internet._sslverify.OpenSSLCertificateAuthorities([caCert.original])
         else:
         else:
             return twisted.internet.ssl.OpenSSLDefaultPaths()
             return twisted.internet.ssl.OpenSSLDefaultPaths()
+
+
+class BodyExceededMaxSize(Exception):
+    """The maximum allowed size of the HTTP body was exceeded."""
+
+
+class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which immediately errors upon receiving data."""
+
+    def __init__(self, deferred):
+        self.deferred = deferred
+
+    def _maybe_fail(self):
+        """
+        Report a max size exceed error and disconnect the first time this is called.
+        """
+        if not self.deferred.called:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
+
+    def dataReceived(self, data) -> None:
+        self._maybe_fail()
+
+    def connectionLost(self, reason) -> None:
+        self._maybe_fail()
+
+
+class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
+    """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
+
+    def __init__(self, deferred, max_size):
+        self.stream = BytesIO()
+        self.deferred = deferred
+        self.length = 0
+        self.max_size = max_size
+
+    def dataReceived(self, data) -> None:
+        # If the deferred was called, bail early.
+        if self.deferred.called:
+            return
+
+        self.stream.write(data)
+        self.length += len(data)
+        # The first time the maximum size is exceeded, error and cancel the
+        # connection. dataReceived might be called again if data was received
+        # in the meantime.
+        if self.max_size is not None and self.length >= self.max_size:
+            self.deferred.errback(BodyExceededMaxSize())
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
+
+    def connectionLost(self, reason=connectionDone) -> None:
+        # If the maximum size was already exceeded, there's nothing to do.
+        if self.deferred.called:
+            return
+
+        if reason.check(ResponseDone):
+            self.deferred.callback(self.stream.getvalue())
+        elif reason.check(PotentialDataLoss):
+            # stolen from https://github.com/twisted/treq/pull/49/files
+            # http://twistedmatrix.com/trac/ticket/4840
+            self.deferred.callback(self.stream.getvalue())
+        else:
+            self.deferred.errback(reason)
+
+
+def read_body_with_max_size(response, max_size):
+    """
+    Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
+
+    If the maximum file size is reached, the returned Deferred will resolve to a
+    Failure with a BodyExceededMaxSize exception.
+
+    Args:
+        response: The HTTP response to read from.
+        max_size: The maximum file size to allow.
+
+    Returns:
+        A Deferred which resolves to the read body.
+    """
+    d = defer.Deferred()
+
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_size is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_size:
+            response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
+            return d
+
+    response.deliverBody(_ReadBodyWithMaxSizeProtocol(d, max_size))
+    return d
+
+
+class SizeLimitingRequest(server.Request):
+    def handleContentChunk(self, data):
+        if self.content.tell() + len(data) > MAX_REQUEST_SIZE:
+            logger.info(
+                "Aborting connection from %s because the request exceeds maximum size",
+                self.client.host)
+            self.transport.abortConnection()
+            return
+
+        return super().handleContentChunk(data)

+ 30 - 30
sydent/http/httpserver.py

@@ -29,6 +29,7 @@ from sydent.http.servlets.authenticated_bind_threepid_servlet import (
 from sydent.http.servlets.authenticated_unbind_threepid_servlet import (
 from sydent.http.servlets.authenticated_unbind_threepid_servlet import (
     AuthenticatedUnbindThreePidServlet,
     AuthenticatedUnbindThreePidServlet,
 )
 )
+from sydent.http.httpcommon import SizeLimitingRequest
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -45,26 +46,17 @@ class ClientApiHttpServer:
         v2 = self.sydent.servlets.v2
         v2 = self.sydent.servlets.v2
 
 
         validate = Resource()
         validate = Resource()
+        validate_v2 = Resource()
         email = Resource()
         email = Resource()
+        email_v2 = Resource()
         msisdn = Resource()
         msisdn = Resource()
-        emailReqCode = self.sydent.servlets.emailRequestCode
-        emailValCode = self.sydent.servlets.emailValidate
-        msisdnReqCode = self.sydent.servlets.msisdnRequestCode
-        msisdnValCode = self.sydent.servlets.msisdnValidate
-        getValidated3pid = self.sydent.servlets.getValidated3pid
-
-        lookup = self.sydent.servlets.lookup
-        bulk_lookup = self.sydent.servlets.bulk_lookup
+        msisdn_v2 = Resource()
 
 
         info = self.sydent.servlets.info
         info = self.sydent.servlets.info
         internalInfo = self.sydent.servlets.internalInfo
         internalInfo = self.sydent.servlets.internalInfo
 
 
-        hash_details = self.sydent.servlets.hash_details
-        lookup_v2 = self.sydent.servlets.lookup_v2
-
         threepid_v1 = Resource()
         threepid_v1 = Resource()
         threepid_v2 = Resource()
         threepid_v2 = Resource()
-        bind = self.sydent.servlets.threepidBind
         unbind = self.sydent.servlets.threepidUnbind
         unbind = self.sydent.servlets.threepidUnbind
 
 
         pubkey = Resource()
         pubkey = Resource()
@@ -72,8 +64,6 @@ class ClientApiHttpServer:
 
 
         userDirectory = Resource()
         userDirectory = Resource()
 
 
-        pk_ed25519 = self.sydent.servlets.pubkey_ed25519
-
         root.putChild(b'_matrix', matrix)
         root.putChild(b'_matrix', matrix)
         matrix.putChild(b'identity', identity)
         matrix.putChild(b'identity', identity)
         identity.putChild(b'api', api)
         identity.putChild(b'api', api)
@@ -83,36 +73,45 @@ class ClientApiHttpServer:
         validate.putChild(b'email', email)
         validate.putChild(b'email', email)
         validate.putChild(b'msisdn', msisdn)
         validate.putChild(b'msisdn', msisdn)
 
 
+        validate_v2.putChild(b'email', email_v2)
+        validate_v2.putChild(b'msisdn', msisdn_v2)
+
         v1.putChild(b'validate', validate)
         v1.putChild(b'validate', validate)
 
 
-        v1.putChild(b'lookup', lookup)
-        v1.putChild(b'bulk_lookup', bulk_lookup)
+        v1.putChild(b'lookup', self.sydent.servlets.lookup)
+        v1.putChild(b'bulk_lookup', self.sydent.servlets.bulk_lookup)
 
 
         v1.putChild(b'pubkey', pubkey)
         v1.putChild(b'pubkey', pubkey)
         pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid)
         pubkey.putChild(b'isvalid', self.sydent.servlets.pubkeyIsValid)
-        pubkey.putChild(b'ed25519:0', pk_ed25519)
+        pubkey.putChild(b'ed25519:0', self.sydent.servlets.pubkey_ed25519)
         pubkey.putChild(b'ephemeral', ephemeralPubkey)
         pubkey.putChild(b'ephemeral', ephemeralPubkey)
         ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid)
         ephemeralPubkey.putChild(b'isvalid', self.sydent.servlets.ephemeralPubkeyIsValid)
 
 
-        threepid_v2.putChild(b'getValidated3pid', getValidated3pid)
-        threepid_v2.putChild(b'bind', bind)
+        threepid_v2.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pidV2)
+        threepid_v2.putChild(b'bind', self.sydent.servlets.threepidBindV2)
         threepid_v2.putChild(b'unbind', unbind)
         threepid_v2.putChild(b'unbind', unbind)
 
 
-        threepid_v1.putChild(b'getValidated3pid', getValidated3pid)
+        threepid_v1.putChild(b'getValidated3pid', self.sydent.servlets.getValidated3pid)
         threepid_v1.putChild(b'unbind', unbind)
         threepid_v1.putChild(b'unbind', unbind)
         if self.sydent.enable_v1_associations:
         if self.sydent.enable_v1_associations:
-            threepid_v1.putChild(b'bind', bind)
+            threepid_v1.putChild(b'bind', self.sydent.servlets.threepidBind)
 
 
         v1.putChild(b'3pid', threepid_v1)
         v1.putChild(b'3pid', threepid_v1)
 
 
         v1.putChild(b'info', info)
         v1.putChild(b'info', info)
         v1.putChild(b'internal-info', internalInfo)
         v1.putChild(b'internal-info', internalInfo)
 
 
-        email.putChild(b'requestToken', emailReqCode)
-        email.putChild(b'submitToken', emailValCode)
+        email.putChild(b'requestToken', self.sydent.servlets.emailRequestCode)
+        email.putChild(b'submitToken', self.sydent.servlets.emailValidate)
+
+        email_v2.putChild(b'requestToken', self.sydent.servlets.emailRequestCodeV2)
+        email_v2.putChild(b'submitToken', self.sydent.servlets.emailValidateV2)
+
+        msisdn.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCode)
+        msisdn.putChild(b'submitToken', self.sydent.servlets.msisdnValidate)
 
 
-        msisdn.putChild(b'requestToken', msisdnReqCode)
-        msisdn.putChild(b'submitToken', msisdnValCode)
+        msisdn_v2.putChild(b'requestToken', self.sydent.servlets.msisdnRequestCodeV2)
+        msisdn_v2.putChild(b'submitToken', self.sydent.servlets.msisdnValidateV2)
 
 
         v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
         v1.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
 
 
@@ -136,15 +135,16 @@ class ClientApiHttpServer:
         account.putChild(b'logout', self.sydent.servlets.logoutServlet)
         account.putChild(b'logout', self.sydent.servlets.logoutServlet)
 
 
         # v2 versions of existing APIs
         # v2 versions of existing APIs
-        v2.putChild(b'validate', validate)
+        v2.putChild(b'validate', validate_v2)
         v2.putChild(b'pubkey', pubkey)
         v2.putChild(b'pubkey', pubkey)
         v2.putChild(b'3pid', threepid_v2)
         v2.putChild(b'3pid', threepid_v2)
-        v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServlet)
-        v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServlet)
-        v2.putChild(b'lookup', lookup_v2)
-        v2.putChild(b'hash_details', hash_details)
+        v2.putChild(b'store-invite', self.sydent.servlets.storeInviteServletV2)
+        v2.putChild(b'sign-ed25519', self.sydent.servlets.blindlySignStuffServletV2)
+        v2.putChild(b'lookup', self.sydent.servlets.lookup_v2)
+        v2.putChild(b'hash_details', self.sydent.servlets.hash_details)
 
 
         self.factory = Site(root)
         self.factory = Site(root)
+        self.factory.requestFactory = SizeLimitingRequest
         self.factory.displayTracebacks = False
         self.factory.displayTracebacks = False
 
 
     def setup(self):
     def setup(self):

+ 7 - 2
sydent/http/matrixfederationagent.py

@@ -25,11 +25,12 @@ from zope.interface import implementer
 from twisted.internet import defer
 from twisted.internet import defer
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
 from twisted.internet.interfaces import IStreamClientEndpoint
 from twisted.internet.interfaces import IStreamClientEndpoint
-from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent, readBody
+from twisted.web.client import URI, Agent, HTTPConnectionPool, RedirectAgent
 from twisted.web.http import stringToDatetime
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IAgent
 from twisted.web.iweb import IAgent
 
 
+from sydent.http.httpcommon import BodyExceededMaxSize, read_body_with_max_size
 from sydent.http.srvresolver import SrvResolver, pick_server_from_list
 from sydent.http.srvresolver import SrvResolver, pick_server_from_list
 from sydent.util import json_decoder
 from sydent.util import json_decoder
 from sydent.util.ttlcache import TTLCache
 from sydent.util.ttlcache import TTLCache
@@ -46,6 +47,9 @@ WELL_KNOWN_INVALID_CACHE_PERIOD = 1 * 3600
 # cap for .well-known cache period
 # cap for .well-known cache period
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 WELL_KNOWN_MAX_CACHE_PERIOD = 48 * 3600
 
 
+# The maximum size (in bytes) to allow a well-known file to be.
+WELL_KNOWN_MAX_SIZE = 50 * 1024  # 50 KiB
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 well_known_cache = TTLCache('well-known')
 well_known_cache = TTLCache('well-known')
 
 
@@ -316,7 +320,7 @@ class MatrixFederationAgent(object):
         logger.info("Fetching %s", uri_str)
         logger.info("Fetching %s", uri_str)
         try:
         try:
             response = yield self._well_known_agent.request(b"GET", uri)
             response = yield self._well_known_agent.request(b"GET", uri)
-            body = yield readBody(response)
+            body = yield read_body_with_max_size(response, WELL_KNOWN_MAX_SIZE)
             if response.code != 200:
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code, ))
                 raise Exception("Non-200 response %s" % (response.code, ))
 
 
@@ -334,6 +338,7 @@ class MatrixFederationAgent(object):
             cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
             cache_period = WELL_KNOWN_INVALID_CACHE_PERIOD
             cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
             cache_period += random.uniform(0, WELL_KNOWN_DEFAULT_CACHE_PERIOD_JITTER)
             defer.returnValue((None, cache_period))
             defer.returnValue((None, cache_period))
+            return
 
 
         result = parsed_body["m.server"].encode("ascii")
         result = parsed_body["m.server"].encode("ascii")
 
 

+ 2 - 3
sydent/http/servlets/accountservlet.py

@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from twisted.web.resource import Resource
 from twisted.web.resource import Resource
 
 
 from sydent.http.servlets import jsonwrap, send_cors
 from sydent.http.servlets import jsonwrap, send_cors
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 
 
 
 class AccountServlet(Resource):
 class AccountServlet(Resource):
@@ -36,7 +36,7 @@ class AccountServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        account = authIfV2(self.sydent, request)
+        account = authV2(self.sydent, request)
 
 
         return {
         return {
             "user_id": account.userId,
             "user_id": account.userId,
@@ -45,4 +45,3 @@ class AccountServlet(Resource):
     def render_OPTIONS(self, request):
     def render_OPTIONS(self, request):
         send_cors(request)
         send_cors(request)
         return b''
         return b''
-

+ 5 - 3
sydent/http/servlets/blindlysignstuffservlet.py

@@ -22,7 +22,7 @@ import signedjson.key
 import signedjson.sign
 import signedjson.sign
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -30,16 +30,18 @@ logger = logging.getLogger(__name__)
 class BlindlySignStuffServlet(Resource):
 class BlindlySignStuffServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
         self.server_name = syd.server_name
         self.server_name = syd.server_name
         self.tokenStore = JoinTokenStore(syd)
         self.tokenStore = JoinTokenStore(syd)
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         args = get_args(request, ("private_key", "token", "mxid"))
         args = get_args(request, ("private_key", "token", "mxid"))
 
 

+ 0 - 3
sydent/http/servlets/bulklookupservlet.py

@@ -21,7 +21,6 @@ from sydent.db.threepid_associations import GlobalAssociationStore
 import logging
 import logging
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
 
 
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -45,8 +44,6 @@ class BulkLookupServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ('threepids',))
         args = get_args(request, ('threepids',))
 
 
         threepids = args['threepids']
         threepids = args['threepids']

+ 17 - 6
sydent/http/servlets/emailservlet.py

@@ -19,7 +19,7 @@ import logging
 
 
 from twisted.web.resource import Resource
 from twisted.web.resource import Resource
 
 
-from sydent.util.stringutils import is_valid_client_secret
+from sydent.util.stringutils import is_valid_client_secret, MAX_EMAIL_ADDRESS_LENGTH
 from sydent.util.emailutils import EmailAddressException, EmailSendException
 from sydent.util.emailutils import EmailAddressException, EmailSendException
 from sydent.validators import (
 from sydent.validators import (
     IncorrectClientSecretException,
     IncorrectClientSecretException,
@@ -31,7 +31,7 @@ from sydent.validators import (
 from sydent.validators.common import validate_next_link
 from sydent.validators.common import validate_next_link
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.http.servlets import get_args, jsonwrap, send_cors
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -39,14 +39,16 @@ logger = logging.getLogger(__name__)
 class EmailRequestCodeServlet(Resource):
 class EmailRequestCodeServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         args = get_args(request, ('email', 'client_secret', 'send_attempt'))
         args = get_args(request, ('email', 'client_secret', 'send_attempt'))
 
 
@@ -61,6 +63,13 @@ class EmailRequestCodeServlet(Resource):
                 'error': 'Invalid client_secret provided'
                 'error': 'Invalid client_secret provided'
             }
             }
 
 
+        if not (0 < len(email) <= MAX_EMAIL_ADDRESS_LENGTH):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'Invalid email provided'
+            }
+
         ipaddress = self.sydent.ip_from_request(request)
         ipaddress = self.sydent.ip_from_request(request)
         brand = self.sydent.brand_from_request(request)
         brand = self.sydent.brand_from_request(request)
 
 
@@ -99,8 +108,9 @@ class EmailRequestCodeServlet(Resource):
 class EmailValidateCodeServlet(Resource):
 class EmailValidateCodeServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
+        self.require_auth = require_auth
 
 
     def render_GET(self, request):
     def render_GET(self, request):
         args = get_args(request, ('nextLink',), required=False)
         args = get_args(request, ('nextLink',), required=False)
@@ -137,7 +147,8 @@ class EmailValidateCodeServlet(Resource):
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         return self.do_validate_request(request)
         return self.do_validate_request(request)
 
 

+ 5 - 3
sydent/http/servlets/getvalidated3pidservlet.py

@@ -18,7 +18,7 @@ from __future__ import absolute_import
 from twisted.web.resource import Resource
 from twisted.web.resource import Resource
 
 
 from sydent.http.servlets import jsonwrap, get_args
 from sydent.http.servlets import jsonwrap, get_args
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.validators import (
 from sydent.validators import (
@@ -32,12 +32,14 @@ from sydent.validators import (
 class GetValidated3pidServlet(Resource):
 class GetValidated3pidServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_GET(self, request):
     def render_GET(self, request):
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         args = get_args(request, ('sid', 'client_secret'))
         args = get_args(request, ('sid', 'client_secret'))
 
 

+ 2 - 2
sydent/http/servlets/hashdetailsservlet.py

@@ -16,7 +16,7 @@
 from __future__ import absolute_import
 from __future__ import absolute_import
 
 
 from twisted.web.resource import Resource
 from twisted.web.resource import Resource
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 
 
 import logging
 import logging
 
 
@@ -48,7 +48,7 @@ class HashDetailsServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        authV2(self.sydent, request)
 
 
         return {
         return {
             "algorithms": self.known_algorithms,
             "algorithms": self.known_algorithms,

+ 2 - 2
sydent/http/servlets/logoutservlet.py

@@ -21,7 +21,7 @@ import logging
 
 
 from sydent.http.servlets import jsonwrap, send_cors
 from sydent.http.servlets import jsonwrap, send_cors
 from sydent.db.accounts import AccountStore
 from sydent.db.accounts import AccountStore
-from sydent.http.auth import authIfV2, tokenFromRequest
+from sydent.http.auth import authV2, tokenFromRequest
 
 
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class LogoutServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request, False)
+        authV2(self.sydent, request, False)
 
 
         token = tokenFromRequest(request)
         token = tokenFromRequest(request)
 
 

+ 1 - 3
sydent/http/servlets/lookupservlet.py

@@ -23,9 +23,9 @@ import logging
 import signedjson.sign
 import signedjson.sign
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
 from sydent.util import json_decoder
 from sydent.util import json_decoder
 
 
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 
 
@@ -48,8 +48,6 @@ class LookupServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ('medium', 'address'))
         args = get_args(request, ('medium', 'address'))
 
 
         medium = args['medium']
         medium = args['medium']

+ 2 - 2
sydent/http/servlets/lookupv2servlet.py

@@ -21,7 +21,7 @@ import logging
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.db.threepid_associations import GlobalAssociationStore
 from sydent.db.threepid_associations import GlobalAssociationStore
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.http.servlets.hashdetailsservlet import HashDetailsServlet
 from sydent.http.servlets.hashdetailsservlet import HashDetailsServlet
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -62,7 +62,7 @@ class LookupV2Servlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        authV2(self.sydent, request)
 
 
         args = get_args(request, ('addresses', 'algorithm', 'pepper'))
         args = get_args(request, ('addresses', 'algorithm', 'pepper'))
 
 

+ 9 - 5
sydent/http/servlets/msisdnservlet.py

@@ -30,7 +30,7 @@ from sydent.validators import (
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.http.servlets import get_args, jsonwrap, send_cors
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.util.stringutils import is_valid_client_secret
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.util.stringutils import is_valid_client_secret
 
 
 
 
@@ -40,14 +40,16 @@ logger = logging.getLogger(__name__)
 class MsisdnRequestCodeServlet(Resource):
 class MsisdnRequestCodeServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         args = get_args(request, ('phone_number', 'country', 'client_secret', 'send_attempt'))
         args = get_args(request, ('phone_number', 'country', 'client_secret', 'send_attempt'))
 
 
@@ -115,8 +117,9 @@ class MsisdnRequestCodeServlet(Resource):
 class MsisdnValidateCodeServlet(Resource):
 class MsisdnValidateCodeServlet(Resource):
     isLeaf = True
     isLeaf = True
 
 
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
+        self.require_auth = require_auth
 
 
     def render_GET(self, request):
     def render_GET(self, request):
         send_cors(request)
         send_cors(request)
@@ -150,7 +153,8 @@ class MsisdnValidateCodeServlet(Resource):
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
+        if self.require_auth:
+            authV2(self.sydent, request)
 
 
         return self.do_validate_request(request)
         return self.do_validate_request(request)
 
 

+ 50 - 3
sydent/http/servlets/registerservlet.py

@@ -25,7 +25,7 @@ from six.moves import urllib
 from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors
 from sydent.http.servlets import get_args, jsonwrap, deferjsonwrap, send_cors
 from sydent.http.httpclient import FederationHttpClient
 from sydent.http.httpclient import FederationHttpClient
 from sydent.users.tokens import issueToken
 from sydent.users.tokens import issueToken
-
+from sydent.util.stringutils import is_valid_matrix_server_name
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -47,15 +47,62 @@ class RegisterServlet(Resource):
 
 
         args = get_args(request, ('matrix_server_name', 'access_token'))
         args = get_args(request, ('matrix_server_name', 'access_token'))
 
 
+        matrix_server = args['matrix_server_name'].lower()
+
+        if not is_valid_matrix_server_name(matrix_server):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'matrix_server_name must be a valid Matrix server name (IP address or hostname)'
+            }
+
         result = yield self.client.get_json(
         result = yield self.client.get_json(
-            "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s" % (
-                args['matrix_server_name'], urllib.parse.quote(args['access_token']),
+            "matrix://%s/_matrix/federation/v1/openid/userinfo?access_token=%s"
+            % (
+                matrix_server,
+                urllib.parse.quote(args['access_token']),
             ),
             ),
+            1024 * 5,
         )
         )
+
         if 'sub' not in result:
         if 'sub' not in result:
             raise Exception("Invalid response from homeserver")
             raise Exception("Invalid response from homeserver")
 
 
         user_id = result['sub']
         user_id = result['sub']
+
+        if not isinstance(user_id, str):
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned a malformed reply'
+            }
+
+        user_id_components = user_id.split(':', 1)
+
+        # Ensure there's a localpart and domain in the returned user ID.
+        if len(user_id_components) != 2:
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned an invalid MXID'
+            }
+
+        user_id_server = user_id_components[1]
+
+        if not is_valid_matrix_server_name(user_id_server):
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned an invalid MXID'
+            }
+
+        if user_id_server != matrix_server:
+            request.setResponseCode(500)
+            return {
+                'errcode': 'M_UNKNOWN',
+                'error': 'The Matrix homeserver returned a MXID belonging to another homeserver'
+            }
+
         tok = yield issueToken(self.sydent, user_id)
         tok = yield issueToken(self.sydent, user_id)
 
 
         # XXX: `token` is correct for the spec, but we released with `access_token`
         # XXX: `token` is correct for the spec, but we released with `access_token`

+ 24 - 6
sydent/http/servlets/store_invite_servlet.py

@@ -28,28 +28,35 @@ from unpaddedbase64 import encode_base64
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.db.invite_tokens import JoinTokenStore
 from sydent.db.threepid_associations import GlobalAssociationStore
 from sydent.db.threepid_associations import GlobalAssociationStore
 
 
-from sydent.http.servlets import get_args, send_cors, jsonwrap
-from sydent.http.auth import authIfV2
+from sydent.http.servlets import get_args, send_cors, jsonwrap, MatrixRestError
+from sydent.http.auth import authV2
 from sydent.util.emailutils import sendEmail
 from sydent.util.emailutils import sendEmail
+from sydent.util.stringutils import MAX_EMAIL_ADDRESS_LENGTH
 
 
 
 
 class StoreInviteServlet(Resource):
 class StoreInviteServlet(Resource):
-    def __init__(self, syd):
+    def __init__(self, syd, require_auth=False):
         self.sydent = syd
         self.sydent = syd
         self.random = random.SystemRandom()
         self.random = random.SystemRandom()
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        authIfV2(self.sydent, request)
-
         args = get_args(request, ("medium", "address", "room_id", "sender",))
         args = get_args(request, ("medium", "address", "room_id", "sender",))
         medium = args["medium"]
         medium = args["medium"]
         address = args["address"]
         address = args["address"]
         roomId = args["room_id"]
         roomId = args["room_id"]
         sender = args["sender"]
         sender = args["sender"]
 
 
+        verified_sender = None
+        if self.require_auth:
+            account = authV2(self.sydent, request)
+            verified_sender = sender
+            if account.userId != sender:
+                raise MatrixRestError(403, "M_UNAUTHORIZED", "'sender' doesn't match")
+
         globalAssocStore = GlobalAssociationStore(self.sydent)
         globalAssocStore = GlobalAssociationStore(self.sydent)
         mxid = globalAssocStore.getMxid(medium, address)
         mxid = globalAssocStore.getMxid(medium, address)
         if mxid:
         if mxid:
@@ -67,6 +74,13 @@ class StoreInviteServlet(Resource):
                 "error": "Didn't understand medium '%s'" % (medium,),
                 "error": "Didn't understand medium '%s'" % (medium,),
             }
             }
 
 
+        if not (0 < len(address) <= MAX_EMAIL_ADDRESS_LENGTH):
+            request.setResponseCode(400)
+            return {
+                'errcode': 'M_INVALID_PARAM',
+                'error': 'Invalid email provided'
+            }
+
         token = self._randomString(128)
         token = self._randomString(128)
 
 
         tokenStore = JoinTokenStore(self.sydent)
         tokenStore = JoinTokenStore(self.sydent)
@@ -103,9 +117,13 @@ class StoreInviteServlet(Resource):
         for k in extra_substitutions:
         for k in extra_substitutions:
             substitutions.setdefault(k, '')
             substitutions.setdefault(k, '')
 
 
+        substitutions["bracketed_verified_sender"] = ""
+        if verified_sender:
+            substitutions["bracketed_verified_sender"] = "(%s) " % (verified_sender,)
+
         substitutions["ephemeral_private_key"] = ephemeralPrivateKeyBase64
         substitutions["ephemeral_private_key"] = ephemeralPrivateKeyBase64
         if substitutions["room_name"] != '':
         if substitutions["room_name"] != '':
-            substitutions["bracketed_room_name"] = "(%s)" % substitutions["room_name"]
+            substitutions["bracketed_room_name"] = "(%s) " % substitutions["room_name"]
 
 
         substitutions["web_client_location"] = self.sydent.default_web_client_location
         substitutions["web_client_location"] = self.sydent.default_web_client_location
         if 'org.matrix.web_client_location' in substitutions:
         if 'org.matrix.web_client_location' in substitutions:

+ 2 - 3
sydent/http/servlets/termsservlet.py

@@ -21,7 +21,7 @@ import logging
 
 
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.terms.terms import get_terms
 from sydent.terms.terms import get_terms
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.db.terms import TermsStore
 from sydent.db.terms import TermsStore
 from sydent.db.accounts import AccountStore
 from sydent.db.accounts import AccountStore
 
 
@@ -54,7 +54,7 @@ class TermsServlet(Resource):
         """
         """
         send_cors(request)
         send_cors(request)
 
 
-        account = authIfV2(self.sydent, request, False)
+        account = authV2(self.sydent, request, False)
 
 
         args = get_args(request, ("user_accepts",))
         args = get_args(request, ("user_accepts",))
 
 
@@ -80,4 +80,3 @@ class TermsServlet(Resource):
     def render_OPTIONS(self, request):
     def render_OPTIONS(self, request):
         send_cors(request)
         send_cors(request)
         return b''
         return b''
-

+ 6 - 3
sydent/http/servlets/threepidbindservlet.py

@@ -21,7 +21,7 @@ from twisted.web.resource import Resource
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
 from sydent.http.servlets import get_args, jsonwrap, send_cors, MatrixRestError
-from sydent.http.auth import authIfV2
+from sydent.http.auth import authV2
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.util.stringutils import is_valid_client_secret
 from sydent.validators import SessionExpiredException, IncorrectClientSecretException, InvalidSessionIdException,\
 from sydent.validators import SessionExpiredException, IncorrectClientSecretException, InvalidSessionIdException,\
     SessionNotValidatedException
     SessionNotValidatedException
@@ -30,14 +30,17 @@ from sydent.threepid.bind import BindingNotPermittedException
 
 
 
 
 class ThreePidBindServlet(Resource):
 class ThreePidBindServlet(Resource):
-    def __init__(self, sydent):
+    def __init__(self, sydent, require_auth=False):
         self.sydent = sydent
         self.sydent = sydent
+        self.require_auth = require_auth
 
 
     @jsonwrap
     @jsonwrap
     def render_POST(self, request):
     def render_POST(self, request):
         send_cors(request)
         send_cors(request)
 
 
-        account = authIfV2(self.sydent, request)
+        account = None
+        if self.require_auth:
+            account = authV2(self.sydent, request)
 
 
         args = get_args(request, ('sid', 'client_secret', 'mxid'))
         args = get_args(request, ('sid', 'client_secret', 'mxid'))
 
 

+ 7 - 2
sydent/http/servlets/threepidunbindservlet.py

@@ -19,7 +19,7 @@ from __future__ import absolute_import
 import json
 import json
 import logging
 import logging
 
 
-from sydent.hs_federation.verifier import NoAuthenticationError
+from sydent.hs_federation.verifier import NoAuthenticationError, InvalidServerName
 from signedjson.sign import SignatureVerifyException
 from signedjson.sign import SignatureVerifyException
 
 
 from sydent.http.servlets import dict_to_json_bytes
 from sydent.http.servlets import dict_to_json_bytes
@@ -144,7 +144,12 @@ class ThreePidUnbindServlet(Resource):
                     request.write(dict_to_json_bytes({'errcode': 'M_FORBIDDEN', 'error': str(ex)}))
                     request.write(dict_to_json_bytes({'errcode': 'M_FORBIDDEN', 'error': str(ex)}))
                     request.finish()
                     request.finish()
                     return
                     return
-                except:
+                except InvalidServerName as ex:
+                    request.setResponseCode(400)
+                    request.write(dict_to_json_bytes({'errcode': 'M_INVALID_PARAM', 'error': str(ex)}))
+                    request.finish()
+                    return
+                except Exception:
                     logger.exception("Exception whilst authenticating unbind request")
                     logger.exception("Exception whilst authenticating unbind request")
                     request.setResponseCode(500)
                     request.setResponseCode(500)
                     request.write(dict_to_json_bytes({'errcode': 'M_UNKNOWN', 'error': 'Internal Server Error'}))
                     request.write(dict_to_json_bytes({'errcode': 'M_UNKNOWN', 'error': 'Internal Server Error'}))

+ 47 - 9
sydent/sydent.py

@@ -24,7 +24,7 @@ import copy
 import logging
 import logging
 import logging.handlers
 import logging.handlers
 import os
 import os
-import re
+from typing import Set
 
 
 import twisted.internet.reactor
 import twisted.internet.reactor
 from twisted.internet import task
 from twisted.internet import task
@@ -49,6 +49,7 @@ from sydent.hs_federation.verifier import Verifier
 
 
 from sydent.util.hash import sha256_and_url_safe_base64
 from sydent.util.hash import sha256_and_url_safe_base64
 from sydent.util.tokenutils import generateAlphanumericTokenOfLength
 from sydent.util.tokenutils import generateAlphanumericTokenOfLength
+from sydent.util.ip_range import generate_ip_set, DEFAULT_IP_RANGE_BLACKLIST
 
 
 from sydent.sign.ed25519 import SydentEd25519
 from sydent.sign.ed25519 import SydentEd25519
 
 
@@ -84,13 +85,6 @@ from sydent.replication.pusher import Pusher
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
-
-def set_from_comma_sep_string(rawstr):
-    if rawstr == '':
-        return set()
-    return {x.strip() for x in rawstr.split(',')}
-
-
 CONFIG_DEFAULTS = {
 CONFIG_DEFAULTS = {
     'general': {
     'general': {
         'server.name': os.environ.get('SYDENT_SERVER_NAME', ''),
         'server.name': os.environ.get('SYDENT_SERVER_NAME', ''),
@@ -136,6 +130,26 @@ CONFIG_DEFAULTS = {
         # Whether clients and homeservers can register an association using v1 endpoints.
         # Whether clients and homeservers can register an association using v1 endpoints.
         'enable_v1_associations': 'true',
         'enable_v1_associations': 'true',
         'delete_tokens_on_bind': 'true',
         'delete_tokens_on_bind': 'true',
+
+        # Prevent outgoing requests from being sent to the following blacklisted
+        # IP address CIDR ranges. If this option is not specified or empty then
+        # it defaults to private IP address ranges.
+        #
+        # The blacklist applies to all outbound requests except replication
+        # requests.
+        #
+        # (0.0.0.0 and :: are always blacklisted, whether or not they are
+        # explicitly listed here, since they correspond to unroutable
+        # addresses.)
+        'ip.blacklist': '',
+
+        # List of IP address CIDR ranges that should be allowed for outbound
+        # requests. This is useful for specifying exceptions to wide-ranging
+        # blacklisted target IP ranges.
+        #
+        # This whitelist overrides `ip.blacklist` and defaults to an empty
+        # list.
+        'ip.whitelist': '',
     },
     },
     'db': {
     'db': {
         'db.file': os.environ.get('SYDENT_DB_PATH', 'sydent.db'),
         'db.file': os.environ.get('SYDENT_DB_PATH', 'sydent.db'),
@@ -245,9 +259,10 @@ CONFIG_DEFAULTS = {
 
 
 
 
 class Sydent:
 class Sydent:
-    def __init__(self, cfg, reactor=twisted.internet.reactor):
+    def __init__(self, cfg, reactor=twisted.internet.reactor, use_tls_for_federation=True):
         self.reactor = reactor
         self.reactor = reactor
         self.config_file = get_config_file_path()
         self.config_file = get_config_file_path()
+        self.use_tls_for_federation = use_tls_for_federation
 
 
         self.cfg = cfg
         self.cfg = cfg
 
 
@@ -328,6 +343,15 @@ class Sydent:
             self.cfg.get("general", "delete_tokens_on_bind")
             self.cfg.get("general", "delete_tokens_on_bind")
         )
         )
 
 
+        ip_blacklist = set_from_comma_sep_string(self.cfg.get("general", "ip.blacklist"))
+        if not ip_blacklist:
+            ip_blacklist = DEFAULT_IP_RANGE_BLACKLIST
+
+        ip_whitelist = set_from_comma_sep_string(self.cfg.get("general", "ip.whitelist"))
+
+        self.ip_blacklist = generate_ip_set(ip_blacklist)
+        self.ip_whitelist = generate_ip_set(ip_whitelist)
+
         self.username_reveal_characters = int(self.cfg.get(
         self.username_reveal_characters = int(self.cfg.get(
             "email", "email.third_party_invite_username_reveal_characters"
             "email", "email.third_party_invite_username_reveal_characters"
         ))
         ))
@@ -400,9 +424,13 @@ class Sydent:
         self.servlets.v1 = V1Servlet(self)
         self.servlets.v1 = V1Servlet(self)
         self.servlets.v2 = V2Servlet(self)
         self.servlets.v2 = V2Servlet(self)
         self.servlets.emailRequestCode = EmailRequestCodeServlet(self)
         self.servlets.emailRequestCode = EmailRequestCodeServlet(self)
+        self.servlets.emailRequestCodeV2 = EmailRequestCodeServlet(self, require_auth=True)
         self.servlets.emailValidate = EmailValidateCodeServlet(self)
         self.servlets.emailValidate = EmailValidateCodeServlet(self)
+        self.servlets.emailValidateV2 = EmailValidateCodeServlet(self, require_auth=True)
         self.servlets.msisdnRequestCode = MsisdnRequestCodeServlet(self)
         self.servlets.msisdnRequestCode = MsisdnRequestCodeServlet(self)
+        self.servlets.msisdnRequestCodeV2 = MsisdnRequestCodeServlet(self, require_auth=True)
         self.servlets.msisdnValidate = MsisdnValidateCodeServlet(self)
         self.servlets.msisdnValidate = MsisdnValidateCodeServlet(self)
+        self.servlets.msisdnValidateV2 = MsisdnValidateCodeServlet(self, require_auth=True)
         self.servlets.lookup = LookupServlet(self)
         self.servlets.lookup = LookupServlet(self)
         self.servlets.bulk_lookup = BulkLookupServlet(self)
         self.servlets.bulk_lookup = BulkLookupServlet(self)
         self.servlets.hash_details = HashDetailsServlet(self, lookup_pepper)
         self.servlets.hash_details = HashDetailsServlet(self, lookup_pepper)
@@ -411,11 +439,15 @@ class Sydent:
         self.servlets.pubkeyIsValid = PubkeyIsValidServlet(self)
         self.servlets.pubkeyIsValid = PubkeyIsValidServlet(self)
         self.servlets.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(self)
         self.servlets.ephemeralPubkeyIsValid = EphemeralPubkeyIsValidServlet(self)
         self.servlets.threepidBind = ThreePidBindServlet(self)
         self.servlets.threepidBind = ThreePidBindServlet(self)
+        self.servlets.threepidBindV2 = ThreePidBindServlet(self, require_auth=True)
         self.servlets.threepidUnbind = ThreePidUnbindServlet(self)
         self.servlets.threepidUnbind = ThreePidUnbindServlet(self)
         self.servlets.replicationPush = ReplicationPushServlet(self)
         self.servlets.replicationPush = ReplicationPushServlet(self)
         self.servlets.getValidated3pid = GetValidated3pidServlet(self)
         self.servlets.getValidated3pid = GetValidated3pidServlet(self)
+        self.servlets.getValidated3pidV2 = GetValidated3pidServlet(self, require_auth=True)
         self.servlets.storeInviteServlet = StoreInviteServlet(self)
         self.servlets.storeInviteServlet = StoreInviteServlet(self)
+        self.servlets.storeInviteServletV2 = StoreInviteServlet(self, require_auth=True)
         self.servlets.blindlySignStuffServlet = BlindlySignStuffServlet(self)
         self.servlets.blindlySignStuffServlet = BlindlySignStuffServlet(self)
+        self.servlets.blindlySignStuffServletV2 = BlindlySignStuffServlet(self, require_auth=True)
         self.servlets.profileReplicationServlet = ProfileReplicationServlet(self)
         self.servlets.profileReplicationServlet = ProfileReplicationServlet(self)
         self.servlets.userDirectorySearchServlet = UserDirectorySearchServlet(self)
         self.servlets.userDirectorySearchServlet = UserDirectorySearchServlet(self)
         self.servlets.termsServlet = TermsServlet(self)
         self.servlets.termsServlet = TermsServlet(self)
@@ -657,6 +689,12 @@ def parse_cfg_bool(value):
     return value.lower() == "true"
     return value.lower() == "true"
 
 
 
 
+def set_from_comma_sep_string(rawstr: str) -> Set[str]:
+    if rawstr == '':
+        return set()
+    return {x.strip() for x in rawstr.split(',')}
+
+
 def run_gc():
 def run_gc():
     threshold = gc.get_threshold()
     threshold = gc.get_threshold()
     counts = gc.get_count()
     counts = gc.get_count()

+ 14 - 2
sydent/threepid/bind.py

@@ -32,6 +32,8 @@ from sydent.http.httpclient import FederationHttpClient
 
 
 from sydent.threepid import ThreepidAssociation
 from sydent.threepid import ThreepidAssociation
 
 
+from sydent.util.stringutils import is_valid_matrix_server_name
+
 from twisted.internet import defer
 from twisted.internet import defer
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
@@ -180,6 +182,7 @@ class ThreepidBinder:
         """
         """
         mxid = assoc["mxid"]
         mxid = assoc["mxid"]
         mxid_parts = mxid.split(":", 1)
         mxid_parts = mxid.split(":", 1)
+
         if len(mxid_parts) != 2:
         if len(mxid_parts) != 2:
             logger.error(
             logger.error(
                 "Can't notify on bind for unparseable mxid %s. Not retrying.",
                 "Can't notify on bind for unparseable mxid %s. Not retrying.",
@@ -187,8 +190,17 @@ class ThreepidBinder:
             )
             )
             return
             return
 
 
+        matrix_server = mxid_parts[1]
+
+        if not is_valid_matrix_server_name(matrix_server):
+            logger.error(
+                "MXID server part '%s' not a valid Matrix server name. Not retrying.",
+                matrix_server,
+            )
+            return
+
         post_url = "matrix://%s/_matrix/federation/v1/3pid/onbind" % (
         post_url = "matrix://%s/_matrix/federation/v1/3pid/onbind" % (
-            mxid_parts[1],
+            matrix_server,
         )
         )
 
 
         logger.info("Making bind callback to: %s", post_url)
         logger.info("Making bind callback to: %s", post_url)
@@ -229,7 +241,7 @@ class ThreepidBinder:
                 "Successfully deleted invite for %s from the store",
                 "Successfully deleted invite for %s from the store",
                 assoc["address"],
                 assoc["address"],
             )
             )
-        except Exception as e:
+        except Exception:
             logger.exception(
             logger.exception(
                 "Couldn't remove invite for %s from the store",
                 "Couldn't remove invite for %s from the store",
                 assoc["address"],
                 assoc["address"],

+ 5 - 0
sydent/util/emailutils.py

@@ -35,6 +35,7 @@ else:
 import email.utils
 import email.utils
 
 
 from sydent.util import time_msec
 from sydent.util import time_msec
+from sydent.util.tokenutils import generateAlphanumericTokenOfLength
 
 
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -75,6 +76,10 @@ def sendEmail(sydent, templateFile, mailTo, substitutions):
         allSubstitutions[k+"_forhtml"] = escape(v)
         allSubstitutions[k+"_forhtml"] = escape(v)
         allSubstitutions[k+"_forurl"] = urllib.parse.quote(v)
         allSubstitutions[k+"_forurl"] = urllib.parse.quote(v)
 
 
+    # We add randomize the multipart boundary to stop user input from
+    # conflicting with it.
+    allSubstitutions["multipart_boundary"] = generateAlphanumericTokenOfLength(32)
+
     mailString = open(templateFile, encoding="utf-8").read() % allSubstitutions
     mailString = open(templateFile, encoding="utf-8").read() % allSubstitutions
     parsedFrom = email.utils.parseaddr(mailFrom)[1]
     parsedFrom = email.utils.parseaddr(mailFrom)[1]
     parsedTo = email.utils.parseaddr(mailTo)[1]
     parsedTo = email.utils.parseaddr(mailTo)[1]

+ 118 - 0
sydent/util/ip_range.py

@@ -0,0 +1,118 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+import itertools
+from typing import Iterable, Optional
+
+from netaddr import AddrFormatError, IPNetwork, IPSet
+
+# IP ranges that are considered private / unroutable / don't make sense.
+DEFAULT_IP_RANGE_BLACKLIST = [
+    # Localhost
+    "127.0.0.0/8",
+    # Private networks.
+    "10.0.0.0/8",
+    "172.16.0.0/12",
+    "192.168.0.0/16",
+    # Carrier grade NAT.
+    "100.64.0.0/10",
+    # Address registry.
+    "192.0.0.0/24",
+    # Link-local networks.
+    "169.254.0.0/16",
+    # Formerly used for 6to4 relay.
+    "192.88.99.0/24",
+    # Testing networks.
+    "198.18.0.0/15",
+    "192.0.2.0/24",
+    "198.51.100.0/24",
+    "203.0.113.0/24",
+    # Multicast.
+    "224.0.0.0/4",
+    # Localhost
+    "::1/128",
+    # Link-local addresses.
+    "fe80::/10",
+    # Unique local addresses.
+    "fc00::/7",
+    # Testing networks.
+    "2001:db8::/32",
+    # Multicast.
+    "ff00::/8",
+    # Site-local addresses
+    "fec0::/10",
+]
+
+
+def generate_ip_set(
+    ip_addresses: Optional[Iterable[str]],
+    extra_addresses: Optional[Iterable[str]] = None,
+    config_path: Optional[Iterable[str]] = None,
+) -> IPSet:
+    """
+    Generate an IPSet from a list of IP addresses or CIDRs.
+
+    Additionally, for each IPv4 network in the list of IP addresses, also
+    includes the corresponding IPv6 networks.
+
+    This includes:
+
+    * IPv4-Compatible IPv6 Address (see RFC 4291, section 2.5.5.1)
+    * IPv4-Mapped IPv6 Address (see RFC 4291, section 2.5.5.2)
+    * 6to4 Address (see RFC 3056, section 2)
+
+    Args:
+        ip_addresses: An iterable of IP addresses or CIDRs.
+        extra_addresses: An iterable of IP addresses or CIDRs.
+        config_path: The path in the configuration for error messages.
+
+    Returns:
+        A new IP set.
+    """
+    result = IPSet()
+    for ip in itertools.chain(ip_addresses or (), extra_addresses or ()):
+        try:
+            network = IPNetwork(ip)
+        except AddrFormatError as e:
+            raise Exception(
+                "Invalid IP range provided: %s." % (ip,), config_path
+            ) from e
+        result.add(network)
+
+        # It is possible that these already exist in the set, but that's OK.
+        if ":" not in str(network):
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=True))
+            result.add(IPNetwork(network).ipv6(ipv4_compatible=False))
+            result.add(_6to4(network))
+
+    return result
+
+
+def _6to4(network: IPNetwork) -> IPNetwork:
+    """Convert an IPv4 network into a 6to4 IPv6 network per RFC 3056."""
+
+    # 6to4 networks consist of:
+    # * 2002 as the first 16 bits
+    # * The first IPv4 address in the network hex-encoded as the next 32 bits
+    # * The new prefix length needs to include the bits from the 2002 prefix.
+    hex_network = hex(network.first)[2:]
+    hex_network = ("0" * (8 - len(hex_network))) + hex_network
+    return IPNetwork(
+        "2002:%s:%s::/%d"
+        % (
+            hex_network[:4],
+            hex_network[4:],
+            16 + network.prefixlen,
+        )
+    )

+ 104 - 3
sydent/util/stringutils.py

@@ -13,18 +13,119 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import re
 import re
+from typing import Optional, Tuple
+
+from twisted.internet.abstract import isIPAddress, isIPv6Address
 
 
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
 # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken
-client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$")
+CLIENT_SECRET_REGEX = re.compile(r"^[0-9a-zA-Z\.=_\-]+$")
+
+# hostname/domain name
+# https://regex101.com/r/OyN1lg/2
+hostname_regex = re.compile(
+    r"^(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)(?:\.[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?)*$",
+    flags=re.IGNORECASE)
+
+# it's unclear what the maximum length of an email address is. RFC3696 (as corrected
+# by errata) says:
+#    the upper limit on address lengths should normally be considered to be 254.
+#
+# In practice, mail servers appear to be more tolerant and allow 400 characters
+# or so. Let's allow 500, which should be plenty for everyone.
+#
+MAX_EMAIL_ADDRESS_LENGTH = 500
 
 
 
 
 def is_valid_client_secret(client_secret):
 def is_valid_client_secret(client_secret):
     """Validate that a given string matches the client_secret regex defined by the spec
     """Validate that a given string matches the client_secret regex defined by the spec
 
 
     :param client_secret: The client_secret to validate
     :param client_secret: The client_secret to validate
-    :type client_secret: unicode
+    :type client_secret: str
 
 
     :return: Whether the client_secret is valid
     :return: Whether the client_secret is valid
     :rtype: bool
     :rtype: bool
     """
     """
-    return client_secret_regex.match(client_secret) is not None
+    return (
+        0 < len(client_secret) <= 255
+        and CLIENT_SECRET_REGEX.match(client_secret) is not None
+    )
+
+
+def is_valid_hostname(string: str) -> bool:
+    """Validate that a given string is a valid hostname or domain name.
+
+    For domain names, this only validates that the form is right (for
+    instance, it doesn't check that the TLD is valid).
+
+    :param string: The string to validate
+    :type string: str
+
+    :return: Whether the input is a valid hostname
+    :rtype: bool
+    """
+
+    return hostname_regex.match(string) is not None
+
+
+def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]:
+    """Split a server name into host/port parts.
+
+    No validation is done on the host part. The port part is validated to be
+    a valid port number.
+
+    Args:
+        server_name: server name to parse
+
+    Returns:
+        host/port parts.
+
+    Raises:
+        ValueError if the server name could not be parsed.
+    """
+    try:
+        if server_name[-1] == "]":
+            # ipv6 literal, hopefully
+            return server_name, None
+
+        host_port = server_name.rsplit(":", 1)
+        host = host_port[0]
+        port = host_port[1] if host_port[1:] else None
+
+        if port:
+            port_num = int(port)
+
+            # exclude things like '08090' or ' 8090'
+            if port != str(port_num) or not (1 <= port_num < 65536):
+                raise ValueError("Invalid port")
+
+        return host, port
+    except Exception:
+        raise ValueError("Invalid server name '%s'" % server_name)
+
+
+def is_valid_matrix_server_name(string: str) -> bool:
+    """Validate that the given string is a valid Matrix server name.
+
+    A string is a valid Matrix server name if it is one of the following, plus
+    an optional port:
+
+    a. IPv4 address
+    b. IPv6 literal (`[IPV6_ADDRESS]`)
+    c. A valid hostname
+
+    :param string: The string to validate
+    :type string: str
+
+    :return: Whether the input is a valid Matrix server name
+    :rtype: bool
+    """
+
+    try:
+        host, port = parse_server_name(string)
+    except ValueError:
+        return False
+
+    valid_ipv4_addr = isIPAddress(host)
+    valid_ipv6_literal = host[0] == "[" and host[-1] == "]" and isIPv6Address(host[1:-1])
+
+    return valid_ipv4_addr or valid_ipv6_literal or is_valid_hostname(host)

+ 2 - 2
tests/test_auth.py

@@ -44,7 +44,7 @@ class AuthTestCase(unittest.TestCase):
         self.sydent.db.commit()
         self.sydent.db.commit()
 
 
     def test_can_read_token_from_headers(self):
     def test_can_read_token_from_headers(self):
-        """Tests that Sydent correct extracts an auth token from request headers"""
+        """Tests that Sydent correctly extracts an auth token from request headers"""
         self.sydent.run()
         self.sydent.run()
 
 
         request, _ = make_request(
         request, _ = make_request(
@@ -59,7 +59,7 @@ class AuthTestCase(unittest.TestCase):
         self.assertEqual(token, self.test_token)
         self.assertEqual(token, self.test_token)
 
 
     def test_can_read_token_from_query_parameters(self):
     def test_can_read_token_from_query_parameters(self):
-        """Tests that Sydent correct extracts an auth token from query parameters"""
+        """Tests that Sydent correctly extracts an auth token from query parameters"""
         self.sydent.run()
         self.sydent.run()
 
 
         request, _ = make_request(
         request, _ = make_request(

+ 243 - 0
tests/test_blacklisting.py

@@ -0,0 +1,243 @@
+#  Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+
+
+from mock import patch
+from netaddr import IPSet
+from twisted.internet import defer
+from twisted.internet.error import DNSLookupError
+from twisted.test.proto_helpers import StringTransport
+from twisted.trial.unittest import TestCase
+from twisted.web.client import Agent
+
+from sydent.http.blacklisting_reactor import BlacklistingReactorWrapper
+from sydent.http.srvresolver import Server
+from tests.utils import make_request, make_sydent
+
+
+class BlacklistingAgentTest(TestCase):
+    def setUp(self):
+        config = {
+            "general": {
+                "ip.blacklist": "5.0.0.0/8",
+                "ip.whitelist": "5.1.1.1",
+            },
+        }
+
+        self.sydent = make_sydent(test_config=config)
+
+        self.reactor = self.sydent.reactor
+
+        self.safe_domain, self.safe_ip = b"safe.test", b"1.2.3.4"
+        self.unsafe_domain, self.unsafe_ip = b"danger.test", b"5.6.7.8"
+        self.allowed_domain, self.allowed_ip = b"allowed.test", b"5.1.1.1"
+
+        # Configure the reactor's DNS resolver.
+        for (domain, ip) in (
+            (self.safe_domain, self.safe_ip),
+            (self.unsafe_domain, self.unsafe_ip),
+            (self.allowed_domain, self.allowed_ip),
+        ):
+            self.reactor.lookups[domain.decode()] = ip.decode()
+            self.reactor.lookups[ip.decode()] = ip.decode()
+
+        self.ip_whitelist = self.sydent.ip_whitelist
+        self.ip_blacklist = self.sydent.ip_blacklist
+
+    def test_reactor(self):
+        """Apply the blacklisting reactor and ensure it properly blocks
+        connections to particular domains and IPs.
+        """
+        agent = Agent(
+            BlacklistingReactorWrapper(
+                self.reactor,
+                ip_whitelist=self.ip_whitelist,
+                ip_blacklist=self.ip_blacklist,
+            ),
+        )
+
+        # The unsafe domains and IPs should be rejected.
+        for domain in (self.unsafe_domain, self.unsafe_ip):
+            self.failureResultOf(
+                agent.request(b"GET", b"http://" + domain), DNSLookupError
+            )
+
+        self.reactor.tcpClients = []
+
+        # The safe domains IPs should be accepted.
+        for domain in (
+            self.safe_domain,
+            self.allowed_domain,
+            self.safe_ip,
+            self.allowed_ip,
+        ):
+            agent.request(b"GET", b"http://" + domain)
+
+            # Grab the latest TCP connection.
+            (
+                host,
+                port,
+                client_factory,
+                _timeout,
+                _bindAddress,
+            ) = self.reactor.tcpClients.pop()
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_allowed_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.allowed_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        transport, protocol = self._get_http_request(
+            self.allowed_ip.decode("ascii"), 443
+        )
+
+        self.assertRegex(
+            transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
+        )
+        self.assertRegex(transport.value(), b"Host: example.com")
+
+        # Send it the HTTP response
+        res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(res_json), res_json)
+        )
+
+        self.assertEqual(channel.code, 200)
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_safe_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.safe_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        transport, protocol = self._get_http_request(self.safe_ip.decode("ascii"), 443)
+
+        self.assertRegex(
+            transport.value(), b"^GET /_matrix/federation/v1/openid/userinfo"
+        )
+        self.assertRegex(transport.value(), b"Host: example.com")
+
+        # Send it the HTTP response
+        res_json = '{ "sub": "@test:example.com" }'.encode("ascii")
+        protocol.dataReceived(
+            b"HTTP/1.1 200 OK\r\n"
+            b"Server: Fake\r\n"
+            b"Content-Type: application/json\r\n"
+            b"Content-Length: %i\r\n"
+            b"\r\n"
+            b"%s" % (len(res_json), res_json)
+        )
+
+        self.assertEqual(channel.code, 200)
+
+    @patch("sydent.http.srvresolver.SrvResolver.resolve_service")
+    def test_federation_client_unsafe_ip(self, resolver):
+        self.sydent.run()
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            {
+                "access_token": "foo",
+                "expires_in": 300,
+                "matrix_server_name": "example.com",
+                "token_type": "Bearer",
+            },
+        )
+
+        resolver.return_value = defer.succeed(
+            [
+                Server(
+                    host=self.unsafe_domain,
+                    port=443,
+                    priority=1,
+                    weight=1,
+                    expires=100,
+                )
+            ]
+        )
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        self.assertNot(self.reactor.tcpClients)
+
+        self.assertEqual(channel.code, 500)
+
+    def _get_http_request(self, expected_host, expected_port):
+        clients = self.reactor.tcpClients
+        (host, port, factory, _timeout, _bindAddress) = clients[-1]
+        self.assertEqual(host, expected_host)
+        self.assertEqual(port, expected_port)
+
+        # complete the connection and wire it up to a fake transport
+        protocol = factory.buildProtocol(None)
+        transport = StringTransport()
+        protocol.makeConnection(transport)
+
+        return transport, protocol

+ 46 - 0
tests/test_register.py

@@ -0,0 +1,46 @@
+# -*- coding: utf-8 -*-
+
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.trial import unittest
+
+from tests.utils import make_request, make_sydent
+
+
+class RegisterTestCase(unittest.TestCase):
+    """Tests Sydent's register servlet"""
+
+    def setUp(self):
+        # Create a new sydent
+        self.sydent = make_sydent()
+
+    def test_sydent_rejects_invalid_hostname(self):
+        """Tests that the /register endpoint rejects an invalid hostname passed as matrix_server_name"""
+        self.sydent.run()
+
+        bad_hostname = "example.com#"
+
+        request, channel = make_request(
+            self.sydent.reactor,
+            "POST",
+            "/_matrix/identity/v2/account/register",
+            content={
+                "matrix_server_name": bad_hostname,
+                "access_token": "foo"
+            })
+
+        request.render(self.sydent.servlets.registerServlet)
+
+        self.assertEqual(channel.code, 400)

+ 37 - 0
tests/test_util.py

@@ -0,0 +1,37 @@
+from twisted.trial import unittest
+from sydent.util.stringutils import is_valid_matrix_server_name
+
+
+class UtilTests(unittest.TestCase):
+    """Tests Sydent utility functions."""
+
+    def test_is_valid_matrix_server_name(self):
+        """Tests that the is_valid_matrix_server_name function accepts only
+        valid hostnames (or domain names), with optional port number.
+        """
+        self.assertTrue(is_valid_matrix_server_name("9.9.9.9"))
+        self.assertTrue(is_valid_matrix_server_name("9.9.9.9:4242"))
+        self.assertTrue(is_valid_matrix_server_name("[::]"))
+        self.assertTrue(is_valid_matrix_server_name("[::]:4242"))
+        self.assertTrue(is_valid_matrix_server_name("[a:b:c::]:4242"))
+
+        self.assertTrue(is_valid_matrix_server_name("example.com"))
+        self.assertTrue(is_valid_matrix_server_name("EXAMPLE.COM"))
+        self.assertTrue(is_valid_matrix_server_name("ExAmPlE.CoM"))
+        self.assertTrue(is_valid_matrix_server_name("example.com:4242"))
+        self.assertTrue(is_valid_matrix_server_name("localhost"))
+        self.assertTrue(is_valid_matrix_server_name("localhost:9000"))
+        self.assertTrue(is_valid_matrix_server_name("a.b.c.d:1234"))
+
+        self.assertFalse(is_valid_matrix_server_name("[:::]"))
+        self.assertFalse(is_valid_matrix_server_name("a:b:c::"))
+
+        self.assertFalse(is_valid_matrix_server_name("example.com:65536"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:0"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:-1"))
+        self.assertFalse(is_valid_matrix_server_name("example.com:a"))
+        self.assertFalse(is_valid_matrix_server_name("example.com: "))
+        self.assertFalse(is_valid_matrix_server_name("example.com:04242"))
+        self.assertFalse(is_valid_matrix_server_name("example.com: 4242"))
+        self.assertFalse(is_valid_matrix_server_name("example.com/example.com"))
+        self.assertFalse(is_valid_matrix_server_name("example.com#example.com"))

+ 42 - 14
tests/utils.py

@@ -2,9 +2,19 @@ import json
 from io import BytesIO
 from io import BytesIO
 import logging
 import logging
 import os
 import os
-
+from typing import Dict
 import attr
 import attr
 from six import text_type
 from six import text_type
+from zope.interface import implementer
+from twisted.internet._resolver import SimpleResolverComplexifier
+from twisted.internet.defer import fail, succeed
+from twisted.internet.error import DNSLookupError
+from twisted.internet.interfaces import (
+    IHostnameResolver,
+    IReactorPluggableNameResolver,
+    IResolverSimple,
+)
+
 from twisted.internet import address
 from twisted.internet import address
 import twisted.logger
 import twisted.logger
 from twisted.web.http_headers import Headers
 from twisted.web.http_headers import Headers
@@ -53,13 +63,13 @@ def make_sydent(test_config={}):
     # Use an in-memory SQLite database. Note that the database isn't cleaned up between
     # Use an in-memory SQLite database. Note that the database isn't cleaned up between
     # tests, so by default the same database will be used for each test if changed to be
     # tests, so by default the same database will be used for each test if changed to be
     # a file on disk.
     # a file on disk.
-    if 'db' not in test_config:
-        test_config['db'] = {'db.file': ':memory:'}
+    if "db" not in test_config:
+        test_config["db"] = {"db.file": ":memory:"}
     else:
     else:
-        test_config['db'].setdefault('db.file', ':memory:')
+        test_config["db"].setdefault("db.file", ":memory:")
 
 
-    reactor = MemoryReactorClock()
-    return Sydent(reactor=reactor, cfg=parse_config_dict(test_config))
+    reactor = ResolvingMemoryReactorClock()
+    return Sydent(reactor=reactor, cfg=parse_config_dict(test_config), use_tls_for_federation=False)
 
 
 
 
 @attr.s
 @attr.s
@@ -149,6 +159,7 @@ class FakeChannel(object):
 
 
 class FakeSite:
 class FakeSite:
     """A fake Twisted Web Site."""
     """A fake Twisted Web Site."""
+
     pass
     pass
 
 
 
 
@@ -191,10 +202,7 @@ def make_request(
         path = path.encode("ascii")
         path = path.encode("ascii")
 
 
     # Decorate it to be the full path, if we're using shorthand
     # Decorate it to be the full path, if we're using shorthand
-    if (
-        shorthand
-        and not path.startswith(b"/_matrix")
-    ):
+    if shorthand and not path.startswith(b"/_matrix"):
         path = b"/_matrix/identity/v2/" + path
         path = b"/_matrix/identity/v2/" + path
         path = path.replace(b"//", b"/")
         path = path.replace(b"//", b"/")
 
 
@@ -253,10 +261,7 @@ def setup_logging():
     """
     """
     root_logger = logging.getLogger()
     root_logger = logging.getLogger()
 
 
-    log_format = (
-        "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s"
-        " - %(message)s"
-    )
+    log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
 
 
     handler = ToTwistedHandler()
     handler = ToTwistedHandler()
     formatter = logging.Formatter(log_format)
     formatter = logging.Formatter(log_format)
@@ -268,3 +273,26 @@ def setup_logging():
 
 
 
 
 setup_logging()
 setup_logging()
+
+
+@implementer(IReactorPluggableNameResolver)
+class ResolvingMemoryReactorClock(MemoryReactorClock):
+    """
+    A MemoryReactorClock that supports name resolution.
+    """
+
+    def __init__(self):
+        lookups = self.lookups = {}  # type: Dict[str, str]
+
+        @implementer(IResolverSimple)
+        class FakeResolver:
+            def getHostByName(self, name, timeout=None):
+                if name not in lookups:
+                    return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
+                return succeed(lookups[name])
+
+        self.nameResolver = SimpleResolverComplexifier(FakeResolver())
+        super().__init__()
+
+    def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
+        raise NotImplementedError()