|
@@ -13,13 +13,15 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
|
import logging
|
|
|
-import urllib
|
|
|
-from typing import TYPE_CHECKING, Dict, Optional, Tuple
|
|
|
+import urllib.parse
|
|
|
+from typing import TYPE_CHECKING, Dict, Optional
|
|
|
from xml.etree import ElementTree as ET
|
|
|
|
|
|
+import attr
|
|
|
+
|
|
|
from twisted.web.client import PartialDownloadError
|
|
|
|
|
|
-from synapse.api.errors import Codes, LoginError
|
|
|
+from synapse.api.errors import HttpResponseException
|
|
|
from synapse.http.site import SynapseRequest
|
|
|
from synapse.types import UserID, map_username_to_mxid_localpart
|
|
|
|
|
@@ -29,6 +31,26 @@ if TYPE_CHECKING:
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
+class CasError(Exception):
|
|
|
+ """Used to catch errors when validating the CAS ticket.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, error, error_description=None):
|
|
|
+ self.error = error
|
|
|
+ self.error_description = error_description
|
|
|
+
|
|
|
+ def __str__(self):
|
|
|
+ if self.error_description:
|
|
|
+ return "{}: {}".format(self.error, self.error_description)
|
|
|
+ return self.error
|
|
|
+
|
|
|
+
|
|
|
+@attr.s(slots=True, frozen=True)
|
|
|
+class CasResponse:
|
|
|
+ username = attr.ib(type=str)
|
|
|
+ attributes = attr.ib(type=Dict[str, Optional[str]])
|
|
|
+
|
|
|
+
|
|
|
class CasHandler:
|
|
|
"""
|
|
|
Utility class for to handle the response from a CAS SSO service.
|
|
@@ -50,6 +72,8 @@ class CasHandler:
|
|
|
|
|
|
self._http_client = hs.get_proxied_http_client()
|
|
|
|
|
|
+ self._sso_handler = hs.get_sso_handler()
|
|
|
+
|
|
|
def _build_service_param(self, args: Dict[str, str]) -> str:
|
|
|
"""
|
|
|
Generates a value to use as the "service" parameter when redirecting or
|
|
@@ -69,14 +93,20 @@ class CasHandler:
|
|
|
|
|
|
async def _validate_ticket(
|
|
|
self, ticket: str, service_args: Dict[str, str]
|
|
|
- ) -> Tuple[str, Optional[str]]:
|
|
|
+ ) -> CasResponse:
|
|
|
"""
|
|
|
- Validate a CAS ticket with the server, parse the response, and return the user and display name.
|
|
|
+ Validate a CAS ticket with the server, and return the parsed the response.
|
|
|
|
|
|
Args:
|
|
|
ticket: The CAS ticket from the client.
|
|
|
service_args: Additional arguments to include in the service URL.
|
|
|
Should be the same as those passed to `get_redirect_url`.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ CasError: If there's an error parsing the CAS response.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The parsed CAS response.
|
|
|
"""
|
|
|
uri = self._cas_server_url + "/proxyValidate"
|
|
|
args = {
|
|
@@ -89,66 +119,65 @@ class CasHandler:
|
|
|
# Twisted raises this error if the connection is closed,
|
|
|
# even if that's being used old-http style to signal end-of-data
|
|
|
body = pde.response
|
|
|
+ except HttpResponseException as e:
|
|
|
+ description = (
|
|
|
+ (
|
|
|
+ 'Authorization server responded with a "{status}" error '
|
|
|
+ "while exchanging the authorization code."
|
|
|
+ ).format(status=e.code),
|
|
|
+ )
|
|
|
+ raise CasError("server_error", description) from e
|
|
|
|
|
|
- user, attributes = self._parse_cas_response(body)
|
|
|
- displayname = attributes.pop(self._cas_displayname_attribute, None)
|
|
|
-
|
|
|
- for required_attribute, required_value in self._cas_required_attributes.items():
|
|
|
- # If required attribute was not in CAS Response - Forbidden
|
|
|
- if required_attribute not in attributes:
|
|
|
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
|
-
|
|
|
- # Also need to check value
|
|
|
- if required_value is not None:
|
|
|
- actual_value = attributes[required_attribute]
|
|
|
- # If required attribute value does not match expected - Forbidden
|
|
|
- if required_value != actual_value:
|
|
|
- raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
|
|
|
-
|
|
|
- return user, displayname
|
|
|
+ return self._parse_cas_response(body)
|
|
|
|
|
|
- def _parse_cas_response(
|
|
|
- self, cas_response_body: bytes
|
|
|
- ) -> Tuple[str, Dict[str, Optional[str]]]:
|
|
|
+ def _parse_cas_response(self, cas_response_body: bytes) -> CasResponse:
|
|
|
"""
|
|
|
Retrieve the user and other parameters from the CAS response.
|
|
|
|
|
|
Args:
|
|
|
cas_response_body: The response from the CAS query.
|
|
|
|
|
|
+ Raises:
|
|
|
+ CasError: If there's an error parsing the CAS response.
|
|
|
+
|
|
|
Returns:
|
|
|
- A tuple of the user and a mapping of other attributes.
|
|
|
+ The parsed CAS response.
|
|
|
"""
|
|
|
+
|
|
|
+ # Ensure the response is valid.
|
|
|
+ root = ET.fromstring(cas_response_body)
|
|
|
+ if not root.tag.endswith("serviceResponse"):
|
|
|
+ raise CasError(
|
|
|
+ "missing_service_response",
|
|
|
+ "root of CAS response is not serviceResponse",
|
|
|
+ )
|
|
|
+
|
|
|
+ success = root[0].tag.endswith("authenticationSuccess")
|
|
|
+ if not success:
|
|
|
+ raise CasError("unsucessful_response", "Unsuccessful CAS response")
|
|
|
+
|
|
|
+ # Iterate through the nodes and pull out the user and any extra attributes.
|
|
|
user = None
|
|
|
attributes = {}
|
|
|
- try:
|
|
|
- root = ET.fromstring(cas_response_body)
|
|
|
- if not root.tag.endswith("serviceResponse"):
|
|
|
- raise Exception("root of CAS response is not serviceResponse")
|
|
|
- success = root[0].tag.endswith("authenticationSuccess")
|
|
|
- for child in root[0]:
|
|
|
- if child.tag.endswith("user"):
|
|
|
- user = child.text
|
|
|
- if child.tag.endswith("attributes"):
|
|
|
- for attribute in child:
|
|
|
- # ElementTree library expands the namespace in
|
|
|
- # attribute tags to the full URL of the namespace.
|
|
|
- # We don't care about namespace here and it will always
|
|
|
- # be encased in curly braces, so we remove them.
|
|
|
- tag = attribute.tag
|
|
|
- if "}" in tag:
|
|
|
- tag = tag.split("}")[1]
|
|
|
- attributes[tag] = attribute.text
|
|
|
- if user is None:
|
|
|
- raise Exception("CAS response does not contain user")
|
|
|
- except Exception:
|
|
|
- logger.exception("Error parsing CAS response")
|
|
|
- raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
|
|
|
- if not success:
|
|
|
- raise LoginError(
|
|
|
- 401, "Unsuccessful CAS response", errcode=Codes.UNAUTHORIZED
|
|
|
- )
|
|
|
- return user, attributes
|
|
|
+ for child in root[0]:
|
|
|
+ if child.tag.endswith("user"):
|
|
|
+ user = child.text
|
|
|
+ if child.tag.endswith("attributes"):
|
|
|
+ for attribute in child:
|
|
|
+ # ElementTree library expands the namespace in
|
|
|
+ # attribute tags to the full URL of the namespace.
|
|
|
+ # We don't care about namespace here and it will always
|
|
|
+ # be encased in curly braces, so we remove them.
|
|
|
+ tag = attribute.tag
|
|
|
+ if "}" in tag:
|
|
|
+ tag = tag.split("}")[1]
|
|
|
+ attributes[tag] = attribute.text
|
|
|
+
|
|
|
+ # Ensure a user was found.
|
|
|
+ if user is None:
|
|
|
+ raise CasError("no_user", "CAS response does not contain user")
|
|
|
+
|
|
|
+ return CasResponse(user, attributes)
|
|
|
|
|
|
def get_redirect_url(self, service_args: Dict[str, str]) -> str:
|
|
|
"""
|
|
@@ -201,7 +230,68 @@ class CasHandler:
|
|
|
args["redirectUrl"] = client_redirect_url
|
|
|
if session:
|
|
|
args["session"] = session
|
|
|
- username, user_display_name = await self._validate_ticket(ticket, args)
|
|
|
+
|
|
|
+ try:
|
|
|
+ cas_response = await self._validate_ticket(ticket, args)
|
|
|
+ except CasError as e:
|
|
|
+ logger.exception("Could not validate ticket")
|
|
|
+ self._sso_handler.render_error(request, e.error, e.error_description, 401)
|
|
|
+ return
|
|
|
+
|
|
|
+ await self._handle_cas_response(
|
|
|
+ request, cas_response, client_redirect_url, session
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _handle_cas_response(
|
|
|
+ self,
|
|
|
+ request: SynapseRequest,
|
|
|
+ cas_response: CasResponse,
|
|
|
+ client_redirect_url: Optional[str],
|
|
|
+ session: Optional[str],
|
|
|
+ ) -> None:
|
|
|
+ """Handle a CAS response to a ticket request.
|
|
|
+
|
|
|
+ Assumes that the response has been validated. Maps the user onto an MXID,
|
|
|
+ registering them if necessary, and returns a response to the browser.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ request: the incoming request from the browser. We'll respond to it with an
|
|
|
+ HTML page or a redirect
|
|
|
+
|
|
|
+ cas_response: The parsed CAS response.
|
|
|
+
|
|
|
+ client_redirect_url: the redirectUrl parameter from the `/cas/ticket` HTTP request, if given.
|
|
|
+ This should be the same as the redirectUrl from the original `/login/sso/redirect` request.
|
|
|
+
|
|
|
+ session: The session parameter from the `/cas/ticket` HTTP request, if given.
|
|
|
+ This should be the UI Auth session id.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Ensure that the attributes of the logged in user meet the required
|
|
|
+ # attributes.
|
|
|
+ for required_attribute, required_value in self._cas_required_attributes.items():
|
|
|
+ # If required attribute was not in CAS Response - Forbidden
|
|
|
+ if required_attribute not in cas_response.attributes:
|
|
|
+ self._sso_handler.render_error(
|
|
|
+ request,
|
|
|
+ "unauthorised",
|
|
|
+ "You are not authorised to log in here.",
|
|
|
+ 401,
|
|
|
+ )
|
|
|
+ return
|
|
|
+
|
|
|
+ # Also need to check value
|
|
|
+ if required_value is not None:
|
|
|
+ actual_value = cas_response.attributes[required_attribute]
|
|
|
+ # If required attribute value does not match expected - Forbidden
|
|
|
+ if required_value != actual_value:
|
|
|
+ self._sso_handler.render_error(
|
|
|
+ request,
|
|
|
+ "unauthorised",
|
|
|
+ "You are not authorised to log in here.",
|
|
|
+ 401,
|
|
|
+ )
|
|
|
+ return
|
|
|
|
|
|
# Pull out the user-agent and IP from the request.
|
|
|
user_agent = request.get_user_agent("")
|
|
@@ -209,7 +299,7 @@ class CasHandler:
|
|
|
|
|
|
# Get the matrix ID from the CAS username.
|
|
|
user_id = await self._map_cas_user_to_matrix_user(
|
|
|
- username, user_display_name, user_agent, ip_address
|
|
|
+ cas_response, user_agent, ip_address
|
|
|
)
|
|
|
|
|
|
if session:
|
|
@@ -225,18 +315,13 @@ class CasHandler:
|
|
|
)
|
|
|
|
|
|
async def _map_cas_user_to_matrix_user(
|
|
|
- self,
|
|
|
- remote_user_id: str,
|
|
|
- display_name: Optional[str],
|
|
|
- user_agent: str,
|
|
|
- ip_address: str,
|
|
|
+ self, cas_response: CasResponse, user_agent: str, ip_address: str,
|
|
|
) -> str:
|
|
|
"""
|
|
|
Given a CAS username, retrieve the user ID for it and possibly register the user.
|
|
|
|
|
|
Args:
|
|
|
- remote_user_id: The username from the CAS response.
|
|
|
- display_name: The display name from the CAS response.
|
|
|
+ cas_response: The parsed CAS response.
|
|
|
user_agent: The user agent of the client making the request.
|
|
|
ip_address: The IP address of the client making the request.
|
|
|
|
|
@@ -244,15 +329,17 @@ class CasHandler:
|
|
|
The user ID associated with this response.
|
|
|
"""
|
|
|
|
|
|
- localpart = map_username_to_mxid_localpart(remote_user_id)
|
|
|
+ localpart = map_username_to_mxid_localpart(cas_response.username)
|
|
|
user_id = UserID(localpart, self._hostname).to_string()
|
|
|
registered_user_id = await self._auth_handler.check_user_exists(user_id)
|
|
|
|
|
|
+ displayname = cas_response.attributes.get(self._cas_displayname_attribute, None)
|
|
|
+
|
|
|
# If the user does not exist, register it.
|
|
|
if not registered_user_id:
|
|
|
registered_user_id = await self._registration_handler.register_user(
|
|
|
localpart=localpart,
|
|
|
- default_display_name=display_name,
|
|
|
+ default_display_name=displayname,
|
|
|
user_agent_ips=[(user_agent, ip_address)],
|
|
|
)
|
|
|
|