test_api.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2022 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, List, Mapping, Optional, Sequence, Union
  15. from unittest.mock import Mock
  16. from twisted.test.proto_helpers import MemoryReactor
  17. from synapse.appservice import ApplicationService
  18. from synapse.server import HomeServer
  19. from synapse.types import JsonDict
  20. from synapse.util import Clock
  21. from tests import unittest
  22. from tests.unittest import override_config
  23. PROTOCOL = "myproto"
  24. TOKEN = "myastoken"
  25. URL = "http://mytestservice"
  26. class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
  27. def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
  28. self.api = hs.get_application_service_api()
  29. self.service = ApplicationService(
  30. id="unique_identifier",
  31. sender="@as:test",
  32. url=URL,
  33. token="unused",
  34. hs_token=TOKEN,
  35. )
  36. def test_query_3pe_authenticates_token_via_header(self) -> None:
  37. """
  38. Tests that 3pe queries to the appservice are authenticated
  39. with the appservice's token.
  40. """
  41. SUCCESS_RESULT_USER = [
  42. {
  43. "protocol": PROTOCOL,
  44. "userid": "@a:user",
  45. "fields": {
  46. "more": "fields",
  47. },
  48. }
  49. ]
  50. SUCCESS_RESULT_LOCATION = [
  51. {
  52. "protocol": PROTOCOL,
  53. "alias": "#a:room",
  54. "fields": {
  55. "more": "fields",
  56. },
  57. }
  58. ]
  59. URL_USER = f"{URL}/_matrix/app/v1/thirdparty/user/{PROTOCOL}"
  60. URL_LOCATION = f"{URL}/_matrix/app/v1/thirdparty/location/{PROTOCOL}"
  61. self.request_url = None
  62. async def get_json(
  63. url: str,
  64. args: Mapping[Any, Any],
  65. headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
  66. ) -> List[JsonDict]:
  67. # Ensure the access token is passed as a header.
  68. if not headers or not headers.get("Authorization"):
  69. raise RuntimeError("Access token not provided")
  70. # ... and not as a query param
  71. if b"access_token" in args:
  72. raise RuntimeError(
  73. "Access token should not be passed as a query param."
  74. )
  75. self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
  76. self.request_url = url
  77. if url == URL_USER:
  78. return SUCCESS_RESULT_USER
  79. elif url == URL_LOCATION:
  80. return SUCCESS_RESULT_LOCATION
  81. else:
  82. raise RuntimeError(
  83. "URL provided was invalid. This should never be seen."
  84. )
  85. # We assign to a method, which mypy doesn't like.
  86. self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment]
  87. result = self.get_success(
  88. self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
  89. )
  90. self.assertEqual(self.request_url, URL_USER)
  91. self.assertEqual(result, SUCCESS_RESULT_USER)
  92. result = self.get_success(
  93. self.api.query_3pe(
  94. self.service, "location", PROTOCOL, {b"some": [b"field"]}
  95. )
  96. )
  97. self.assertEqual(self.request_url, URL_LOCATION)
  98. self.assertEqual(result, SUCCESS_RESULT_LOCATION)
  99. @override_config({"use_appservice_legacy_authorization": True})
  100. def test_query_3pe_authenticates_token_via_param(self) -> None:
  101. """
  102. Tests that 3pe queries to the appservice are authenticated
  103. with the appservice's token.
  104. """
  105. SUCCESS_RESULT_USER = [
  106. {
  107. "protocol": PROTOCOL,
  108. "userid": "@a:user",
  109. "fields": {
  110. "more": "fields",
  111. },
  112. }
  113. ]
  114. SUCCESS_RESULT_LOCATION = [
  115. {
  116. "protocol": PROTOCOL,
  117. "alias": "#a:room",
  118. "fields": {
  119. "more": "fields",
  120. },
  121. }
  122. ]
  123. URL_USER = f"{URL}/_matrix/app/v1/thirdparty/user/{PROTOCOL}"
  124. URL_LOCATION = f"{URL}/_matrix/app/v1/thirdparty/location/{PROTOCOL}"
  125. self.request_url = None
  126. async def get_json(
  127. url: str,
  128. args: Mapping[Any, Any],
  129. headers: Optional[
  130. Mapping[Union[str, bytes], Sequence[Union[str, bytes]]]
  131. ] = None,
  132. ) -> List[JsonDict]:
  133. # Ensure the access token is passed as a both a query param and in the headers.
  134. if not args.get(b"access_token"):
  135. raise RuntimeError("Access token should be provided in query params.")
  136. if not headers or not headers.get("Authorization"):
  137. raise RuntimeError("Access token should be provided in auth headers.")
  138. self.assertEqual(args.get(b"access_token"), TOKEN)
  139. self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
  140. self.request_url = url
  141. if url == URL_USER:
  142. return SUCCESS_RESULT_USER
  143. elif url == URL_LOCATION:
  144. return SUCCESS_RESULT_LOCATION
  145. else:
  146. raise RuntimeError(
  147. "URL provided was invalid. This should never be seen."
  148. )
  149. # We assign to a method, which mypy doesn't like.
  150. self.api.get_json = Mock(side_effect=get_json) # type: ignore[assignment]
  151. result = self.get_success(
  152. self.api.query_3pe(self.service, "user", PROTOCOL, {b"some": [b"field"]})
  153. )
  154. self.assertEqual(self.request_url, URL_USER)
  155. self.assertEqual(result, SUCCESS_RESULT_USER)
  156. result = self.get_success(
  157. self.api.query_3pe(
  158. self.service, "location", PROTOCOL, {b"some": [b"field"]}
  159. )
  160. )
  161. self.assertEqual(self.request_url, URL_LOCATION)
  162. self.assertEqual(result, SUCCESS_RESULT_LOCATION)
  163. def test_claim_keys(self) -> None:
  164. """
  165. Tests that the /keys/claim response is properly parsed for missing
  166. keys.
  167. """
  168. RESPONSE: JsonDict = {
  169. "@alice:example.org": {
  170. "DEVICE_1": {
  171. "signed_curve25519:AAAAHg": {
  172. # We don't really care about the content of the keys,
  173. # they get passed back transparently.
  174. },
  175. "signed_curve25519:BBBBHg": {},
  176. },
  177. "DEVICE_2": {"signed_curve25519:CCCCHg": {}},
  178. },
  179. }
  180. async def post_json_get_json(
  181. uri: str,
  182. post_json: Any,
  183. headers: Mapping[Union[str, bytes], Sequence[Union[str, bytes]]],
  184. ) -> JsonDict:
  185. # Ensure the access token is passed as both a header and query arg.
  186. if not headers.get("Authorization"):
  187. raise RuntimeError("Access token not provided")
  188. self.assertEqual(headers.get("Authorization"), [f"Bearer {TOKEN}"])
  189. return RESPONSE
  190. # We assign to a method, which mypy doesn't like.
  191. self.api.post_json_get_json = Mock(side_effect=post_json_get_json) # type: ignore[assignment]
  192. MISSING_KEYS = [
  193. # Known user, known device, missing algorithm.
  194. ("@alice:example.org", "DEVICE_2", "xyz", 1),
  195. # Known user, missing device.
  196. ("@alice:example.org", "DEVICE_3", "signed_curve25519", 1),
  197. # Unknown user.
  198. ("@bob:example.org", "DEVICE_4", "signed_curve25519", 1),
  199. ]
  200. claimed_keys, missing = self.get_success(
  201. self.api.claim_client_keys(
  202. self.service,
  203. [
  204. # Found devices
  205. ("@alice:example.org", "DEVICE_1", "signed_curve25519", 1),
  206. ("@alice:example.org", "DEVICE_2", "signed_curve25519", 1),
  207. ]
  208. + MISSING_KEYS,
  209. )
  210. )
  211. self.assertEqual(claimed_keys, RESPONSE)
  212. self.assertEqual(missing, MISSING_KEYS)