test_visibility.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from twisted.internet import defer
  17. from twisted.internet.defer import succeed
  18. from synapse.api.constants import RoomVersions
  19. from synapse.events import FrozenEvent
  20. from synapse.visibility import filter_events_for_server
  21. import tests.unittest
  22. from tests.utils import create_room, setup_test_homeserver
  23. logger = logging.getLogger(__name__)
  24. TEST_ROOM_ID = "!TEST:ROOM"
  25. class FilterEventsForServerTestCase(tests.unittest.TestCase):
  26. @defer.inlineCallbacks
  27. def setUp(self):
  28. self.hs = yield setup_test_homeserver(self.addCleanup)
  29. self.event_creation_handler = self.hs.get_event_creation_handler()
  30. self.event_builder_factory = self.hs.get_event_builder_factory()
  31. self.store = self.hs.get_datastore()
  32. yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
  33. @defer.inlineCallbacks
  34. def test_filtering(self):
  35. #
  36. # The events to be filtered consist of 10 membership events (it doesn't
  37. # really matter if they are joins or leaves, so let's make them joins).
  38. # One of those membership events is going to be for a user on the
  39. # server we are filtering for (so we can check the filtering is doing
  40. # the right thing).
  41. #
  42. # before we do that, we persist some other events to act as state.
  43. self.inject_visibility("@admin:hs", "joined")
  44. for i in range(0, 10):
  45. yield self.inject_room_member("@resident%i:hs" % i)
  46. events_to_filter = []
  47. for i in range(0, 10):
  48. user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
  49. evt = yield self.inject_room_member(user, extra_content={"a": "b"})
  50. events_to_filter.append(evt)
  51. filtered = yield filter_events_for_server(
  52. self.store, "test_server", events_to_filter
  53. )
  54. # the result should be 5 redacted events, and 5 unredacted events.
  55. for i in range(0, 5):
  56. self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
  57. self.assertNotIn("a", filtered[i].content)
  58. for i in range(5, 10):
  59. self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
  60. self.assertEqual(filtered[i].content["a"], "b")
  61. @tests.unittest.DEBUG
  62. @defer.inlineCallbacks
  63. def test_erased_user(self):
  64. # 4 message events, from erased and unerased users, with a membership
  65. # change in the middle of them.
  66. events_to_filter = []
  67. evt = yield self.inject_message("@unerased:local_hs")
  68. events_to_filter.append(evt)
  69. evt = yield self.inject_message("@erased:local_hs")
  70. events_to_filter.append(evt)
  71. evt = yield self.inject_room_member("@joiner:remote_hs")
  72. events_to_filter.append(evt)
  73. evt = yield self.inject_message("@unerased:local_hs")
  74. events_to_filter.append(evt)
  75. evt = yield self.inject_message("@erased:local_hs")
  76. events_to_filter.append(evt)
  77. # the erasey user gets erased
  78. yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
  79. # ... and the filtering happens.
  80. filtered = yield filter_events_for_server(
  81. self.store, "test_server", events_to_filter
  82. )
  83. for i in range(0, len(events_to_filter)):
  84. self.assertEqual(
  85. events_to_filter[i].event_id,
  86. filtered[i].event_id,
  87. "Unexpected event at result position %i" % (i,),
  88. )
  89. for i in (0, 3):
  90. self.assertEqual(
  91. events_to_filter[i].content["body"],
  92. filtered[i].content["body"],
  93. "Unexpected event content at result position %i" % (i,),
  94. )
  95. for i in (1, 4):
  96. self.assertNotIn("body", filtered[i].content)
  97. @defer.inlineCallbacks
  98. def inject_visibility(self, user_id, visibility):
  99. content = {"history_visibility": visibility}
  100. builder = self.event_builder_factory.new(
  101. RoomVersions.V1,
  102. {
  103. "type": "m.room.history_visibility",
  104. "sender": user_id,
  105. "state_key": "",
  106. "room_id": TEST_ROOM_ID,
  107. "content": content,
  108. }
  109. )
  110. event, context = yield self.event_creation_handler.create_new_client_event(
  111. builder
  112. )
  113. yield self.hs.get_datastore().persist_event(event, context)
  114. defer.returnValue(event)
  115. @defer.inlineCallbacks
  116. def inject_room_member(self, user_id, membership="join", extra_content={}):
  117. content = {"membership": membership}
  118. content.update(extra_content)
  119. builder = self.event_builder_factory.new(
  120. RoomVersions.V1,
  121. {
  122. "type": "m.room.member",
  123. "sender": user_id,
  124. "state_key": user_id,
  125. "room_id": TEST_ROOM_ID,
  126. "content": content,
  127. }
  128. )
  129. event, context = yield self.event_creation_handler.create_new_client_event(
  130. builder
  131. )
  132. yield self.hs.get_datastore().persist_event(event, context)
  133. defer.returnValue(event)
  134. @defer.inlineCallbacks
  135. def inject_message(self, user_id, content=None):
  136. if content is None:
  137. content = {"body": "testytest", "msgtype": "m.text"}
  138. builder = self.event_builder_factory.new(
  139. RoomVersions.V1,
  140. {
  141. "type": "m.room.message",
  142. "sender": user_id,
  143. "room_id": TEST_ROOM_ID,
  144. "content": content,
  145. }
  146. )
  147. event, context = yield self.event_creation_handler.create_new_client_event(
  148. builder
  149. )
  150. yield self.hs.get_datastore().persist_event(event, context)
  151. defer.returnValue(event)
  152. @defer.inlineCallbacks
  153. def test_large_room(self):
  154. # see what happens when we have a large room with hundreds of thousands
  155. # of membership events
  156. # As above, the events to be filtered consist of 10 membership events,
  157. # where one of them is for a user on the server we are filtering for.
  158. import cProfile
  159. import pstats
  160. import time
  161. # we stub out the store, because building up all that state the normal
  162. # way is very slow.
  163. test_store = _TestStore()
  164. # our initial state is 100000 membership events and one
  165. # history_visibility event.
  166. room_state = []
  167. history_visibility_evt = FrozenEvent(
  168. {
  169. "event_id": "$history_vis",
  170. "type": "m.room.history_visibility",
  171. "sender": "@resident_user_0:test.com",
  172. "state_key": "",
  173. "room_id": TEST_ROOM_ID,
  174. "content": {"history_visibility": "joined"},
  175. }
  176. )
  177. room_state.append(history_visibility_evt)
  178. test_store.add_event(history_visibility_evt)
  179. for i in range(0, 100000):
  180. user = "@resident_user_%i:test.com" % (i,)
  181. evt = FrozenEvent(
  182. {
  183. "event_id": "$res_event_%i" % (i,),
  184. "type": "m.room.member",
  185. "state_key": user,
  186. "sender": user,
  187. "room_id": TEST_ROOM_ID,
  188. "content": {"membership": "join", "extra": "zzz,"},
  189. }
  190. )
  191. room_state.append(evt)
  192. test_store.add_event(evt)
  193. events_to_filter = []
  194. for i in range(0, 10):
  195. user = "@user%i:%s" % (i, "test_server" if i == 5 else "other_server")
  196. evt = FrozenEvent(
  197. {
  198. "event_id": "$evt%i" % (i,),
  199. "type": "m.room.member",
  200. "state_key": user,
  201. "sender": user,
  202. "room_id": TEST_ROOM_ID,
  203. "content": {"membership": "join", "extra": "zzz"},
  204. }
  205. )
  206. events_to_filter.append(evt)
  207. room_state.append(evt)
  208. test_store.add_event(evt)
  209. test_store.set_state_ids_for_event(
  210. evt, {(e.type, e.state_key): e.event_id for e in room_state}
  211. )
  212. pr = cProfile.Profile()
  213. pr.enable()
  214. logger.info("Starting filtering")
  215. start = time.time()
  216. filtered = yield filter_events_for_server(
  217. test_store, "test_server", events_to_filter
  218. )
  219. logger.info("Filtering took %f seconds", time.time() - start)
  220. pr.disable()
  221. with open("filter_events_for_server.profile", "w+") as f:
  222. ps = pstats.Stats(pr, stream=f).sort_stats('cumulative')
  223. ps.print_stats()
  224. # the result should be 5 redacted events, and 5 unredacted events.
  225. for i in range(0, 5):
  226. self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
  227. self.assertNotIn("extra", filtered[i].content)
  228. for i in range(5, 10):
  229. self.assertEqual(events_to_filter[i].event_id, filtered[i].event_id)
  230. self.assertEqual(filtered[i].content["extra"], "zzz")
  231. test_large_room.skip = "Disabled by default because it's slow"
  232. class _TestStore(object):
  233. """Implements a few methods of the DataStore, so that we can test
  234. filter_events_for_server
  235. """
  236. def __init__(self):
  237. # data for get_events: a map from event_id to event
  238. self.events = {}
  239. # data for get_state_ids_for_events mock: a map from event_id to
  240. # a map from (type_state_key) -> event_id for the state at that
  241. # event
  242. self.state_ids_for_events = {}
  243. def add_event(self, event):
  244. self.events[event.event_id] = event
  245. def set_state_ids_for_event(self, event, state):
  246. self.state_ids_for_events[event.event_id] = state
  247. def get_state_ids_for_events(self, events, types):
  248. res = {}
  249. include_memberships = False
  250. for (type, state_key) in types:
  251. if type == "m.room.history_visibility":
  252. continue
  253. if type != "m.room.member" or state_key is not None:
  254. raise RuntimeError(
  255. "Unimplemented: get_state_ids with type (%s, %s)"
  256. % (type, state_key)
  257. )
  258. include_memberships = True
  259. if include_memberships:
  260. for event_id in events:
  261. res[event_id] = self.state_ids_for_events[event_id]
  262. else:
  263. k = ("m.room.history_visibility", "")
  264. for event_id in events:
  265. hve = self.state_ids_for_events[event_id][k]
  266. res[event_id] = {k: hve}
  267. return succeed(res)
  268. def get_events(self, events):
  269. return succeed({event_id: self.events[event_id] for event_id in events})
  270. def are_users_erased(self, users):
  271. return succeed({u: False for u in users})