test_v2.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. from six.moves import zip
  17. import attr
  18. from synapse.api.constants import EventTypes, JoinRules, Membership
  19. from synapse.event_auth import auth_types_for_event
  20. from synapse.events import FrozenEvent
  21. from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
  22. from synapse.types import EventID
  23. from tests import unittest
  24. ALICE = "@alice:example.com"
  25. BOB = "@bob:example.com"
  26. CHARLIE = "@charlie:example.com"
  27. EVELYN = "@evelyn:example.com"
  28. ZARA = "@zara:example.com"
  29. ROOM_ID = "!test:example.com"
  30. MEMBERSHIP_CONTENT_JOIN = {"membership": Membership.JOIN}
  31. MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
  32. ORIGIN_SERVER_TS = 0
  33. class FakeEvent(object):
  34. """A fake event we use as a convenience.
  35. NOTE: Again as a convenience we use "node_ids" rather than event_ids to
  36. refer to events. The event_id has node_id as localpart and example.com
  37. as domain.
  38. """
  39. def __init__(self, id, sender, type, state_key, content):
  40. self.node_id = id
  41. self.event_id = EventID(id, "example.com").to_string()
  42. self.sender = sender
  43. self.type = type
  44. self.state_key = state_key
  45. self.content = content
  46. def to_event(self, auth_events, prev_events):
  47. """Given the auth_events and prev_events, convert to a Frozen Event
  48. Args:
  49. auth_events (list[str]): list of event_ids
  50. prev_events (list[str]): list of event_ids
  51. Returns:
  52. FrozenEvent
  53. """
  54. global ORIGIN_SERVER_TS
  55. ts = ORIGIN_SERVER_TS
  56. ORIGIN_SERVER_TS = ORIGIN_SERVER_TS + 1
  57. event_dict = {
  58. "auth_events": [(a, {}) for a in auth_events],
  59. "prev_events": [(p, {}) for p in prev_events],
  60. "event_id": self.node_id,
  61. "sender": self.sender,
  62. "type": self.type,
  63. "content": self.content,
  64. "origin_server_ts": ts,
  65. "room_id": ROOM_ID,
  66. }
  67. if self.state_key is not None:
  68. event_dict["state_key"] = self.state_key
  69. return FrozenEvent(event_dict)
  70. # All graphs start with this set of events
  71. INITIAL_EVENTS = [
  72. FakeEvent(
  73. id="CREATE",
  74. sender=ALICE,
  75. type=EventTypes.Create,
  76. state_key="",
  77. content={"creator": ALICE},
  78. ),
  79. FakeEvent(
  80. id="IMA",
  81. sender=ALICE,
  82. type=EventTypes.Member,
  83. state_key=ALICE,
  84. content=MEMBERSHIP_CONTENT_JOIN,
  85. ),
  86. FakeEvent(
  87. id="IPOWER",
  88. sender=ALICE,
  89. type=EventTypes.PowerLevels,
  90. state_key="",
  91. content={"users": {ALICE: 100}},
  92. ),
  93. FakeEvent(
  94. id="IJR",
  95. sender=ALICE,
  96. type=EventTypes.JoinRules,
  97. state_key="",
  98. content={"join_rule": JoinRules.PUBLIC},
  99. ),
  100. FakeEvent(
  101. id="IMB",
  102. sender=BOB,
  103. type=EventTypes.Member,
  104. state_key=BOB,
  105. content=MEMBERSHIP_CONTENT_JOIN,
  106. ),
  107. FakeEvent(
  108. id="IMC",
  109. sender=CHARLIE,
  110. type=EventTypes.Member,
  111. state_key=CHARLIE,
  112. content=MEMBERSHIP_CONTENT_JOIN,
  113. ),
  114. FakeEvent(
  115. id="IMZ",
  116. sender=ZARA,
  117. type=EventTypes.Member,
  118. state_key=ZARA,
  119. content=MEMBERSHIP_CONTENT_JOIN,
  120. ),
  121. FakeEvent(
  122. id="START",
  123. sender=ZARA,
  124. type=EventTypes.Message,
  125. state_key=None,
  126. content={},
  127. ),
  128. FakeEvent(
  129. id="END",
  130. sender=ZARA,
  131. type=EventTypes.Message,
  132. state_key=None,
  133. content={},
  134. ),
  135. ]
  136. INITIAL_EDGES = [
  137. "START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE",
  138. ]
  139. class StateTestCase(unittest.TestCase):
  140. def test_ban_vs_pl(self):
  141. events = [
  142. FakeEvent(
  143. id="PA",
  144. sender=ALICE,
  145. type=EventTypes.PowerLevels,
  146. state_key="",
  147. content={
  148. "users": {
  149. ALICE: 100,
  150. BOB: 50,
  151. }
  152. },
  153. ),
  154. FakeEvent(
  155. id="MA",
  156. sender=ALICE,
  157. type=EventTypes.Member,
  158. state_key=ALICE,
  159. content={"membership": Membership.JOIN},
  160. ),
  161. FakeEvent(
  162. id="MB",
  163. sender=ALICE,
  164. type=EventTypes.Member,
  165. state_key=BOB,
  166. content={"membership": Membership.BAN},
  167. ),
  168. FakeEvent(
  169. id="PB",
  170. sender=BOB,
  171. type=EventTypes.PowerLevels,
  172. state_key='',
  173. content={
  174. "users": {
  175. ALICE: 100,
  176. BOB: 50,
  177. },
  178. },
  179. ),
  180. ]
  181. edges = [
  182. ["END", "MB", "MA", "PA", "START"],
  183. ["END", "PB", "PA"],
  184. ]
  185. expected_state_ids = ["PA", "MA", "MB"]
  186. self.do_check(events, edges, expected_state_ids)
  187. def test_join_rule_evasion(self):
  188. events = [
  189. FakeEvent(
  190. id="JR",
  191. sender=ALICE,
  192. type=EventTypes.JoinRules,
  193. state_key="",
  194. content={"join_rules": JoinRules.PRIVATE},
  195. ),
  196. FakeEvent(
  197. id="ME",
  198. sender=EVELYN,
  199. type=EventTypes.Member,
  200. state_key=EVELYN,
  201. content={"membership": Membership.JOIN},
  202. ),
  203. ]
  204. edges = [
  205. ["END", "JR", "START"],
  206. ["END", "ME", "START"],
  207. ]
  208. expected_state_ids = ["JR"]
  209. self.do_check(events, edges, expected_state_ids)
  210. def test_offtopic_pl(self):
  211. events = [
  212. FakeEvent(
  213. id="PA",
  214. sender=ALICE,
  215. type=EventTypes.PowerLevels,
  216. state_key="",
  217. content={
  218. "users": {
  219. ALICE: 100,
  220. BOB: 50,
  221. }
  222. },
  223. ),
  224. FakeEvent(
  225. id="PB",
  226. sender=BOB,
  227. type=EventTypes.PowerLevels,
  228. state_key='',
  229. content={
  230. "users": {
  231. ALICE: 100,
  232. BOB: 50,
  233. CHARLIE: 50,
  234. },
  235. },
  236. ),
  237. FakeEvent(
  238. id="PC",
  239. sender=CHARLIE,
  240. type=EventTypes.PowerLevels,
  241. state_key='',
  242. content={
  243. "users": {
  244. ALICE: 100,
  245. BOB: 50,
  246. CHARLIE: 0,
  247. },
  248. },
  249. ),
  250. ]
  251. edges = [
  252. ["END", "PC", "PB", "PA", "START"],
  253. ["END", "PA"],
  254. ]
  255. expected_state_ids = ["PC"]
  256. self.do_check(events, edges, expected_state_ids)
  257. def test_topic_basic(self):
  258. events = [
  259. FakeEvent(
  260. id="T1",
  261. sender=ALICE,
  262. type=EventTypes.Topic,
  263. state_key="",
  264. content={},
  265. ),
  266. FakeEvent(
  267. id="PA1",
  268. sender=ALICE,
  269. type=EventTypes.PowerLevels,
  270. state_key='',
  271. content={
  272. "users": {
  273. ALICE: 100,
  274. BOB: 50,
  275. },
  276. },
  277. ),
  278. FakeEvent(
  279. id="T2",
  280. sender=ALICE,
  281. type=EventTypes.Topic,
  282. state_key="",
  283. content={},
  284. ),
  285. FakeEvent(
  286. id="PA2",
  287. sender=ALICE,
  288. type=EventTypes.PowerLevels,
  289. state_key='',
  290. content={
  291. "users": {
  292. ALICE: 100,
  293. BOB: 0,
  294. },
  295. },
  296. ),
  297. FakeEvent(
  298. id="PB",
  299. sender=BOB,
  300. type=EventTypes.PowerLevels,
  301. state_key='',
  302. content={
  303. "users": {
  304. ALICE: 100,
  305. BOB: 50,
  306. },
  307. },
  308. ),
  309. FakeEvent(
  310. id="T3",
  311. sender=BOB,
  312. type=EventTypes.Topic,
  313. state_key="",
  314. content={},
  315. ),
  316. ]
  317. edges = [
  318. ["END", "PA2", "T2", "PA1", "T1", "START"],
  319. ["END", "T3", "PB", "PA1"],
  320. ]
  321. expected_state_ids = ["PA2", "T2"]
  322. self.do_check(events, edges, expected_state_ids)
  323. def test_topic_reset(self):
  324. events = [
  325. FakeEvent(
  326. id="T1",
  327. sender=ALICE,
  328. type=EventTypes.Topic,
  329. state_key="",
  330. content={},
  331. ),
  332. FakeEvent(
  333. id="PA",
  334. sender=ALICE,
  335. type=EventTypes.PowerLevels,
  336. state_key='',
  337. content={
  338. "users": {
  339. ALICE: 100,
  340. BOB: 50,
  341. },
  342. },
  343. ),
  344. FakeEvent(
  345. id="T2",
  346. sender=BOB,
  347. type=EventTypes.Topic,
  348. state_key="",
  349. content={},
  350. ),
  351. FakeEvent(
  352. id="MB",
  353. sender=ALICE,
  354. type=EventTypes.Member,
  355. state_key=BOB,
  356. content={"membership": Membership.BAN},
  357. ),
  358. ]
  359. edges = [
  360. ["END", "MB", "T2", "PA", "T1", "START"],
  361. ["END", "T1"],
  362. ]
  363. expected_state_ids = ["T1", "MB", "PA"]
  364. self.do_check(events, edges, expected_state_ids)
  365. def test_topic(self):
  366. events = [
  367. FakeEvent(
  368. id="T1",
  369. sender=ALICE,
  370. type=EventTypes.Topic,
  371. state_key="",
  372. content={},
  373. ),
  374. FakeEvent(
  375. id="PA1",
  376. sender=ALICE,
  377. type=EventTypes.PowerLevels,
  378. state_key='',
  379. content={
  380. "users": {
  381. ALICE: 100,
  382. BOB: 50,
  383. },
  384. },
  385. ),
  386. FakeEvent(
  387. id="T2",
  388. sender=ALICE,
  389. type=EventTypes.Topic,
  390. state_key="",
  391. content={},
  392. ),
  393. FakeEvent(
  394. id="PA2",
  395. sender=ALICE,
  396. type=EventTypes.PowerLevels,
  397. state_key='',
  398. content={
  399. "users": {
  400. ALICE: 100,
  401. BOB: 0,
  402. },
  403. },
  404. ),
  405. FakeEvent(
  406. id="PB",
  407. sender=BOB,
  408. type=EventTypes.PowerLevels,
  409. state_key='',
  410. content={
  411. "users": {
  412. ALICE: 100,
  413. BOB: 50,
  414. },
  415. },
  416. ),
  417. FakeEvent(
  418. id="T3",
  419. sender=BOB,
  420. type=EventTypes.Topic,
  421. state_key="",
  422. content={},
  423. ),
  424. FakeEvent(
  425. id="MZ1",
  426. sender=ZARA,
  427. type=EventTypes.Message,
  428. state_key=None,
  429. content={},
  430. ),
  431. FakeEvent(
  432. id="T4",
  433. sender=ALICE,
  434. type=EventTypes.Topic,
  435. state_key="",
  436. content={},
  437. ),
  438. ]
  439. edges = [
  440. ["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"],
  441. ["END", "MZ1", "T3", "PB", "PA1"],
  442. ]
  443. expected_state_ids = ["T4", "PA2"]
  444. self.do_check(events, edges, expected_state_ids)
  445. def do_check(self, events, edges, expected_state_ids):
  446. """Take a list of events and edges and calculate the state of the
  447. graph at END, and asserts it matches `expected_state_ids`
  448. Args:
  449. events (list[FakeEvent])
  450. edges (list[list[str]]): A list of chains of event edges, e.g.
  451. `[[A, B, C]]` are edges A->B and B->C.
  452. expected_state_ids (list[str]): The expected state at END, (excluding
  453. the keys that haven't changed since START).
  454. """
  455. # We want to sort the events into topological order for processing.
  456. graph = {}
  457. # node_id -> FakeEvent
  458. fake_event_map = {}
  459. for ev in itertools.chain(INITIAL_EVENTS, events):
  460. graph[ev.node_id] = set()
  461. fake_event_map[ev.node_id] = ev
  462. for a, b in pairwise(INITIAL_EDGES):
  463. graph[a].add(b)
  464. for edge_list in edges:
  465. for a, b in pairwise(edge_list):
  466. graph[a].add(b)
  467. # event_id -> FrozenEvent
  468. event_map = {}
  469. # node_id -> state
  470. state_at_event = {}
  471. # We copy the map as the sort consumes the graph
  472. graph_copy = {k: set(v) for k, v in graph.items()}
  473. for node_id in lexicographical_topological_sort(graph_copy, key=lambda e: e):
  474. fake_event = fake_event_map[node_id]
  475. event_id = fake_event.event_id
  476. prev_events = list(graph[node_id])
  477. if len(prev_events) == 0:
  478. state_before = {}
  479. elif len(prev_events) == 1:
  480. state_before = dict(state_at_event[prev_events[0]])
  481. else:
  482. state_d = resolve_events_with_store(
  483. [state_at_event[n] for n in prev_events],
  484. event_map=event_map,
  485. state_res_store=TestStateResolutionStore(event_map),
  486. )
  487. self.assertTrue(state_d.called)
  488. state_before = state_d.result
  489. state_after = dict(state_before)
  490. if fake_event.state_key is not None:
  491. state_after[(fake_event.type, fake_event.state_key)] = event_id
  492. auth_types = set(auth_types_for_event(fake_event))
  493. auth_events = []
  494. for key in auth_types:
  495. if key in state_before:
  496. auth_events.append(state_before[key])
  497. event = fake_event.to_event(auth_events, prev_events)
  498. state_at_event[node_id] = state_after
  499. event_map[event_id] = event
  500. expected_state = {}
  501. for node_id in expected_state_ids:
  502. # expected_state_ids are node IDs rather than event IDs,
  503. # so we have to convert
  504. event_id = EventID(node_id, "example.com").to_string()
  505. event = event_map[event_id]
  506. key = (event.type, event.state_key)
  507. expected_state[key] = event_id
  508. start_state = state_at_event["START"]
  509. end_state = {
  510. key: value
  511. for key, value in state_at_event["END"].items()
  512. if key in expected_state or start_state.get(key) != value
  513. }
  514. self.assertEqual(expected_state, end_state)
  515. class LexicographicalTestCase(unittest.TestCase):
  516. def test_simple(self):
  517. graph = {
  518. "l": {"o"},
  519. "m": {"n", "o"},
  520. "n": {"o"},
  521. "o": set(),
  522. "p": {"o"},
  523. }
  524. res = list(lexicographical_topological_sort(graph, key=lambda x: x))
  525. self.assertEqual(["o", "l", "n", "m", "p"], res)
  526. def pairwise(iterable):
  527. "s -> (s0,s1), (s1,s2), (s2, s3), ..."
  528. a, b = itertools.tee(iterable)
  529. next(b, None)
  530. return zip(a, b)
  531. @attr.s
  532. class TestStateResolutionStore(object):
  533. event_map = attr.ib()
  534. def get_events(self, event_ids, allow_rejected=False):
  535. """Get events from the database
  536. Args:
  537. event_ids (list): The event_ids of the events to fetch
  538. allow_rejected (bool): If True return rejected events.
  539. Returns:
  540. Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
  541. """
  542. return {
  543. eid: self.event_map[eid]
  544. for eid in event_ids
  545. if eid in self.event_map
  546. }
  547. def get_auth_chain(self, event_ids):
  548. """Gets the full auth chain for a set of events (including rejected
  549. events).
  550. Includes the given event IDs in the result.
  551. Note that:
  552. 1. All events must be state events.
  553. 2. For v1 rooms this may not have the full auth chain in the
  554. presence of rejected events
  555. Args:
  556. event_ids (list): The event IDs of the events to fetch the auth
  557. chain for. Must be state events.
  558. Returns:
  559. Deferred[list[str]]: List of event IDs of the auth chain.
  560. """
  561. # Simple DFS for auth chain
  562. result = set()
  563. stack = list(event_ids)
  564. while stack:
  565. event_id = stack.pop()
  566. if event_id in result:
  567. continue
  568. result.add(event_id)
  569. event = self.event_map[event_id]
  570. for aid, _ in event.auth_events:
  571. stack.append(aid)
  572. return list(result)