utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2017 Vector Creations Ltd
  4. # Copyright 2018-2019 New Vector Ltd
  5. # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
  6. #
  7. # Licensed under the Apache License, Version 2.0 (the "License");
  8. # you may not use this file except in compliance with the License.
  9. # You may obtain a copy of the License at
  10. #
  11. # http://www.apache.org/licenses/LICENSE-2.0
  12. #
  13. # Unless required by applicable law or agreed to in writing, software
  14. # distributed under the License is distributed on an "AS IS" BASIS,
  15. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  16. # See the License for the specific language governing permissions and
  17. # limitations under the License.
  18. import json
  19. import re
  20. import time
  21. import urllib.parse
  22. from typing import Any, Dict, Mapping, MutableMapping, Optional
  23. from mock import patch
  24. import attr
  25. from twisted.web.resource import Resource
  26. from twisted.web.server import Site
  27. from synapse.api.constants import Membership
  28. from synapse.types import JsonDict
  29. from tests.server import FakeChannel, FakeSite, make_request
  30. from tests.test_utils import FakeResponse
  31. from tests.test_utils.html_parsers import TestHtmlParser
  32. @attr.s
  33. class RestHelper:
  34. """Contains extra helper functions to quickly and clearly perform a given
  35. REST action, which isn't the focus of the test.
  36. """
  37. hs = attr.ib()
  38. site = attr.ib(type=Site)
  39. auth_user_id = attr.ib()
  40. def create_room_as(
  41. self,
  42. room_creator: str = None,
  43. is_public: bool = True,
  44. room_version: str = None,
  45. tok: str = None,
  46. expect_code: int = 200,
  47. ) -> str:
  48. """
  49. Create a room.
  50. Args:
  51. room_creator: The user ID to create the room with.
  52. is_public: If True, the `visibility` parameter will be set to the
  53. default (public). Otherwise, the `visibility` parameter will be set
  54. to "private".
  55. room_version: The room version to create the room as. Defaults to Synapse's
  56. default room version.
  57. tok: The access token to use in the request.
  58. expect_code: The expected HTTP response code.
  59. Returns:
  60. The ID of the newly created room.
  61. """
  62. temp_id = self.auth_user_id
  63. self.auth_user_id = room_creator
  64. path = "/_matrix/client/r0/createRoom"
  65. content = {}
  66. if not is_public:
  67. content["visibility"] = "private"
  68. if room_version:
  69. content["room_version"] = room_version
  70. if tok:
  71. path = path + "?access_token=%s" % tok
  72. channel = make_request(
  73. self.hs.get_reactor(),
  74. self.site,
  75. "POST",
  76. path,
  77. json.dumps(content).encode("utf8"),
  78. )
  79. assert channel.result["code"] == b"%d" % expect_code, channel.result
  80. self.auth_user_id = temp_id
  81. if expect_code == 200:
  82. return channel.json_body["room_id"]
  83. def invite(self, room=None, src=None, targ=None, expect_code=200, tok=None):
  84. self.change_membership(
  85. room=room,
  86. src=src,
  87. targ=targ,
  88. tok=tok,
  89. membership=Membership.INVITE,
  90. expect_code=expect_code,
  91. )
  92. def join(self, room=None, user=None, expect_code=200, tok=None):
  93. self.change_membership(
  94. room=room,
  95. src=user,
  96. targ=user,
  97. tok=tok,
  98. membership=Membership.JOIN,
  99. expect_code=expect_code,
  100. )
  101. def leave(self, room=None, user=None, expect_code=200, tok=None):
  102. self.change_membership(
  103. room=room,
  104. src=user,
  105. targ=user,
  106. tok=tok,
  107. membership=Membership.LEAVE,
  108. expect_code=expect_code,
  109. )
  110. def change_membership(
  111. self,
  112. room: str,
  113. src: str,
  114. targ: str,
  115. membership: str,
  116. extra_data: dict = {},
  117. tok: Optional[str] = None,
  118. expect_code: int = 200,
  119. ) -> None:
  120. """
  121. Send a membership state event into a room.
  122. Args:
  123. room: The ID of the room to send to
  124. src: The mxid of the event sender
  125. targ: The mxid of the event's target. The state key
  126. membership: The type of membership event
  127. extra_data: Extra information to include in the content of the event
  128. tok: The user access token to use
  129. expect_code: The expected HTTP response code
  130. """
  131. temp_id = self.auth_user_id
  132. self.auth_user_id = src
  133. path = "/_matrix/client/r0/rooms/%s/state/m.room.member/%s" % (room, targ)
  134. if tok:
  135. path = path + "?access_token=%s" % tok
  136. data = {"membership": membership}
  137. data.update(extra_data)
  138. channel = make_request(
  139. self.hs.get_reactor(),
  140. self.site,
  141. "PUT",
  142. path,
  143. json.dumps(data).encode("utf8"),
  144. )
  145. assert int(channel.result["code"]) == expect_code, (
  146. "Expected: %d, got: %d, resp: %r"
  147. % (expect_code, int(channel.result["code"]), channel.result["body"])
  148. )
  149. self.auth_user_id = temp_id
  150. def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
  151. if body is None:
  152. body = "body_text_here"
  153. content = {"msgtype": "m.text", "body": body}
  154. return self.send_event(
  155. room_id, "m.room.message", content, txn_id, tok, expect_code
  156. )
  157. def send_event(
  158. self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
  159. ):
  160. if txn_id is None:
  161. txn_id = "m%s" % (str(time.time()))
  162. path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
  163. if tok:
  164. path = path + "?access_token=%s" % tok
  165. channel = make_request(
  166. self.hs.get_reactor(),
  167. self.site,
  168. "PUT",
  169. path,
  170. json.dumps(content).encode("utf8"),
  171. )
  172. assert int(channel.result["code"]) == expect_code, (
  173. "Expected: %d, got: %d, resp: %r"
  174. % (expect_code, int(channel.result["code"]), channel.result["body"])
  175. )
  176. return channel.json_body
  177. def _read_write_state(
  178. self,
  179. room_id: str,
  180. event_type: str,
  181. body: Optional[Dict[str, Any]],
  182. tok: str,
  183. expect_code: int = 200,
  184. state_key: str = "",
  185. method: str = "GET",
  186. ) -> Dict:
  187. """Read or write some state from a given room
  188. Args:
  189. room_id:
  190. event_type: The type of state event
  191. body: Body that is sent when making the request. The content of the state event.
  192. If None, the request to the server will have an empty body
  193. tok: The access token to use
  194. expect_code: The HTTP code to expect in the response
  195. state_key:
  196. method: "GET" or "PUT" for reading or writing state, respectively
  197. Returns:
  198. The response body from the server
  199. Raises:
  200. AssertionError: if expect_code doesn't match the HTTP code we received
  201. """
  202. path = "/_matrix/client/r0/rooms/%s/state/%s/%s" % (
  203. room_id,
  204. event_type,
  205. state_key,
  206. )
  207. if tok:
  208. path = path + "?access_token=%s" % tok
  209. # Set request body if provided
  210. content = b""
  211. if body is not None:
  212. content = json.dumps(body).encode("utf8")
  213. channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
  214. assert int(channel.result["code"]) == expect_code, (
  215. "Expected: %d, got: %d, resp: %r"
  216. % (expect_code, int(channel.result["code"]), channel.result["body"])
  217. )
  218. return channel.json_body
  219. def get_state(
  220. self,
  221. room_id: str,
  222. event_type: str,
  223. tok: str,
  224. expect_code: int = 200,
  225. state_key: str = "",
  226. ):
  227. """Gets some state from a room
  228. Args:
  229. room_id:
  230. event_type: The type of state event
  231. tok: The access token to use
  232. expect_code: The HTTP code to expect in the response
  233. state_key:
  234. Returns:
  235. The response body from the server
  236. Raises:
  237. AssertionError: if expect_code doesn't match the HTTP code we received
  238. """
  239. return self._read_write_state(
  240. room_id, event_type, None, tok, expect_code, state_key, method="GET"
  241. )
  242. def send_state(
  243. self,
  244. room_id: str,
  245. event_type: str,
  246. body: Dict[str, Any],
  247. tok: str,
  248. expect_code: int = 200,
  249. state_key: str = "",
  250. ):
  251. """Set some state in a room
  252. Args:
  253. room_id:
  254. event_type: The type of state event
  255. body: Body that is sent when making the request. The content of the state event.
  256. tok: The access token to use
  257. expect_code: The HTTP code to expect in the response
  258. state_key:
  259. Returns:
  260. The response body from the server
  261. Raises:
  262. AssertionError: if expect_code doesn't match the HTTP code we received
  263. """
  264. return self._read_write_state(
  265. room_id, event_type, body, tok, expect_code, state_key, method="PUT"
  266. )
  267. def upload_media(
  268. self,
  269. resource: Resource,
  270. image_data: bytes,
  271. tok: str,
  272. filename: str = "test.png",
  273. expect_code: int = 200,
  274. ) -> dict:
  275. """Upload a piece of test media to the media repo
  276. Args:
  277. resource: The resource that will handle the upload request
  278. image_data: The image data to upload
  279. tok: The user token to use during the upload
  280. filename: The filename of the media to be uploaded
  281. expect_code: The return code to expect from attempting to upload the media
  282. """
  283. image_length = len(image_data)
  284. path = "/_matrix/media/r0/upload?filename=%s" % (filename,)
  285. channel = make_request(
  286. self.hs.get_reactor(),
  287. FakeSite(resource),
  288. "POST",
  289. path,
  290. content=image_data,
  291. access_token=tok,
  292. custom_headers=[(b"Content-Length", str(image_length))],
  293. )
  294. assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
  295. expect_code,
  296. int(channel.result["code"]),
  297. channel.result["body"],
  298. )
  299. return channel.json_body
  300. def login_via_oidc(self, remote_user_id: str) -> JsonDict:
  301. """Log in (as a new user) via OIDC
  302. Returns the result of the final token login.
  303. Requires that "oidc_config" in the homeserver config be set appropriately
  304. (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
  305. "public_base_url".
  306. Also requires the login servlet and the OIDC callback resource to be mounted at
  307. the normal places.
  308. """
  309. client_redirect_url = "https://x"
  310. channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
  311. # expect a confirmation page
  312. assert channel.code == 200, channel.result
  313. # fish the matrix login token out of the body of the confirmation page
  314. m = re.search(
  315. 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
  316. channel.text_body,
  317. )
  318. assert m, channel.text_body
  319. login_token = m.group(1)
  320. # finally, submit the matrix login token to the login API, which gives us our
  321. # matrix access token and device id.
  322. channel = make_request(
  323. self.hs.get_reactor(),
  324. self.site,
  325. "POST",
  326. "/login",
  327. content={"type": "m.login.token", "token": login_token},
  328. )
  329. assert channel.code == 200
  330. return channel.json_body
  331. def auth_via_oidc(
  332. self,
  333. user_info_dict: JsonDict,
  334. client_redirect_url: Optional[str] = None,
  335. ui_auth_session_id: Optional[str] = None,
  336. ) -> FakeChannel:
  337. """Perform an OIDC authentication flow via a mock OIDC provider.
  338. This can be used for either login or user-interactive auth.
  339. Starts by making a request to the relevant synapse redirect endpoint, which is
  340. expected to serve a 302 to the OIDC provider. We then make a request to the
  341. OIDC callback endpoint, intercepting the HTTP requests that will get sent back
  342. to the OIDC provider.
  343. Requires that "oidc_config" in the homeserver config be set appropriately
  344. (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
  345. "public_base_url".
  346. Also requires the login servlet and the OIDC callback resource to be mounted at
  347. the normal places.
  348. Args:
  349. user_info_dict: the remote userinfo that the OIDC provider should present.
  350. Typically this should be '{"sub": "<remote user id>"}'.
  351. client_redirect_url: for a login flow, the client redirect URL to pass to
  352. the login redirect endpoint
  353. ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
  354. of the UI auth.
  355. Returns:
  356. A FakeChannel containing the result of calling the OIDC callback endpoint.
  357. Note that the response code may be a 200, 302 or 400 depending on how things
  358. went.
  359. """
  360. cookies = {}
  361. # if we're doing a ui auth, hit the ui auth redirect endpoint
  362. if ui_auth_session_id:
  363. # can't set the client redirect url for UI Auth
  364. assert client_redirect_url is None
  365. oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
  366. else:
  367. # otherwise, hit the login redirect endpoint
  368. oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
  369. # we now have a URI for the OIDC IdP, but we skip that and go straight
  370. # back to synapse's OIDC callback resource. However, we do need the "state"
  371. # param that synapse passes to the IdP via query params, as well as the cookie
  372. # that synapse passes to the client.
  373. oauth_uri_path, _ = oauth_uri.split("?", 1)
  374. assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
  375. "unexpected SSO URI " + oauth_uri_path
  376. )
  377. return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
  378. def complete_oidc_auth(
  379. self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
  380. ) -> FakeChannel:
  381. """Mock out an OIDC authentication flow
  382. Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
  383. initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
  384. Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
  385. sent back to the OIDC provider.
  386. Requires the OIDC callback resource to be mounted at the normal place.
  387. Args:
  388. oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
  389. from initiate_sso_login or initiate_sso_ui_auth).
  390. cookies: the cookies set by synapse's redirect endpoint, which will be
  391. sent back to the callback endpoint.
  392. user_info_dict: the remote userinfo that the OIDC provider should present.
  393. Typically this should be '{"sub": "<remote user id>"}'.
  394. Returns:
  395. A FakeChannel containing the result of calling the OIDC callback endpoint.
  396. """
  397. _, oauth_uri_qs = oauth_uri.split("?", 1)
  398. params = urllib.parse.parse_qs(oauth_uri_qs)
  399. callback_uri = "%s?%s" % (
  400. urllib.parse.urlparse(params["redirect_uri"][0]).path,
  401. urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
  402. )
  403. # before we hit the callback uri, stub out some methods in the http client so
  404. # that we don't have to handle full HTTPS requests.
  405. # (expected url, json response) pairs, in the order we expect them.
  406. expected_requests = [
  407. # first we get a hit to the token endpoint, which we tell to return
  408. # a dummy OIDC access token
  409. (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
  410. # and then one to the user_info endpoint, which returns our remote user id.
  411. (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
  412. ]
  413. async def mock_req(method: str, uri: str, data=None, headers=None):
  414. (expected_uri, resp_obj) = expected_requests.pop(0)
  415. assert uri == expected_uri
  416. resp = FakeResponse(
  417. code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
  418. )
  419. return resp
  420. with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
  421. # now hit the callback URI with the right params and a made-up code
  422. channel = make_request(
  423. self.hs.get_reactor(),
  424. self.site,
  425. "GET",
  426. callback_uri,
  427. custom_headers=[
  428. ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
  429. ],
  430. )
  431. return channel
  432. def initiate_sso_login(
  433. self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
  434. ) -> str:
  435. """Make a request to the login-via-sso redirect endpoint, and return the target
  436. Assumes that exactly one SSO provider has been configured. Requires the login
  437. servlet to be mounted.
  438. Args:
  439. client_redirect_url: the client redirect URL to pass to the login redirect
  440. endpoint
  441. cookies: any cookies returned will be added to this dict
  442. Returns:
  443. the URI that the client gets redirected to (ie, the SSO server)
  444. """
  445. params = {}
  446. if client_redirect_url:
  447. params["redirectUrl"] = client_redirect_url
  448. # hit the redirect url (which will issue a cookie and state)
  449. channel = make_request(
  450. self.hs.get_reactor(),
  451. self.site,
  452. "GET",
  453. "/_matrix/client/r0/login/sso/redirect?" + urllib.parse.urlencode(params),
  454. )
  455. assert channel.code == 302
  456. channel.extract_cookies(cookies)
  457. return channel.headers.getRawHeaders("Location")[0]
  458. def initiate_sso_ui_auth(
  459. self, ui_auth_session_id: str, cookies: MutableMapping[str, str]
  460. ) -> str:
  461. """Make a request to the ui-auth-via-sso endpoint, and return the target
  462. Assumes that exactly one SSO provider has been configured. Requires the
  463. AuthRestServlet to be mounted.
  464. Args:
  465. ui_auth_session_id: the session id of the UI auth
  466. cookies: any cookies returned will be added to this dict
  467. Returns:
  468. the URI that the client gets linked to (ie, the SSO server)
  469. """
  470. sso_redirect_endpoint = (
  471. "/_matrix/client/r0/auth/m.login.sso/fallback/web?"
  472. + urllib.parse.urlencode({"session": ui_auth_session_id})
  473. )
  474. # hit the redirect url (which will issue a cookie and state)
  475. channel = make_request(
  476. self.hs.get_reactor(), self.site, "GET", sso_redirect_endpoint
  477. )
  478. # that should serve a confirmation page
  479. assert channel.code == 200, channel.text_body
  480. channel.extract_cookies(cookies)
  481. # parse the confirmation page to fish out the link.
  482. p = TestHtmlParser()
  483. p.feed(channel.text_body)
  484. p.close()
  485. assert len(p.links) == 1, "not exactly one link in confirmation page"
  486. oauth_uri = p.links[0]
  487. return oauth_uri
  488. # an 'oidc_config' suitable for login_via_oidc.
  489. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
  490. TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
  491. TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
  492. TEST_OIDC_CONFIG = {
  493. "enabled": True,
  494. "discover": False,
  495. "issuer": "https://issuer.test",
  496. "client_id": "test-client-id",
  497. "client_secret": "test-client-secret",
  498. "scopes": ["profile"],
  499. "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
  500. "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
  501. "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
  502. "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
  503. }