|
@@ -13,7 +13,6 @@
|
|
|
# limitations under the License.
|
|
|
import time
|
|
|
import urllib.parse
|
|
|
-from http import HTTPStatus
|
|
|
from typing import Any, Dict, List, Optional
|
|
|
from unittest.mock import Mock
|
|
|
from urllib.parse import urlencode
|
|
@@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
if i == 5:
|
|
|
- self.assertEqual(channel.result["code"], b"429", channel.result)
|
|
|
+ self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
else:
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
@@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
@override_config(
|
|
|
{
|
|
@@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
if i == 5:
|
|
|
- self.assertEqual(channel.result["code"], b"429", channel.result)
|
|
|
+ self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
else:
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
@@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
@override_config(
|
|
|
{
|
|
@@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
if i == 5:
|
|
|
- self.assertEqual(channel.result["code"], b"429", channel.result)
|
|
|
+ self.assertEqual(channel.code, 429, msg=channel.result)
|
|
|
retry_after_ms = int(channel.json_body["retry_after_ms"])
|
|
|
else:
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
|
|
|
# Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
|
|
|
# than 1min.
|
|
@@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
|
|
|
@override_config({"session_lifetime": "24h"})
|
|
|
def test_soft_logout(self) -> None:
|
|
@@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
# we shouldn't be able to make requests without an access token
|
|
|
channel = self.make_request(b"GET", TEST_URL)
|
|
|
- self.assertEqual(channel.result["code"], b"401", channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
|
|
|
|
|
|
# log in as normal
|
|
@@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
access_token = channel.json_body["access_token"]
|
|
|
device_id = channel.json_body["device_id"]
|
|
|
|
|
|
# we should now be able to make requests with the access token
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
# time passes
|
|
|
self.reactor.advance(24 * 3600)
|
|
|
|
|
|
# ... and we should be soft-logouted
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
|
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
|
|
|
|
@@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
# more requests with the expired token should still return a soft-logout
|
|
|
self.reactor.advance(3600)
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
|
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
|
|
|
|
@@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
self._delete_device(access_token_2, "kermit", "monkey", device_id)
|
|
|
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
|
|
self.assertEqual(channel.json_body["soft_logout"], False)
|
|
|
|
|
@@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.make_request(
|
|
|
b"DELETE", "devices/" + device_id, access_token=access_token
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
# check it's a UI-Auth fail
|
|
|
self.assertEqual(
|
|
|
set(channel.json_body.keys()),
|
|
@@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
access_token=access_token,
|
|
|
content={"auth": auth},
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
@override_config({"session_lifetime": "24h"})
|
|
|
def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
|
|
@@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
# we should now be able to make requests with the access token
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
# time passes
|
|
|
self.reactor.advance(24 * 3600)
|
|
|
|
|
|
# ... and we should be soft-logouted
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
|
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
|
|
|
|
|
# Now try to hard logout this session
|
|
|
channel = self.make_request(b"POST", "/logout", access_token=access_token)
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
@override_config({"session_lifetime": "24h"})
|
|
|
def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
|
|
@@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
# we should now be able to make requests with the access token
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
# time passes
|
|
|
self.reactor.advance(24 * 3600)
|
|
|
|
|
|
# ... and we should be soft-logouted
|
|
|
channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
|
|
|
self.assertEqual(channel.json_body["soft_logout"], True)
|
|
|
|
|
|
# Now try to hard log out all of the user's sessions
|
|
|
channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
def test_login_with_overly_long_device_id_fails(self) -> None:
|
|
|
self.register_user("mickey", "cheese")
|
|
@@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
def test_get_login_flows(self) -> None:
|
|
|
"""GET /login should return password and SSO flows"""
|
|
|
channel = self.make_request("GET", "/_matrix/client/r0/login")
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
expected_flow_types = [
|
|
|
"m.login.cas",
|
|
@@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
"""/login/sso/redirect should redirect to an identity picker"""
|
|
|
# first hit the redirect url, which should redirect to our idp picker
|
|
|
channel = self._make_sso_redirect_request(None)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
uri = location_headers[0]
|
|
|
|
|
|
# hitting that picker should give us some HTML
|
|
|
channel = self.make_request("GET", uri)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
|
|
|
# parse the form to check it has fields assumed elsewhere in this class
|
|
|
html = channel.result["body"].decode("utf-8")
|
|
@@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
+ "&idp=cas",
|
|
|
shorthand=False,
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
cas_uri = location_headers[0]
|
|
@@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
|
|
+ "&idp=saml",
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
saml_uri = location_headers[0]
|
|
@@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
+ urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
|
|
|
+ "&idp=oidc",
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
oidc_uri = location_headers[0]
|
|
@@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
|
|
|
|
|
|
# that should serve a confirmation page
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
|
|
assert content_type_headers
|
|
|
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
|
@@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
"/login",
|
|
|
content={"type": "m.login.token", "token": login_token},
|
|
|
)
|
|
|
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
|
|
+ self.assertEqual(chan.code, 200, chan.result)
|
|
|
self.assertEqual(chan.json_body["user_id"], "@user1:test")
|
|
|
|
|
|
def test_multi_sso_redirect_to_unknown(self) -> None:
|
|
@@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
|
|
"GET",
|
|
|
"/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
|
|
|
)
|
|
|
- self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
|
|
|
+ self.assertEqual(channel.code, 400, channel.result)
|
|
|
|
|
|
def test_client_idp_redirect_to_unknown(self) -> None:
|
|
|
"""If the client tries to pick an unknown IdP, return a 404"""
|
|
|
channel = self._make_sso_redirect_request("xxx")
|
|
|
- self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 404, channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
|
|
|
|
|
def test_client_idp_redirect_to_oidc(self) -> None:
|
|
|
"""If the client pick a known IdP, redirect to it"""
|
|
|
channel = self._make_sso_redirect_request("oidc")
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
oidc_uri = location_headers[0]
|
|
@@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
|
|
channel = self.make_request("GET", cas_ticket_url)
|
|
|
|
|
|
# Test that the response is HTML.
|
|
|
- self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, channel.result)
|
|
|
content_type_header_value = ""
|
|
|
for header in channel.result.get("headers", []):
|
|
|
if header[0] == b"Content-Type":
|
|
@@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
def test_login_jwt_valid_registered(self) -> None:
|
|
|
self.register_user("kermit", "monkey")
|
|
|
channel = self.jwt_login({"sub": "kermit"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
def test_login_jwt_valid_unregistered(self) -> None:
|
|
|
channel = self.jwt_login({"sub": "frog"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
|
|
|
|
|
def test_login_jwt_invalid_signature(self) -> None:
|
|
|
channel = self.jwt_login({"sub": "frog"}, "notsecret")
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
def test_login_jwt_expired(self) -> None:
|
|
|
channel = self.jwt_login({"sub": "frog", "exp": 864000})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
def test_login_jwt_not_before(self) -> None:
|
|
|
now = int(time.time())
|
|
|
channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
def test_login_no_sub(self) -> None:
|
|
|
channel = self.jwt_login({"username": "root"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(channel.json_body["error"], "Invalid JWT")
|
|
|
|
|
@@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
"""Test validating the issuer claim."""
|
|
|
# A valid issuer.
|
|
|
channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
# An invalid issuer.
|
|
|
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
# Not providing an issuer.
|
|
|
channel = self.jwt_login({"sub": "kermit"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
def test_login_iss_no_config(self) -> None:
|
|
|
"""Test providing an issuer claim without requiring it in the configuration."""
|
|
|
channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
@override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
|
|
@@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
"""Test validating the audience claim."""
|
|
|
# A valid audience.
|
|
|
channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
# An invalid audience.
|
|
|
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
# Not providing an audience.
|
|
|
channel = self.jwt_login({"sub": "kermit"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
def test_login_aud_no_config(self) -> None:
|
|
|
"""Test providing an audience without requiring it in the configuration."""
|
|
|
channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
|
|
|
def test_login_default_sub(self) -> None:
|
|
|
"""Test reading user ID from the default subject claim."""
|
|
|
channel = self.jwt_login({"sub": "kermit"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
@override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
|
|
|
def test_login_custom_sub(self) -> None:
|
|
|
"""Test reading user ID from a custom subject claim."""
|
|
|
channel = self.jwt_login({"username": "frog"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@frog:test")
|
|
|
|
|
|
def test_login_no_token(self) -> None:
|
|
|
params = {"type": "org.matrix.login.jwt"}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
|
|
|
|
|
@@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
|
|
|
|
|
|
def test_login_jwt_valid(self) -> None:
|
|
|
channel = self.jwt_login({"sub": "kermit"})
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["user_id"], "@kermit:test")
|
|
|
|
|
|
def test_login_jwt_invalid_signature(self) -> None:
|
|
|
channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
|
|
|
self.assertEqual(
|
|
|
channel.json_body["error"],
|
|
@@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
b"POST", LOGIN_URL, params, access_token=self.service.token
|
|
|
)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
def test_login_appservice_user_bot(self) -> None:
|
|
|
"""Test that the appservice bot can use /login"""
|
|
@@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
b"POST", LOGIN_URL, params, access_token=self.service.token
|
|
|
)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"200", channel.result)
|
|
|
+ self.assertEqual(channel.code, 200, msg=channel.result)
|
|
|
|
|
|
def test_login_appservice_wrong_user(self) -> None:
|
|
|
"""Test that non-as users cannot login with the as token"""
|
|
@@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
b"POST", LOGIN_URL, params, access_token=self.service.token
|
|
|
)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
|
|
|
def test_login_appservice_wrong_as(self) -> None:
|
|
|
"""Test that as users cannot login with wrong as token"""
|
|
@@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
b"POST", LOGIN_URL, params, access_token=self.another_service.token
|
|
|
)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"403", channel.result)
|
|
|
+ self.assertEqual(channel.code, 403, msg=channel.result)
|
|
|
|
|
|
def test_login_appservice_no_token(self) -> None:
|
|
|
"""Test that users must provide a token when using the appservice
|
|
@@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
|
|
}
|
|
|
channel = self.make_request(b"POST", LOGIN_URL, params)
|
|
|
|
|
|
- self.assertEqual(channel.result["code"], b"401", channel.result)
|
|
|
+ self.assertEqual(channel.code, 401, msg=channel.result)
|
|
|
|
|
|
|
|
|
@skip_unless(HAS_OIDC, "requires OIDC")
|
|
@@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
|
|
)
|
|
|
|
|
|
# that should redirect to the username picker
|
|
|
- self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
|
|
|
+ self.assertEqual(channel.code, 302, channel.result)
|
|
|
location_headers = channel.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
picker_url = location_headers[0]
|
|
@@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
|
|
("Content-Length", str(len(content))),
|
|
|
],
|
|
|
)
|
|
|
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
|
|
+ self.assertEqual(chan.code, 302, chan.result)
|
|
|
location_headers = chan.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
|
|
@@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
|
|
path=location_headers[0],
|
|
|
custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
|
|
|
)
|
|
|
- self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
|
|
|
+ self.assertEqual(chan.code, 302, chan.result)
|
|
|
location_headers = chan.headers.getRawHeaders("Location")
|
|
|
assert location_headers
|
|
|
|
|
@@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
|
|
"/login",
|
|
|
content={"type": "m.login.token", "token": login_token},
|
|
|
)
|
|
|
- self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
|
|
|
+ self.assertEqual(chan.code, 200, chan.result)
|
|
|
self.assertEqual(chan.json_body["user_id"], "@bobby:test")
|