test_state_store.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. from typing import Dict, List, Sequence, Tuple
  2. from unittest.mock import patch
  3. from twisted.internet.defer import Deferred, ensureDeferred
  4. from synapse.storage.state import StateFilter
  5. from synapse.types import MutableStateMap, StateMap
  6. from tests.unittest import HomeserverTestCase
  7. class StateGroupInflightCachingTestCase(HomeserverTestCase):
  8. def setUp(self) -> None:
  9. super(StateGroupInflightCachingTestCase, self).setUp()
  10. # Patch out the `_get_state_groups_from_groups`.
  11. # This is useful because it lets us pretend we have a slow database.
  12. gsgfg_patch = patch(
  13. "synapse.storage.databases.state.store.StateGroupDataStore._get_state_groups_from_groups",
  14. self._fake_get_state_groups_from_groups,
  15. )
  16. gsgfg_patch.start()
  17. self.addCleanup(gsgfg_patch.stop)
  18. self.gsgfg_calls: List[
  19. Tuple[Tuple[int, ...], StateFilter, Deferred[Dict[int, StateMap[str]]]]
  20. ] = []
  21. def prepare(self, reactor, clock, homeserver) -> None:
  22. super(StateGroupInflightCachingTestCase, self).prepare(
  23. reactor, clock, homeserver
  24. )
  25. self.state_storage = homeserver.get_storage().state
  26. self.state_datastore = homeserver.get_datastores().state
  27. def _fake_get_state_groups_from_groups(
  28. self, groups: Sequence[int], state_filter: StateFilter
  29. ) -> "Deferred[Dict[int, StateMap[str]]]":
  30. print("hi", groups, state_filter)
  31. d: Deferred[Dict[int, StateMap[str]]] = Deferred()
  32. self.gsgfg_calls.append((tuple(groups), state_filter, d))
  33. return d
  34. def _complete_request_fake(
  35. self,
  36. groups: Tuple[int, ...],
  37. state_filter: StateFilter,
  38. d: "Deferred[Dict[int, StateMap[str]]]",
  39. ) -> None:
  40. """
  41. Assemble a fake database response and complete the database request.
  42. """
  43. result: Dict[int, StateMap[str]] = {}
  44. for group in groups:
  45. group_result: MutableStateMap[str] = {}
  46. result[group] = group_result
  47. for state_type, state_keys in state_filter.types.items():
  48. if state_keys is None:
  49. group_result[
  50. (state_type, "wild wombat")
  51. ] = f"{group} {state_type} wild wombat"
  52. group_result[
  53. (state_type, "wild spqr")
  54. ] = f"{group} {state_type} wild spqr"
  55. else:
  56. for state_key in state_keys:
  57. group_result[
  58. (state_type, state_key)
  59. ] = f"{group} {state_type} {state_key}"
  60. if state_filter.include_others:
  61. group_result[("something.else", "wild")] = "card"
  62. d.callback(result)
  63. def test_duplicate_requests_deduplicated(self) -> None:
  64. req1 = ensureDeferred(
  65. self.state_datastore._get_state_for_group_using_inflight_cache(
  66. 42, StateFilter.all()
  67. )
  68. )
  69. self.pump(by=0.1)
  70. # This should have gone to the database
  71. self.assertEqual(len(self.gsgfg_calls), 1)
  72. self.assertFalse(req1.called)
  73. req2 = ensureDeferred(
  74. self.state_datastore._get_state_for_group_using_inflight_cache(
  75. 42, StateFilter.all()
  76. )
  77. )
  78. self.pump(by=0.1)
  79. # No more calls should have gone to the database
  80. self.assertEqual(len(self.gsgfg_calls), 1)
  81. self.assertFalse(req1.called)
  82. self.assertFalse(req2.called)
  83. groups, sf, d = self.gsgfg_calls[0]
  84. self.assertEqual(groups, (42,))
  85. self.assertEqual(sf, StateFilter.all())
  86. # Now we can complete the request
  87. self._complete_request_fake(groups, sf, d)
  88. self.assertEqual(self.get_success(req1), {("something.else", "wild"): "card"})
  89. self.assertEqual(self.get_success(req2), {("something.else", "wild"): "card"})