test_cache.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # Copyright 2023 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from unittest.mock import Mock, call
  15. from synapse.storage.database import LoggingTransaction
  16. from tests.replication._base import BaseMultiWorkerStreamTestCase
  17. from tests.unittest import HomeserverTestCase
  18. class CacheInvalidationTestCase(HomeserverTestCase):
  19. def setUp(self) -> None:
  20. super().setUp()
  21. self.store = self.hs.get_datastores().main
  22. def test_bulk_invalidation(self) -> None:
  23. master_invalidate = Mock()
  24. self.store._get_cached_user_device.invalidate = master_invalidate
  25. keys_to_invalidate = [
  26. ("a", "b"),
  27. ("c", "d"),
  28. ("e", "f"),
  29. ("g", "h"),
  30. ]
  31. def test_txn(txn: LoggingTransaction) -> None:
  32. self.store._invalidate_cache_and_stream_bulk(
  33. txn,
  34. # This is an arbitrarily chosen cached store function. It was chosen
  35. # because it takes more than one argument. We'll use this later to
  36. # check that the invalidation was actioned over replication.
  37. cache_func=self.store._get_cached_user_device,
  38. key_tuples=keys_to_invalidate,
  39. )
  40. self.get_success(
  41. self.store.db_pool.runInteraction(
  42. "test_invalidate_cache_and_stream_bulk", test_txn
  43. )
  44. )
  45. master_invalidate.assert_has_calls(
  46. [call(key_list) for key_list in keys_to_invalidate],
  47. any_order=True,
  48. )
  49. class CacheInvalidationOverReplicationTestCase(BaseMultiWorkerStreamTestCase):
  50. def setUp(self) -> None:
  51. super().setUp()
  52. self.store = self.hs.get_datastores().main
  53. def test_bulk_invalidation_replicates(self) -> None:
  54. """Like test_bulk_invalidation, but also checks the invalidations replicate."""
  55. master_invalidate = Mock()
  56. worker_invalidate = Mock()
  57. self.store._get_cached_user_device.invalidate = master_invalidate
  58. worker = self.make_worker_hs("synapse.app.generic_worker")
  59. worker_ds = worker.get_datastores().main
  60. worker_ds._get_cached_user_device.invalidate = worker_invalidate
  61. keys_to_invalidate = [
  62. ("a", "b"),
  63. ("c", "d"),
  64. ("e", "f"),
  65. ("g", "h"),
  66. ]
  67. def test_txn(txn: LoggingTransaction) -> None:
  68. self.store._invalidate_cache_and_stream_bulk(
  69. txn,
  70. # This is an arbitrarily chosen cached store function. It was chosen
  71. # because it takes more than one argument. We'll use this later to
  72. # check that the invalidation was actioned over replication.
  73. cache_func=self.store._get_cached_user_device,
  74. key_tuples=keys_to_invalidate,
  75. )
  76. assert self.store._cache_id_gen is not None
  77. initial_token = self.store._cache_id_gen.get_current_token()
  78. self.get_success(
  79. self.database_pool.runInteraction(
  80. "test_invalidate_cache_and_stream_bulk", test_txn
  81. )
  82. )
  83. second_token = self.store._cache_id_gen.get_current_token()
  84. self.assertGreaterEqual(second_token, initial_token + len(keys_to_invalidate))
  85. self.get_success(
  86. worker.get_replication_data_handler().wait_for_stream_position(
  87. "master", "caches", second_token
  88. )
  89. )
  90. master_invalidate.assert_has_calls(
  91. [call(key_list) for key_list in keys_to_invalidate],
  92. any_order=True,
  93. )
  94. worker_invalidate.assert_has_calls(
  95. [call(key_list) for key_list in keys_to_invalidate],
  96. any_order=True,
  97. )