test__base.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. # Copyright 2019 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from mock import Mock
  17. from twisted.internet import defer
  18. from synapse.util.async_helpers import ObservableDeferred
  19. from synapse.util.caches.descriptors import Cache, cached
  20. from tests import unittest
  21. class CacheTestCase(unittest.HomeserverTestCase):
  22. def prepare(self, reactor, clock, homeserver):
  23. self.cache = Cache("test")
  24. def test_empty(self):
  25. failed = False
  26. try:
  27. self.cache.get("foo")
  28. except KeyError:
  29. failed = True
  30. self.assertTrue(failed)
  31. def test_hit(self):
  32. self.cache.prefill("foo", 123)
  33. self.assertEquals(self.cache.get("foo"), 123)
  34. def test_invalidate(self):
  35. self.cache.prefill(("foo",), 123)
  36. self.cache.invalidate(("foo",))
  37. failed = False
  38. try:
  39. self.cache.get(("foo",))
  40. except KeyError:
  41. failed = True
  42. self.assertTrue(failed)
  43. def test_eviction(self):
  44. cache = Cache("test", max_entries=2)
  45. cache.prefill(1, "one")
  46. cache.prefill(2, "two")
  47. cache.prefill(3, "three") # 1 will be evicted
  48. failed = False
  49. try:
  50. cache.get(1)
  51. except KeyError:
  52. failed = True
  53. self.assertTrue(failed)
  54. cache.get(2)
  55. cache.get(3)
  56. def test_eviction_lru(self):
  57. cache = Cache("test", max_entries=2)
  58. cache.prefill(1, "one")
  59. cache.prefill(2, "two")
  60. # Now access 1 again, thus causing 2 to be least-recently used
  61. cache.get(1)
  62. cache.prefill(3, "three")
  63. failed = False
  64. try:
  65. cache.get(2)
  66. except KeyError:
  67. failed = True
  68. self.assertTrue(failed)
  69. cache.get(1)
  70. cache.get(3)
  71. class CacheDecoratorTestCase(unittest.HomeserverTestCase):
  72. @defer.inlineCallbacks
  73. def test_passthrough(self):
  74. class A:
  75. @cached()
  76. def func(self, key):
  77. return key
  78. a = A()
  79. self.assertEquals((yield a.func("foo")), "foo")
  80. self.assertEquals((yield a.func("bar")), "bar")
  81. @defer.inlineCallbacks
  82. def test_hit(self):
  83. callcount = [0]
  84. class A:
  85. @cached()
  86. def func(self, key):
  87. callcount[0] += 1
  88. return key
  89. a = A()
  90. yield a.func("foo")
  91. self.assertEquals(callcount[0], 1)
  92. self.assertEquals((yield a.func("foo")), "foo")
  93. self.assertEquals(callcount[0], 1)
  94. @defer.inlineCallbacks
  95. def test_invalidate(self):
  96. callcount = [0]
  97. class A:
  98. @cached()
  99. def func(self, key):
  100. callcount[0] += 1
  101. return key
  102. a = A()
  103. yield a.func("foo")
  104. self.assertEquals(callcount[0], 1)
  105. a.func.invalidate(("foo",))
  106. yield a.func("foo")
  107. self.assertEquals(callcount[0], 2)
  108. def test_invalidate_missing(self):
  109. class A:
  110. @cached()
  111. def func(self, key):
  112. return key
  113. A().func.invalidate(("what",))
  114. @defer.inlineCallbacks
  115. def test_max_entries(self):
  116. callcount = [0]
  117. class A:
  118. @cached(max_entries=10)
  119. def func(self, key):
  120. callcount[0] += 1
  121. return key
  122. a = A()
  123. for k in range(0, 12):
  124. yield a.func(k)
  125. self.assertEquals(callcount[0], 12)
  126. # There must have been at least 2 evictions, meaning if we calculate
  127. # all 12 values again, we must get called at least 2 more times
  128. for k in range(0, 12):
  129. yield a.func(k)
  130. self.assertTrue(
  131. callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
  132. )
  133. def test_prefill(self):
  134. callcount = [0]
  135. d = defer.succeed(123)
  136. class A:
  137. @cached()
  138. def func(self, key):
  139. callcount[0] += 1
  140. return d
  141. a = A()
  142. a.func.prefill(("foo",), ObservableDeferred(d))
  143. self.assertEquals(a.func("foo").result, d.result)
  144. self.assertEquals(callcount[0], 0)
  145. @defer.inlineCallbacks
  146. def test_invalidate_context(self):
  147. callcount = [0]
  148. callcount2 = [0]
  149. class A:
  150. @cached()
  151. def func(self, key):
  152. callcount[0] += 1
  153. return key
  154. @cached(cache_context=True)
  155. def func2(self, key, cache_context):
  156. callcount2[0] += 1
  157. return self.func(key, on_invalidate=cache_context.invalidate)
  158. a = A()
  159. yield a.func2("foo")
  160. self.assertEquals(callcount[0], 1)
  161. self.assertEquals(callcount2[0], 1)
  162. a.func.invalidate(("foo",))
  163. yield a.func("foo")
  164. self.assertEquals(callcount[0], 2)
  165. self.assertEquals(callcount2[0], 1)
  166. yield a.func2("foo")
  167. self.assertEquals(callcount[0], 2)
  168. self.assertEquals(callcount2[0], 2)
  169. @defer.inlineCallbacks
  170. def test_eviction_context(self):
  171. callcount = [0]
  172. callcount2 = [0]
  173. class A:
  174. @cached(max_entries=2)
  175. def func(self, key):
  176. callcount[0] += 1
  177. return key
  178. @cached(cache_context=True)
  179. def func2(self, key, cache_context):
  180. callcount2[0] += 1
  181. return self.func(key, on_invalidate=cache_context.invalidate)
  182. a = A()
  183. yield a.func2("foo")
  184. yield a.func2("foo2")
  185. self.assertEquals(callcount[0], 2)
  186. self.assertEquals(callcount2[0], 2)
  187. yield a.func2("foo")
  188. self.assertEquals(callcount[0], 2)
  189. self.assertEquals(callcount2[0], 2)
  190. yield a.func("foo3")
  191. self.assertEquals(callcount[0], 3)
  192. self.assertEquals(callcount2[0], 2)
  193. yield a.func2("foo")
  194. self.assertEquals(callcount[0], 4)
  195. self.assertEquals(callcount2[0], 3)
  196. @defer.inlineCallbacks
  197. def test_double_get(self):
  198. callcount = [0]
  199. callcount2 = [0]
  200. class A:
  201. @cached()
  202. def func(self, key):
  203. callcount[0] += 1
  204. return key
  205. @cached(cache_context=True)
  206. def func2(self, key, cache_context):
  207. callcount2[0] += 1
  208. return self.func(key, on_invalidate=cache_context.invalidate)
  209. a = A()
  210. a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
  211. yield a.func2("foo")
  212. self.assertEquals(callcount[0], 1)
  213. self.assertEquals(callcount2[0], 1)
  214. a.func2.invalidate(("foo",))
  215. self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
  216. yield a.func2("foo")
  217. a.func2.invalidate(("foo",))
  218. self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
  219. self.assertEquals(callcount[0], 1)
  220. self.assertEquals(callcount2[0], 2)
  221. a.func.invalidate(("foo",))
  222. self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
  223. yield a.func("foo")
  224. self.assertEquals(callcount[0], 2)
  225. self.assertEquals(callcount2[0], 2)
  226. yield a.func2("foo")
  227. self.assertEquals(callcount[0], 2)
  228. self.assertEquals(callcount2[0], 3)
  229. class UpsertManyTests(unittest.HomeserverTestCase):
  230. def prepare(self, reactor, clock, hs):
  231. self.storage = hs.get_datastore()
  232. self.table_name = "table_" + hs.get_secrets().token_hex(6)
  233. self.get_success(
  234. self.storage.db_pool.runInteraction(
  235. "create",
  236. lambda x, *a: x.execute(*a),
  237. "CREATE TABLE %s (id INTEGER, username TEXT, value TEXT)"
  238. % (self.table_name,),
  239. )
  240. )
  241. self.get_success(
  242. self.storage.db_pool.runInteraction(
  243. "index",
  244. lambda x, *a: x.execute(*a),
  245. "CREATE UNIQUE INDEX %sindex ON %s(id, username)"
  246. % (self.table_name, self.table_name),
  247. )
  248. )
  249. def _dump_to_tuple(self, res):
  250. for i in res:
  251. yield (i["id"], i["username"], i["value"])
  252. def test_upsert_many(self):
  253. """
  254. Upsert_many will perform the upsert operation across a batch of data.
  255. """
  256. # Add some data to an empty table
  257. key_names = ["id", "username"]
  258. value_names = ["value"]
  259. key_values = [[1, "user1"], [2, "user2"]]
  260. value_values = [["hello"], ["there"]]
  261. self.get_success(
  262. self.storage.db_pool.runInteraction(
  263. "test",
  264. self.storage.db_pool.simple_upsert_many_txn,
  265. self.table_name,
  266. key_names,
  267. key_values,
  268. value_names,
  269. value_values,
  270. )
  271. )
  272. # Check results are what we expect
  273. res = self.get_success(
  274. self.storage.db_pool.simple_select_list(
  275. self.table_name, None, ["id, username, value"]
  276. )
  277. )
  278. self.assertEqual(
  279. set(self._dump_to_tuple(res)),
  280. {(1, "user1", "hello"), (2, "user2", "there")},
  281. )
  282. # Update only user2
  283. key_values = [[2, "user2"]]
  284. value_values = [["bleb"]]
  285. self.get_success(
  286. self.storage.db_pool.runInteraction(
  287. "test",
  288. self.storage.db_pool.simple_upsert_many_txn,
  289. self.table_name,
  290. key_names,
  291. key_values,
  292. value_names,
  293. value_values,
  294. )
  295. )
  296. # Check results are what we expect
  297. res = self.get_success(
  298. self.storage.db_pool.simple_select_list(
  299. self.table_name, None, ["id, username, value"]
  300. )
  301. )
  302. self.assertEqual(
  303. set(self._dump_to_tuple(res)),
  304. {(1, "user1", "hello"), (2, "user2", "bleb")},
  305. )