ui_auth.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2020 Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import json
  16. from typing import Any, Dict, Optional, Union
  17. import attr
  18. from synapse.api.errors import StoreError
  19. from synapse.storage._base import SQLBaseStore
  20. from synapse.types import JsonDict
  21. from synapse.util import stringutils as stringutils
  22. @attr.s
  23. class UIAuthSessionData:
  24. session_id = attr.ib(type=str)
  25. # The dictionary from the client root level, not the 'auth' key.
  26. clientdict = attr.ib(type=JsonDict)
  27. # The URI and method the session was intiatied with. These are checked at
  28. # each stage of the authentication to ensure that the asked for operation
  29. # has not changed.
  30. uri = attr.ib(type=str)
  31. method = attr.ib(type=str)
  32. # A string description of the operation that the current authentication is
  33. # authorising.
  34. description = attr.ib(type=str)
  35. class UIAuthWorkerStore(SQLBaseStore):
  36. """
  37. Manage user interactive authentication sessions.
  38. """
  39. async def create_ui_auth_session(
  40. self, clientdict: JsonDict, uri: str, method: str, description: str,
  41. ) -> UIAuthSessionData:
  42. """
  43. Creates a new user interactive authentication session.
  44. The session can be used to track the stages necessary to authenticate a
  45. user across multiple HTTP requests.
  46. Args:
  47. clientdict:
  48. The dictionary from the client root level, not the 'auth' key.
  49. uri:
  50. The URI this session was initiated with, this is checked at each
  51. stage of the authentication to ensure that the asked for
  52. operation has not changed.
  53. method:
  54. The method this session was initiated with, this is checked at each
  55. stage of the authentication to ensure that the asked for
  56. operation has not changed.
  57. description:
  58. A string description of the operation that the current
  59. authentication is authorising.
  60. Returns:
  61. The newly created session.
  62. Raises:
  63. StoreError if a unique session ID cannot be generated.
  64. """
  65. # The clientdict gets stored as JSON.
  66. clientdict_json = json.dumps(clientdict)
  67. # autogen a session ID and try to create it. We may clash, so just
  68. # try a few times till one goes through, giving up eventually.
  69. attempts = 0
  70. while attempts < 5:
  71. session_id = stringutils.random_string(24)
  72. try:
  73. await self.db.simple_insert(
  74. table="ui_auth_sessions",
  75. values={
  76. "session_id": session_id,
  77. "clientdict": clientdict_json,
  78. "uri": uri,
  79. "method": method,
  80. "description": description,
  81. "serverdict": "{}",
  82. "creation_time": self.hs.get_clock().time_msec(),
  83. },
  84. desc="create_ui_auth_session",
  85. )
  86. return UIAuthSessionData(
  87. session_id, clientdict, uri, method, description
  88. )
  89. except self.db.engine.module.IntegrityError:
  90. attempts += 1
  91. raise StoreError(500, "Couldn't generate a session ID.")
  92. async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData:
  93. """Retrieve a UI auth session.
  94. Args:
  95. session_id: The ID of the session.
  96. Returns:
  97. A dict containing the device information.
  98. Raises:
  99. StoreError if the session is not found.
  100. """
  101. result = await self.db.simple_select_one(
  102. table="ui_auth_sessions",
  103. keyvalues={"session_id": session_id},
  104. retcols=("clientdict", "uri", "method", "description"),
  105. desc="get_ui_auth_session",
  106. )
  107. result["clientdict"] = json.loads(result["clientdict"])
  108. return UIAuthSessionData(session_id, **result)
  109. async def mark_ui_auth_stage_complete(
  110. self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict],
  111. ):
  112. """
  113. Mark a session stage as completed.
  114. Args:
  115. session_id: The ID of the corresponding session.
  116. stage_type: The completed stage type.
  117. result: The result of the stage verification.
  118. Raises:
  119. StoreError if the session cannot be found.
  120. """
  121. # Add (or update) the results of the current stage to the database.
  122. #
  123. # Note that we need to allow for the same stage to complete multiple
  124. # times here so that registration is idempotent.
  125. try:
  126. await self.db.simple_upsert(
  127. table="ui_auth_sessions_credentials",
  128. keyvalues={"session_id": session_id, "stage_type": stage_type},
  129. values={"result": json.dumps(result)},
  130. desc="mark_ui_auth_stage_complete",
  131. )
  132. except self.db.engine.module.IntegrityError:
  133. raise StoreError(400, "Unknown session ID: %s" % (session_id,))
  134. async def get_completed_ui_auth_stages(
  135. self, session_id: str
  136. ) -> Dict[str, Union[str, bool, JsonDict]]:
  137. """
  138. Retrieve the completed stages of a UI authentication session.
  139. Args:
  140. session_id: The ID of the session.
  141. Returns:
  142. The completed stages mapped to the result of the verification of
  143. that auth-type.
  144. """
  145. results = {}
  146. for row in await self.db.simple_select_list(
  147. table="ui_auth_sessions_credentials",
  148. keyvalues={"session_id": session_id},
  149. retcols=("stage_type", "result"),
  150. desc="get_completed_ui_auth_stages",
  151. ):
  152. results[row["stage_type"]] = json.loads(row["result"])
  153. return results
  154. async def set_ui_auth_clientdict(
  155. self, session_id: str, clientdict: JsonDict
  156. ) -> None:
  157. """
  158. Store an updated clientdict for a given session ID.
  159. Args:
  160. session_id: The ID of this session as returned from check_auth
  161. clientdict:
  162. The dictionary from the client root level, not the 'auth' key.
  163. """
  164. # The clientdict gets stored as JSON.
  165. clientdict_json = json.dumps(clientdict)
  166. await self.db.simple_update_one(
  167. table="ui_auth_sessions",
  168. keyvalues={"session_id": session_id},
  169. updatevalues={"clientdict": clientdict_json},
  170. desc="set_ui_auth_client_dict",
  171. )
  172. async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
  173. """
  174. Store a key-value pair into the sessions data associated with this
  175. request. This data is stored server-side and cannot be modified by
  176. the client.
  177. Args:
  178. session_id: The ID of this session as returned from check_auth
  179. key: The key to store the data under
  180. value: The data to store
  181. Raises:
  182. StoreError if the session cannot be found.
  183. """
  184. await self.db.runInteraction(
  185. "set_ui_auth_session_data",
  186. self._set_ui_auth_session_data_txn,
  187. session_id,
  188. key,
  189. value,
  190. )
  191. def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
  192. # Get the current value.
  193. result = self.db.simple_select_one_txn(
  194. txn,
  195. table="ui_auth_sessions",
  196. keyvalues={"session_id": session_id},
  197. retcols=("serverdict",),
  198. )
  199. # Update it and add it back to the database.
  200. serverdict = json.loads(result["serverdict"])
  201. serverdict[key] = value
  202. self.db.simple_update_one_txn(
  203. txn,
  204. table="ui_auth_sessions",
  205. keyvalues={"session_id": session_id},
  206. updatevalues={"serverdict": json.dumps(serverdict)},
  207. )
  208. async def get_ui_auth_session_data(
  209. self, session_id: str, key: str, default: Optional[Any] = None
  210. ) -> Any:
  211. """
  212. Retrieve data stored with set_session_data
  213. Args:
  214. session_id: The ID of this session as returned from check_auth
  215. key: The key to store the data under
  216. default: Value to return if the key has not been set
  217. Raises:
  218. StoreError if the session cannot be found.
  219. """
  220. result = await self.db.simple_select_one(
  221. table="ui_auth_sessions",
  222. keyvalues={"session_id": session_id},
  223. retcols=("serverdict",),
  224. desc="get_ui_auth_session_data",
  225. )
  226. serverdict = json.loads(result["serverdict"])
  227. return serverdict.get(key, default)
  228. class UIAuthStore(UIAuthWorkerStore):
  229. def delete_old_ui_auth_sessions(self, expiration_time: int):
  230. """
  231. Remove sessions which were last used earlier than the expiration time.
  232. Args:
  233. expiration_time: The latest time that is still considered valid.
  234. This is an epoch time in milliseconds.
  235. """
  236. return self.db.runInteraction(
  237. "delete_old_ui_auth_sessions",
  238. self._delete_old_ui_auth_sessions_txn,
  239. expiration_time,
  240. )
  241. def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
  242. # Get the expired sessions.
  243. sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
  244. txn.execute(sql, [expiration_time])
  245. session_ids = [r[0] for r in txn.fetchall()]
  246. # Delete the corresponding completed credentials.
  247. self.db.simple_delete_many_txn(
  248. txn,
  249. table="ui_auth_sessions_credentials",
  250. column="session_id",
  251. iterable=session_ids,
  252. keyvalues={},
  253. )
  254. # Finally, delete the sessions.
  255. self.db.simple_delete_many_txn(
  256. txn,
  257. table="ui_auth_sessions",
  258. column="session_id",
  259. iterable=session_ids,
  260. keyvalues={},
  261. )