saml2.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # Copyright 2018 New Vector Ltd
  2. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from typing import Any, List, Set
  17. from synapse.config.sso import SsoAttributeRequirement
  18. from synapse.types import JsonDict
  19. from synapse.util.check_dependencies import check_requirements
  20. from synapse.util.module_loader import load_module, load_python_module
  21. from ._base import Config, ConfigError
  22. from ._util import validate_config
  23. logger = logging.getLogger(__name__)
  24. DEFAULT_USER_MAPPING_PROVIDER = "synapse.handlers.saml.DefaultSamlMappingProvider"
  25. # The module that DefaultSamlMappingProvider is in was renamed, we want to
  26. # transparently handle both the same.
  27. LEGACY_USER_MAPPING_PROVIDER = (
  28. "synapse.handlers.saml_handler.DefaultSamlMappingProvider"
  29. )
  30. def _dict_merge(merge_dict: dict, into_dict: dict) -> None:
  31. """Do a deep merge of two dicts
  32. Recursively merges `merge_dict` into `into_dict`:
  33. * For keys where both `merge_dict` and `into_dict` have a dict value, the values
  34. are recursively merged
  35. * For all other keys, the values in `into_dict` (if any) are overwritten with
  36. the value from `merge_dict`.
  37. Args:
  38. merge_dict: dict to merge
  39. into_dict: target dict to be modified
  40. """
  41. for k, v in merge_dict.items():
  42. if k not in into_dict:
  43. into_dict[k] = v
  44. continue
  45. current_val = into_dict[k]
  46. if isinstance(v, dict) and isinstance(current_val, dict):
  47. _dict_merge(v, current_val)
  48. continue
  49. # otherwise we just overwrite
  50. into_dict[k] = v
  51. class SAML2Config(Config):
  52. section = "saml2"
  53. def read_config(self, config: JsonDict, **kwargs: Any) -> None:
  54. self.saml2_enabled = False
  55. saml2_config = config.get("saml2_config")
  56. if not saml2_config or not saml2_config.get("enabled", True):
  57. return
  58. if not saml2_config.get("sp_config") and not saml2_config.get("config_path"):
  59. return
  60. check_requirements("saml2")
  61. self.saml2_enabled = True
  62. attribute_requirements = saml2_config.get("attribute_requirements") or []
  63. self.attribute_requirements = _parse_attribute_requirements_def(
  64. attribute_requirements
  65. )
  66. self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
  67. "grandfathered_mxid_source_attribute", "uid"
  68. )
  69. self.saml2_idp_entityid = saml2_config.get("idp_entityid", None)
  70. # user_mapping_provider may be None if the key is present but has no value
  71. ump_dict = saml2_config.get("user_mapping_provider") or {}
  72. # Use the default user mapping provider if not set
  73. ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
  74. if ump_dict.get("module") == LEGACY_USER_MAPPING_PROVIDER:
  75. ump_dict["module"] = DEFAULT_USER_MAPPING_PROVIDER
  76. # Ensure a config is present
  77. ump_dict["config"] = ump_dict.get("config") or {}
  78. if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
  79. # Load deprecated options for use by the default module
  80. old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
  81. if old_mxid_source_attribute:
  82. logger.warning(
  83. "The config option saml2_config.mxid_source_attribute is deprecated. "
  84. "Please use saml2_config.user_mapping_provider.config"
  85. ".mxid_source_attribute instead."
  86. )
  87. ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
  88. old_mxid_mapping = saml2_config.get("mxid_mapping")
  89. if old_mxid_mapping:
  90. logger.warning(
  91. "The config option saml2_config.mxid_mapping is deprecated. Please "
  92. "use saml2_config.user_mapping_provider.config.mxid_mapping instead."
  93. )
  94. ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
  95. # Retrieve an instance of the module's class
  96. # Pass the config dictionary to the module for processing
  97. (
  98. self.saml2_user_mapping_provider_class,
  99. self.saml2_user_mapping_provider_config,
  100. ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
  101. # Ensure loaded user mapping module has defined all necessary methods
  102. # Note parse_config() is already checked during the call to load_module
  103. required_methods = [
  104. "get_saml_attributes",
  105. "saml_response_to_user_attributes",
  106. "get_remote_user_id",
  107. ]
  108. missing_methods = [
  109. method
  110. for method in required_methods
  111. if not hasattr(self.saml2_user_mapping_provider_class, method)
  112. ]
  113. if missing_methods:
  114. raise ConfigError(
  115. "Class specified by saml2_config."
  116. "user_mapping_provider.module is missing required "
  117. "methods: %s" % (", ".join(missing_methods),)
  118. )
  119. # Get the desired saml auth response attributes from the module
  120. saml2_config_dict = self._default_saml_config_dict(
  121. *self.saml2_user_mapping_provider_class.get_saml_attributes(
  122. self.saml2_user_mapping_provider_config
  123. )
  124. )
  125. _dict_merge(
  126. merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
  127. )
  128. config_path = saml2_config.get("config_path", None)
  129. if config_path is not None:
  130. mod = load_python_module(config_path)
  131. config_dict_from_file = getattr(mod, "CONFIG", None)
  132. if config_dict_from_file is None:
  133. raise ConfigError(
  134. "Config path specified by saml2_config.config_path does not "
  135. "have a CONFIG property."
  136. )
  137. _dict_merge(merge_dict=config_dict_from_file, into_dict=saml2_config_dict)
  138. import saml2.config
  139. self.saml2_sp_config = saml2.config.SPConfig()
  140. self.saml2_sp_config.load(saml2_config_dict)
  141. # session lifetime: in milliseconds
  142. self.saml2_session_lifetime = self.parse_duration(
  143. saml2_config.get("saml_session_lifetime", "15m")
  144. )
  145. def _default_saml_config_dict(
  146. self, required_attributes: Set[str], optional_attributes: Set[str]
  147. ) -> JsonDict:
  148. """Generate a configuration dictionary with required and optional attributes that
  149. will be needed to process new user registration
  150. Args:
  151. required_attributes: SAML auth response attributes that are
  152. necessary to function
  153. optional_attributes: SAML auth response attributes that can be used to add
  154. additional information to Synapse user accounts, but are not required
  155. Returns:
  156. A SAML configuration dictionary
  157. """
  158. import saml2
  159. if self.saml2_grandfathered_mxid_source_attribute:
  160. optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
  161. optional_attributes -= required_attributes
  162. public_baseurl = self.root.server.public_baseurl
  163. metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml"
  164. response_url = public_baseurl + "_synapse/client/saml2/authn_response"
  165. return {
  166. "entityid": metadata_url,
  167. "service": {
  168. "sp": {
  169. "endpoints": {
  170. "assertion_consumer_service": [
  171. (response_url, saml2.BINDING_HTTP_POST)
  172. ]
  173. },
  174. "required_attributes": list(required_attributes),
  175. "optional_attributes": list(optional_attributes),
  176. # "name_id_format": saml2.saml.NAMEID_FORMAT_PERSISTENT,
  177. }
  178. },
  179. }
  180. ATTRIBUTE_REQUIREMENTS_SCHEMA = {
  181. "type": "array",
  182. "items": SsoAttributeRequirement.JSON_SCHEMA,
  183. }
  184. def _parse_attribute_requirements_def(
  185. attribute_requirements: Any,
  186. ) -> List[SsoAttributeRequirement]:
  187. validate_config(
  188. ATTRIBUTE_REQUIREMENTS_SCHEMA,
  189. attribute_requirements,
  190. config_path=("saml2_config", "attribute_requirements"),
  191. )
  192. return [SsoAttributeRequirement(**x) for x in attribute_requirements]