test_state.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627
  1. from immutabledict import immutabledict
  2. from synapse.api.constants import EventTypes
  3. from synapse.types.state import StateFilter
  4. from tests.unittest import TestCase
  5. class StateFilterDifferenceTestCase(TestCase):
  6. def assert_difference(
  7. self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
  8. ) -> None:
  9. self.assertEqual(
  10. minuend.approx_difference(subtrahend),
  11. expected,
  12. f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
  13. )
  14. def test_state_filter_difference_no_include_other_minus_no_include_other(
  15. self,
  16. ) -> None:
  17. """
  18. Tests the StateFilter.approx_difference method
  19. where, in a.approx_difference(b), both a and b do not have the
  20. include_others flag set.
  21. """
  22. # (wildcard on state keys) - (wildcard on state keys):
  23. self.assert_difference(
  24. StateFilter.freeze(
  25. {EventTypes.Member: None, EventTypes.Create: None},
  26. include_others=False,
  27. ),
  28. StateFilter.freeze(
  29. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  30. include_others=False,
  31. ),
  32. StateFilter.freeze({EventTypes.Create: None}, include_others=False),
  33. )
  34. # (wildcard on state keys) - (specific state keys)
  35. # This one is an over-approximation because we can't represent
  36. # 'all state keys except a few named examples'
  37. self.assert_difference(
  38. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  39. StateFilter.freeze(
  40. {EventTypes.Member: {"@wombat:spqr"}},
  41. include_others=False,
  42. ),
  43. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  44. )
  45. # (wildcard on state keys) - (no state keys)
  46. self.assert_difference(
  47. StateFilter.freeze(
  48. {EventTypes.Member: None},
  49. include_others=False,
  50. ),
  51. StateFilter.freeze(
  52. {
  53. EventTypes.Member: set(),
  54. },
  55. include_others=False,
  56. ),
  57. StateFilter.freeze(
  58. {EventTypes.Member: None},
  59. include_others=False,
  60. ),
  61. )
  62. # (specific state keys) - (wildcard on state keys):
  63. self.assert_difference(
  64. StateFilter.freeze(
  65. {
  66. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  67. EventTypes.CanonicalAlias: {""},
  68. },
  69. include_others=False,
  70. ),
  71. StateFilter.freeze(
  72. {EventTypes.Member: None},
  73. include_others=False,
  74. ),
  75. StateFilter.freeze(
  76. {EventTypes.CanonicalAlias: {""}},
  77. include_others=False,
  78. ),
  79. )
  80. # (specific state keys) - (specific state keys)
  81. self.assert_difference(
  82. StateFilter.freeze(
  83. {
  84. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  85. EventTypes.CanonicalAlias: {""},
  86. },
  87. include_others=False,
  88. ),
  89. StateFilter.freeze(
  90. {
  91. EventTypes.Member: {"@wombat:spqr"},
  92. },
  93. include_others=False,
  94. ),
  95. StateFilter.freeze(
  96. {
  97. EventTypes.Member: {"@spqr:spqr"},
  98. EventTypes.CanonicalAlias: {""},
  99. },
  100. include_others=False,
  101. ),
  102. )
  103. # (specific state keys) - (no state keys)
  104. self.assert_difference(
  105. StateFilter.freeze(
  106. {
  107. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  108. EventTypes.CanonicalAlias: {""},
  109. },
  110. include_others=False,
  111. ),
  112. StateFilter.freeze(
  113. {
  114. EventTypes.Member: set(),
  115. },
  116. include_others=False,
  117. ),
  118. StateFilter.freeze(
  119. {
  120. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  121. EventTypes.CanonicalAlias: {""},
  122. },
  123. include_others=False,
  124. ),
  125. )
  126. def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
  127. """
  128. Tests the StateFilter.approx_difference method
  129. where, in a.approx_difference(b), only a has the include_others flag set.
  130. """
  131. # (wildcard on state keys) - (wildcard on state keys):
  132. self.assert_difference(
  133. StateFilter.freeze(
  134. {EventTypes.Member: None, EventTypes.Create: None},
  135. include_others=True,
  136. ),
  137. StateFilter.freeze(
  138. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  139. include_others=False,
  140. ),
  141. StateFilter.freeze(
  142. {
  143. EventTypes.Create: None,
  144. EventTypes.Member: set(),
  145. EventTypes.CanonicalAlias: set(),
  146. },
  147. include_others=True,
  148. ),
  149. )
  150. # (wildcard on state keys) - (specific state keys)
  151. # This one is an over-approximation because we can't represent
  152. # 'all state keys except a few named examples'
  153. # This also shows that the resultant state filter is normalised.
  154. self.assert_difference(
  155. StateFilter.freeze({EventTypes.Member: None}, include_others=True),
  156. StateFilter.freeze(
  157. {
  158. EventTypes.Member: {"@wombat:spqr"},
  159. EventTypes.Create: {""},
  160. },
  161. include_others=False,
  162. ),
  163. StateFilter(types=immutabledict(), include_others=True),
  164. )
  165. # (wildcard on state keys) - (no state keys)
  166. self.assert_difference(
  167. StateFilter.freeze(
  168. {EventTypes.Member: None},
  169. include_others=True,
  170. ),
  171. StateFilter.freeze(
  172. {
  173. EventTypes.Member: set(),
  174. },
  175. include_others=False,
  176. ),
  177. StateFilter(
  178. types=immutabledict(),
  179. include_others=True,
  180. ),
  181. )
  182. # (specific state keys) - (wildcard on state keys):
  183. self.assert_difference(
  184. StateFilter.freeze(
  185. {
  186. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  187. EventTypes.CanonicalAlias: {""},
  188. },
  189. include_others=True,
  190. ),
  191. StateFilter.freeze(
  192. {EventTypes.Member: None},
  193. include_others=False,
  194. ),
  195. StateFilter.freeze(
  196. {
  197. EventTypes.CanonicalAlias: {""},
  198. EventTypes.Member: set(),
  199. },
  200. include_others=True,
  201. ),
  202. )
  203. # (specific state keys) - (specific state keys)
  204. self.assert_difference(
  205. StateFilter.freeze(
  206. {
  207. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  208. EventTypes.CanonicalAlias: {""},
  209. },
  210. include_others=True,
  211. ),
  212. StateFilter.freeze(
  213. {
  214. EventTypes.Member: {"@wombat:spqr"},
  215. },
  216. include_others=False,
  217. ),
  218. StateFilter.freeze(
  219. {
  220. EventTypes.Member: {"@spqr:spqr"},
  221. EventTypes.CanonicalAlias: {""},
  222. },
  223. include_others=True,
  224. ),
  225. )
  226. # (specific state keys) - (no state keys)
  227. self.assert_difference(
  228. StateFilter.freeze(
  229. {
  230. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  231. EventTypes.CanonicalAlias: {""},
  232. },
  233. include_others=True,
  234. ),
  235. StateFilter.freeze(
  236. {
  237. EventTypes.Member: set(),
  238. },
  239. include_others=False,
  240. ),
  241. StateFilter.freeze(
  242. {
  243. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  244. EventTypes.CanonicalAlias: {""},
  245. },
  246. include_others=True,
  247. ),
  248. )
  249. def test_state_filter_difference_include_other_minus_include_other(self) -> None:
  250. """
  251. Tests the StateFilter.approx_difference method
  252. where, in a.approx_difference(b), both a and b have the include_others
  253. flag set.
  254. """
  255. # (wildcard on state keys) - (wildcard on state keys):
  256. self.assert_difference(
  257. StateFilter.freeze(
  258. {EventTypes.Member: None, EventTypes.Create: None},
  259. include_others=True,
  260. ),
  261. StateFilter.freeze(
  262. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  263. include_others=True,
  264. ),
  265. StateFilter(types=immutabledict(), include_others=False),
  266. )
  267. # (wildcard on state keys) - (specific state keys)
  268. # This one is an over-approximation because we can't represent
  269. # 'all state keys except a few named examples'
  270. self.assert_difference(
  271. StateFilter.freeze({EventTypes.Member: None}, include_others=True),
  272. StateFilter.freeze(
  273. {
  274. EventTypes.Member: {"@wombat:spqr"},
  275. EventTypes.CanonicalAlias: {""},
  276. },
  277. include_others=True,
  278. ),
  279. StateFilter.freeze(
  280. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  281. include_others=False,
  282. ),
  283. )
  284. # (wildcard on state keys) - (no state keys)
  285. self.assert_difference(
  286. StateFilter.freeze(
  287. {EventTypes.Member: None},
  288. include_others=True,
  289. ),
  290. StateFilter.freeze(
  291. {
  292. EventTypes.Member: set(),
  293. },
  294. include_others=True,
  295. ),
  296. StateFilter.freeze(
  297. {EventTypes.Member: None},
  298. include_others=False,
  299. ),
  300. )
  301. # (specific state keys) - (wildcard on state keys):
  302. self.assert_difference(
  303. StateFilter.freeze(
  304. {
  305. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  306. EventTypes.CanonicalAlias: {""},
  307. },
  308. include_others=True,
  309. ),
  310. StateFilter.freeze(
  311. {EventTypes.Member: None},
  312. include_others=True,
  313. ),
  314. StateFilter(
  315. types=immutabledict(),
  316. include_others=False,
  317. ),
  318. )
  319. # (specific state keys) - (specific state keys)
  320. # This one is an over-approximation because we can't represent
  321. # 'all state keys except a few named examples'
  322. self.assert_difference(
  323. StateFilter.freeze(
  324. {
  325. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  326. EventTypes.CanonicalAlias: {""},
  327. EventTypes.Create: {""},
  328. },
  329. include_others=True,
  330. ),
  331. StateFilter.freeze(
  332. {
  333. EventTypes.Member: {"@wombat:spqr"},
  334. EventTypes.Create: set(),
  335. },
  336. include_others=True,
  337. ),
  338. StateFilter.freeze(
  339. {
  340. EventTypes.Member: {"@spqr:spqr"},
  341. EventTypes.Create: {""},
  342. },
  343. include_others=False,
  344. ),
  345. )
  346. # (specific state keys) - (no state keys)
  347. self.assert_difference(
  348. StateFilter.freeze(
  349. {
  350. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  351. EventTypes.CanonicalAlias: {""},
  352. },
  353. include_others=True,
  354. ),
  355. StateFilter.freeze(
  356. {
  357. EventTypes.Member: set(),
  358. },
  359. include_others=True,
  360. ),
  361. StateFilter.freeze(
  362. {
  363. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  364. },
  365. include_others=False,
  366. ),
  367. )
  368. def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
  369. """
  370. Tests the StateFilter.approx_difference method
  371. where, in a.approx_difference(b), only b has the include_others flag set.
  372. """
  373. # (wildcard on state keys) - (wildcard on state keys):
  374. self.assert_difference(
  375. StateFilter.freeze(
  376. {EventTypes.Member: None, EventTypes.Create: None},
  377. include_others=False,
  378. ),
  379. StateFilter.freeze(
  380. {EventTypes.Member: None, EventTypes.CanonicalAlias: None},
  381. include_others=True,
  382. ),
  383. StateFilter(types=immutabledict(), include_others=False),
  384. )
  385. # (wildcard on state keys) - (specific state keys)
  386. # This one is an over-approximation because we can't represent
  387. # 'all state keys except a few named examples'
  388. self.assert_difference(
  389. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  390. StateFilter.freeze(
  391. {EventTypes.Member: {"@wombat:spqr"}},
  392. include_others=True,
  393. ),
  394. StateFilter.freeze({EventTypes.Member: None}, include_others=False),
  395. )
  396. # (wildcard on state keys) - (no state keys)
  397. self.assert_difference(
  398. StateFilter.freeze(
  399. {EventTypes.Member: None},
  400. include_others=False,
  401. ),
  402. StateFilter.freeze(
  403. {
  404. EventTypes.Member: set(),
  405. },
  406. include_others=True,
  407. ),
  408. StateFilter.freeze(
  409. {EventTypes.Member: None},
  410. include_others=False,
  411. ),
  412. )
  413. # (specific state keys) - (wildcard on state keys):
  414. self.assert_difference(
  415. StateFilter.freeze(
  416. {
  417. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  418. EventTypes.CanonicalAlias: {""},
  419. },
  420. include_others=False,
  421. ),
  422. StateFilter.freeze(
  423. {EventTypes.Member: None},
  424. include_others=True,
  425. ),
  426. StateFilter(
  427. types=immutabledict(),
  428. include_others=False,
  429. ),
  430. )
  431. # (specific state keys) - (specific state keys)
  432. # This one is an over-approximation because we can't represent
  433. # 'all state keys except a few named examples'
  434. self.assert_difference(
  435. StateFilter.freeze(
  436. {
  437. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  438. EventTypes.CanonicalAlias: {""},
  439. },
  440. include_others=False,
  441. ),
  442. StateFilter.freeze(
  443. {
  444. EventTypes.Member: {"@wombat:spqr"},
  445. },
  446. include_others=True,
  447. ),
  448. StateFilter.freeze(
  449. {
  450. EventTypes.Member: {"@spqr:spqr"},
  451. },
  452. include_others=False,
  453. ),
  454. )
  455. # (specific state keys) - (no state keys)
  456. self.assert_difference(
  457. StateFilter.freeze(
  458. {
  459. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  460. EventTypes.CanonicalAlias: {""},
  461. },
  462. include_others=False,
  463. ),
  464. StateFilter.freeze(
  465. {
  466. EventTypes.Member: set(),
  467. },
  468. include_others=True,
  469. ),
  470. StateFilter.freeze(
  471. {
  472. EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
  473. },
  474. include_others=False,
  475. ),
  476. )
  477. def test_state_filter_difference_simple_cases(self) -> None:
  478. """
  479. Tests some very simple cases of the StateFilter approx_difference,
  480. that are not explicitly tested by the more in-depth tests.
  481. """
  482. self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
  483. self.assert_difference(
  484. StateFilter.all(),
  485. StateFilter.none(),
  486. StateFilter.all(),
  487. )
  488. class StateFilterTestCase(TestCase):
  489. def test_return_expanded(self) -> None:
  490. """
  491. Tests the behaviour of the return_expanded() function that expands
  492. StateFilters to include more state types (for the sake of cache hit rate).
  493. """
  494. self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
  495. self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
  496. # Concrete-only state filters stay the same
  497. # (Case: mixed filter)
  498. self.assertEqual(
  499. StateFilter.freeze(
  500. {
  501. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  502. "some.other.state.type": {""},
  503. },
  504. include_others=False,
  505. ).return_expanded(),
  506. StateFilter.freeze(
  507. {
  508. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  509. "some.other.state.type": {""},
  510. },
  511. include_others=False,
  512. ),
  513. )
  514. # Concrete-only state filters stay the same
  515. # (Case: non-member-only filter)
  516. self.assertEqual(
  517. StateFilter.freeze(
  518. {"some.other.state.type": {""}}, include_others=False
  519. ).return_expanded(),
  520. StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
  521. )
  522. # Concrete-only state filters stay the same
  523. # (Case: member-only filter)
  524. self.assertEqual(
  525. StateFilter.freeze(
  526. {
  527. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  528. },
  529. include_others=False,
  530. ).return_expanded(),
  531. StateFilter.freeze(
  532. {
  533. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  534. },
  535. include_others=False,
  536. ),
  537. )
  538. # Wildcard member-only state filters stay the same
  539. self.assertEqual(
  540. StateFilter.freeze(
  541. {EventTypes.Member: None},
  542. include_others=False,
  543. ).return_expanded(),
  544. StateFilter.freeze(
  545. {EventTypes.Member: None},
  546. include_others=False,
  547. ),
  548. )
  549. # If there is a wildcard in the non-member portion of the filter,
  550. # it's expanded to include ALL non-member events.
  551. # (Case: mixed filter)
  552. self.assertEqual(
  553. StateFilter.freeze(
  554. {
  555. EventTypes.Member: {"@wombat:test", "@alicia:test"},
  556. "some.other.state.type": None,
  557. },
  558. include_others=False,
  559. ).return_expanded(),
  560. StateFilter.freeze(
  561. {EventTypes.Member: {"@wombat:test", "@alicia:test"}},
  562. include_others=True,
  563. ),
  564. )
  565. # If there is a wildcard in the non-member portion of the filter,
  566. # it's expanded to include ALL non-member events.
  567. # (Case: non-member-only filter)
  568. self.assertEqual(
  569. StateFilter.freeze(
  570. {
  571. "some.other.state.type": None,
  572. },
  573. include_others=False,
  574. ).return_expanded(),
  575. StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
  576. )
  577. self.assertEqual(
  578. StateFilter.freeze(
  579. {
  580. "some.other.state.type": None,
  581. "yet.another.state.type": {"wombat"},
  582. },
  583. include_others=False,
  584. ).return_expanded(),
  585. StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
  586. )