|
@@ -18,7 +18,6 @@ import xml.etree.ElementTree as ET
|
|
|
|
|
|
from six.moves import urllib
|
|
|
|
|
|
-from twisted.internet import defer
|
|
|
from twisted.web.client import PartialDownloadError
|
|
|
|
|
|
from synapse.api.errors import Codes, LoginError, SynapseError
|
|
@@ -130,8 +129,7 @@ class LoginRestServlet(RestServlet):
|
|
|
def on_OPTIONS(self, request):
|
|
|
return 200, {}
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def on_POST(self, request):
|
|
|
+ async def on_POST(self, request):
|
|
|
self._address_ratelimiter.ratelimit(
|
|
|
request.getClientIP(),
|
|
|
time_now_s=self.hs.clock.time(),
|
|
@@ -145,11 +143,11 @@ class LoginRestServlet(RestServlet):
|
|
|
if self.jwt_enabled and (
|
|
|
login_submission["type"] == LoginRestServlet.JWT_TYPE
|
|
|
):
|
|
|
- result = yield self.do_jwt_login(login_submission)
|
|
|
+ result = await self.do_jwt_login(login_submission)
|
|
|
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
|
|
|
- result = yield self.do_token_login(login_submission)
|
|
|
+ result = await self.do_token_login(login_submission)
|
|
|
else:
|
|
|
- result = yield self._do_other_login(login_submission)
|
|
|
+ result = await self._do_other_login(login_submission)
|
|
|
except KeyError:
|
|
|
raise SynapseError(400, "Missing JSON keys.")
|
|
|
|
|
@@ -158,8 +156,7 @@ class LoginRestServlet(RestServlet):
|
|
|
result["well_known"] = well_known_data
|
|
|
return 200, result
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def _do_other_login(self, login_submission):
|
|
|
+ async def _do_other_login(self, login_submission):
|
|
|
"""Handle non-token/saml/jwt logins
|
|
|
|
|
|
Args:
|
|
@@ -219,20 +216,20 @@ class LoginRestServlet(RestServlet):
|
|
|
(
|
|
|
canonical_user_id,
|
|
|
callback_3pid,
|
|
|
- ) = yield self.auth_handler.check_password_provider_3pid(
|
|
|
+ ) = await self.auth_handler.check_password_provider_3pid(
|
|
|
medium, address, login_submission["password"]
|
|
|
)
|
|
|
if canonical_user_id:
|
|
|
# Authentication through password provider and 3pid succeeded
|
|
|
|
|
|
- result = yield self._complete_login(
|
|
|
+ result = await self._complete_login(
|
|
|
canonical_user_id, login_submission, callback_3pid
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
# No password providers were able to handle this 3pid
|
|
|
# Check local store
|
|
|
- user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
|
|
+ user_id = await self.hs.get_datastore().get_user_id_by_threepid(
|
|
|
medium, address
|
|
|
)
|
|
|
if not user_id:
|
|
@@ -280,7 +277,7 @@ class LoginRestServlet(RestServlet):
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
- canonical_user_id, callback = yield self.auth_handler.validate_login(
|
|
|
+ canonical_user_id, callback = await self.auth_handler.validate_login(
|
|
|
identifier["user"], login_submission
|
|
|
)
|
|
|
except LoginError:
|
|
@@ -297,13 +294,12 @@ class LoginRestServlet(RestServlet):
|
|
|
)
|
|
|
raise
|
|
|
|
|
|
- result = yield self._complete_login(
|
|
|
+ result = await self._complete_login(
|
|
|
canonical_user_id, login_submission, callback
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def _complete_login(
|
|
|
+ async def _complete_login(
|
|
|
self, user_id, login_submission, callback=None, create_non_existant_users=False
|
|
|
):
|
|
|
"""Called when we've successfully authed the user and now need to
|
|
@@ -337,15 +333,15 @@ class LoginRestServlet(RestServlet):
|
|
|
)
|
|
|
|
|
|
if create_non_existant_users:
|
|
|
- user_id = yield self.auth_handler.check_user_exists(user_id)
|
|
|
+ user_id = await self.auth_handler.check_user_exists(user_id)
|
|
|
if not user_id:
|
|
|
- user_id = yield self.registration_handler.register_user(
|
|
|
+ user_id = await self.registration_handler.register_user(
|
|
|
localpart=UserID.from_string(user_id).localpart
|
|
|
)
|
|
|
|
|
|
device_id = login_submission.get("device_id")
|
|
|
initial_display_name = login_submission.get("initial_device_display_name")
|
|
|
- device_id, access_token = yield self.registration_handler.register_device(
|
|
|
+ device_id, access_token = await self.registration_handler.register_device(
|
|
|
user_id, device_id, initial_display_name
|
|
|
)
|
|
|
|
|
@@ -357,23 +353,21 @@ class LoginRestServlet(RestServlet):
|
|
|
}
|
|
|
|
|
|
if callback is not None:
|
|
|
- yield callback(result)
|
|
|
+ await callback(result)
|
|
|
|
|
|
return result
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def do_token_login(self, login_submission):
|
|
|
+ async def do_token_login(self, login_submission):
|
|
|
token = login_submission["token"]
|
|
|
auth_handler = self.auth_handler
|
|
|
- user_id = yield auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
|
+ user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
|
|
|
token
|
|
|
)
|
|
|
|
|
|
- result = yield self._complete_login(user_id, login_submission)
|
|
|
+ result = await self._complete_login(user_id, login_submission)
|
|
|
return result
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def do_jwt_login(self, login_submission):
|
|
|
+ async def do_jwt_login(self, login_submission):
|
|
|
token = login_submission.get("token", None)
|
|
|
if token is None:
|
|
|
raise LoginError(
|
|
@@ -397,7 +391,7 @@ class LoginRestServlet(RestServlet):
|
|
|
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
|
|
|
|
|
|
user_id = UserID(user, self.hs.hostname).to_string()
|
|
|
- result = yield self._complete_login(
|
|
|
+ result = await self._complete_login(
|
|
|
user_id, login_submission, create_non_existant_users=True
|
|
|
)
|
|
|
return result
|
|
@@ -460,8 +454,7 @@ class CasTicketServlet(RestServlet):
|
|
|
self._sso_auth_handler = SSOAuthHandler(hs)
|
|
|
self._http_client = hs.get_proxied_http_client()
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def on_GET(self, request):
|
|
|
+ async def on_GET(self, request):
|
|
|
client_redirect_url = parse_string(request, "redirectUrl", required=True)
|
|
|
uri = self.cas_server_url + "/proxyValidate"
|
|
|
args = {
|
|
@@ -469,12 +462,12 @@ class CasTicketServlet(RestServlet):
|
|
|
"service": self.cas_service_url,
|
|
|
}
|
|
|
try:
|
|
|
- body = yield self._http_client.get_raw(uri, args)
|
|
|
+ body = await self._http_client.get_raw(uri, args)
|
|
|
except PartialDownloadError as pde:
|
|
|
# Twisted raises this error if the connection is closed,
|
|
|
# even if that's being used old-http style to signal end-of-data
|
|
|
body = pde.response
|
|
|
- result = yield self.handle_cas_response(request, body, client_redirect_url)
|
|
|
+ result = await self.handle_cas_response(request, body, client_redirect_url)
|
|
|
return result
|
|
|
|
|
|
def handle_cas_response(self, request, cas_response_body, client_redirect_url):
|
|
@@ -555,8 +548,7 @@ class SSOAuthHandler(object):
|
|
|
self._registration_handler = hs.get_registration_handler()
|
|
|
self._macaroon_gen = hs.get_macaroon_generator()
|
|
|
|
|
|
- @defer.inlineCallbacks
|
|
|
- def on_successful_auth(
|
|
|
+ async def on_successful_auth(
|
|
|
self, username, request, client_redirect_url, user_display_name=None
|
|
|
):
|
|
|
"""Called once the user has successfully authenticated with the SSO.
|
|
@@ -582,9 +574,9 @@ class SSOAuthHandler(object):
|
|
|
"""
|
|
|
localpart = map_username_to_mxid_localpart(username)
|
|
|
user_id = UserID(localpart, self._hostname).to_string()
|
|
|
- registered_user_id = yield self._auth_handler.check_user_exists(user_id)
|
|
|
+ registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
|
|
if not registered_user_id:
|
|
|
- registered_user_id = yield self._registration_handler.register_user(
|
|
|
+ registered_user_id = await self._registration_handler.register_user(
|
|
|
localpart=localpart, default_display_name=user_display_name
|
|
|
)
|
|
|
|