1
0

test_resource.py 7.6 KB


  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from synapse.replication.resource import ReplicationResource
  16. from synapse.types import Requester, UserID
  17. from twisted.internet import defer
  18. from tests import unittest
  19. from tests.utils import setup_test_homeserver, requester_for_user
  20. from mock import Mock, NonCallableMock
  21. import json
  22. import contextlib
  23. class ReplicationResourceCase(unittest.TestCase):
  24. @defer.inlineCallbacks
  25. def setUp(self):
  26. self.hs = yield setup_test_homeserver(
  27. "red",
  28. http_client=None,
  29. replication_layer=Mock(),
  30. ratelimiter=NonCallableMock(spec_set=[
  31. "send_message",
  32. ]),
  33. )
  34. self.user_id = "@seeing:red"
  35. self.user = UserID.from_string(self.user_id)
  36. self.hs.get_ratelimiter().send_message.return_value = (True, 0)
  37. self.resource = ReplicationResource(self.hs)
  38. @defer.inlineCallbacks
  39. def test_streams(self):
  40. # Passing "-1" returns the current stream positions
  41. code, body = yield self.get(streams="-1")
  42. self.assertEquals(code, 200)
  43. self.assertEquals(body["streams"]["field_names"], ["name", "position"])
  44. position = body["streams"]["position"]
  45. # Passing the current position returns an empty response after the
  46. # timeout
  47. get = self.get(streams=str(position), timeout="0")
  48. self.hs.clock.advance_time_msec(1)
  49. code, body = yield get
  50. self.assertEquals(code, 200)
  51. self.assertEquals(body, {})
  52. @defer.inlineCallbacks
  53. def test_events_and_state(self):
  54. get = self.get(events="-1", state="-1", timeout="0")
  55. yield self.hs.get_handlers().room_creation_handler.create_room(
  56. Requester(self.user, "", False), {}
  57. )
  58. code, body = yield get
  59. self.assertEquals(code, 200)
  60. self.assertEquals(body["events"]["field_names"], [
  61. "position", "internal", "json", "state_group"
  62. ])
  63. self.assertEquals(body["state_groups"]["field_names"], [
  64. "position", "room_id", "event_id"
  65. ])
  66. self.assertEquals(body["state_group_state"]["field_names"], [
  67. "position", "type", "state_key", "event_id"
  68. ])
  69. @defer.inlineCallbacks
  70. def test_presence(self):
  71. get = self.get(presence="-1")
  72. yield self.hs.get_handlers().presence_handler.set_state(
  73. self.user, {"presence": "online"}
  74. )
  75. code, body = yield get
  76. self.assertEquals(code, 200)
  77. self.assertEquals(body["presence"]["field_names"], [
  78. "position", "user_id", "state", "last_active_ts",
  79. "last_federation_update_ts", "last_user_sync_ts",
  80. "status_msg", "currently_active",
  81. ])
  82. @defer.inlineCallbacks
  83. def test_typing(self):
  84. room_id = yield self.create_room()
  85. get = self.get(typing="-1")
  86. yield self.hs.get_handlers().typing_notification_handler.started_typing(
  87. self.user, self.user, room_id, timeout=2
  88. )
  89. code, body = yield get
  90. self.assertEquals(code, 200)
  91. self.assertEquals(body["typing"]["field_names"], [
  92. "position", "room_id", "typing"
  93. ])
  94. @defer.inlineCallbacks
  95. def test_receipts(self):
  96. room_id = yield self.create_room()
  97. event_id = yield self.send_text_message(room_id, "Hello, World")
  98. get = self.get(receipts="-1")
  99. yield self.hs.get_handlers().receipts_handler.received_client_receipt(
  100. room_id, "m.read", self.user_id, event_id
  101. )
  102. code, body = yield get
  103. self.assertEquals(code, 200)
  104. self.assertEquals(body["receipts"]["field_names"], [
  105. "position", "room_id", "receipt_type", "user_id", "event_id", "data"
  106. ])
  107. def _test_timeout(stream):
  108. """Check that a request for the given stream timesout"""
  109. @defer.inlineCallbacks
  110. def test_timeout(self):
  111. get = self.get(**{stream: "-1", "timeout": "0"})
  112. self.hs.clock.advance_time_msec(1)
  113. code, body = yield get
  114. self.assertEquals(code, 200)
  115. self.assertEquals(body, {})
  116. test_timeout.__name__ = "test_timeout_%s" % (stream)
  117. return test_timeout
  118. test_timeout_events = _test_timeout("events")
  119. test_timeout_presence = _test_timeout("presence")
  120. test_timeout_typing = _test_timeout("typing")
  121. test_timeout_receipts = _test_timeout("receipts")
  122. test_timeout_user_account_data = _test_timeout("user_account_data")
  123. test_timeout_room_account_data = _test_timeout("room_account_data")
  124. test_timeout_tag_account_data = _test_timeout("tag_account_data")
  125. test_timeout_backfill = _test_timeout("backfill")
  126. test_timeout_push_rules = _test_timeout("push_rules")
  127. test_timeout_pushers = _test_timeout("pushers")
  128. test_timeout_state = _test_timeout("state")
  129. @defer.inlineCallbacks
  130. def send_text_message(self, room_id, message):
  131. handler = self.hs.get_handlers().message_handler
  132. event = yield handler.create_and_send_nonmember_event(
  133. requester_for_user(self.user),
  134. {
  135. "type": "m.room.message",
  136. "content": {"body": "message", "msgtype": "m.text"},
  137. "room_id": room_id,
  138. "sender": self.user.to_string(),
  139. }
  140. )
  141. defer.returnValue(event.event_id)
  142. @defer.inlineCallbacks
  143. def create_room(self):
  144. result = yield self.hs.get_handlers().room_creation_handler.create_room(
  145. Requester(self.user, "", False), {}
  146. )
  147. defer.returnValue(result["room_id"])
  148. @defer.inlineCallbacks
  149. def get(self, **params):
  150. request = NonCallableMock(spec_set=[
  151. "write", "finish", "setResponseCode", "setHeader", "args",
  152. "method", "processing"
  153. ])
  154. request.method = "GET"
  155. request.args = {k: [v] for k, v in params.items()}
  156. @contextlib.contextmanager
  157. def processing():
  158. yield
  159. request.processing = processing
  160. yield self.resource._async_render_GET(request)
  161. self.assertTrue(request.finish.called)
  162. if request.setResponseCode.called:
  163. response_code = request.setResponseCode.call_args[0][0]
  164. else:
  165. response_code = 200
  166. response_json = "".join(
  167. call[0][0] for call in request.write.call_args_list
  168. )
  169. response_body = json.loads(response_json)
  170. if response_code == 200:
  171. self.check_response(response_body)
  172. defer.returnValue((response_code, response_body))
  173. def check_response(self, response_body):
  174. for name, stream in response_body.items():
  175. self.assertIn("field_names", stream)
  176. field_names = stream["field_names"]
  177. self.assertIn("rows", stream)
  178. self.assertTrue(stream["rows"])
  179. for row in stream["rows"]:
  180. self.assertEquals(
  181. len(row), len(field_names),
  182. "%s: len(row = %r) == len(field_names = %r)" % (
  183. name, row, field_names
  184. )
  185. )