test_descriptors.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from functools import partial
  18. import mock
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import SynapseError
  21. from synapse.logging.context import (
  22. LoggingContext,
  23. PreserveLoggingContext,
  24. make_deferred_yieldable,
  25. )
  26. from synapse.util.caches import descriptors
  27. from synapse.util.caches.descriptors import cached
  28. from tests import unittest
  29. logger = logging.getLogger(__name__)
  30. def run_on_reactor():
  31. d = defer.Deferred()
  32. reactor.callLater(0, d.callback, 0)
  33. return make_deferred_yieldable(d)
  34. class CacheTestCase(unittest.TestCase):
  35. def test_invalidate_all(self):
  36. cache = descriptors.Cache("testcache")
  37. callback_record = [False, False]
  38. def record_callback(idx):
  39. callback_record[idx] = True
  40. # add a couple of pending entries
  41. d1 = defer.Deferred()
  42. cache.set("key1", d1, partial(record_callback, 0))
  43. d2 = defer.Deferred()
  44. cache.set("key2", d2, partial(record_callback, 1))
  45. # lookup should return observable deferreds
  46. self.assertFalse(cache.get("key1").has_called())
  47. self.assertFalse(cache.get("key2").has_called())
  48. # let one of the lookups complete
  49. d2.callback("result2")
  50. # for now at least, the cache will return real results rather than an
  51. # observabledeferred
  52. self.assertEqual(cache.get("key2"), "result2")
  53. # now do the invalidation
  54. cache.invalidate_all()
  55. # lookup should return none
  56. self.assertIsNone(cache.get("key1", None))
  57. self.assertIsNone(cache.get("key2", None))
  58. # both callbacks should have been callbacked
  59. self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
  60. self.assertTrue(callback_record[1], "Invalidation callback for key2 not called")
  61. # letting the other lookup complete should do nothing
  62. d1.callback("result1")
  63. self.assertIsNone(cache.get("key1", None))
  64. class DescriptorTestCase(unittest.TestCase):
  65. @defer.inlineCallbacks
  66. def test_cache(self):
  67. class Cls(object):
  68. def __init__(self):
  69. self.mock = mock.Mock()
  70. @descriptors.cached()
  71. def fn(self, arg1, arg2):
  72. return self.mock(arg1, arg2)
  73. obj = Cls()
  74. obj.mock.return_value = "fish"
  75. r = yield obj.fn(1, 2)
  76. self.assertEqual(r, "fish")
  77. obj.mock.assert_called_once_with(1, 2)
  78. obj.mock.reset_mock()
  79. # a call with different params should call the mock again
  80. obj.mock.return_value = "chips"
  81. r = yield obj.fn(1, 3)
  82. self.assertEqual(r, "chips")
  83. obj.mock.assert_called_once_with(1, 3)
  84. obj.mock.reset_mock()
  85. # the two values should now be cached
  86. r = yield obj.fn(1, 2)
  87. self.assertEqual(r, "fish")
  88. r = yield obj.fn(1, 3)
  89. self.assertEqual(r, "chips")
  90. obj.mock.assert_not_called()
  91. @defer.inlineCallbacks
  92. def test_cache_num_args(self):
  93. """Only the first num_args arguments should matter to the cache"""
  94. class Cls(object):
  95. def __init__(self):
  96. self.mock = mock.Mock()
  97. @descriptors.cached(num_args=1)
  98. def fn(self, arg1, arg2):
  99. return self.mock(arg1, arg2)
  100. obj = Cls()
  101. obj.mock.return_value = "fish"
  102. r = yield obj.fn(1, 2)
  103. self.assertEqual(r, "fish")
  104. obj.mock.assert_called_once_with(1, 2)
  105. obj.mock.reset_mock()
  106. # a call with different params should call the mock again
  107. obj.mock.return_value = "chips"
  108. r = yield obj.fn(2, 3)
  109. self.assertEqual(r, "chips")
  110. obj.mock.assert_called_once_with(2, 3)
  111. obj.mock.reset_mock()
  112. # the two values should now be cached; we should be able to vary
  113. # the second argument and still get the cached result.
  114. r = yield obj.fn(1, 4)
  115. self.assertEqual(r, "fish")
  116. r = yield obj.fn(2, 5)
  117. self.assertEqual(r, "chips")
  118. obj.mock.assert_not_called()
  119. def test_cache_with_sync_exception(self):
  120. """If the wrapped function throws synchronously, things should continue to work
  121. """
  122. class Cls(object):
  123. @cached()
  124. def fn(self, arg1):
  125. raise SynapseError(100, "mai spoon iz too big!!1")
  126. obj = Cls()
  127. # this should fail immediately
  128. d = obj.fn(1)
  129. self.failureResultOf(d, SynapseError)
  130. # ... leaving the cache empty
  131. self.assertEqual(len(obj.fn.cache.cache), 0)
  132. # and a second call should result in a second exception
  133. d = obj.fn(1)
  134. self.failureResultOf(d, SynapseError)
  135. def test_cache_logcontexts(self):
  136. """Check that logcontexts are set and restored correctly when
  137. using the cache."""
  138. complete_lookup = defer.Deferred()
  139. class Cls(object):
  140. @descriptors.cached()
  141. def fn(self, arg1):
  142. @defer.inlineCallbacks
  143. def inner_fn():
  144. with PreserveLoggingContext():
  145. yield complete_lookup
  146. return 1
  147. return inner_fn()
  148. @defer.inlineCallbacks
  149. def do_lookup():
  150. with LoggingContext() as c1:
  151. c1.name = "c1"
  152. r = yield obj.fn(1)
  153. self.assertEqual(LoggingContext.current_context(), c1)
  154. return r
  155. def check_result(r):
  156. self.assertEqual(r, 1)
  157. obj = Cls()
  158. # set off a deferred which will do a cache lookup
  159. d1 = do_lookup()
  160. self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
  161. d1.addCallback(check_result)
  162. # and another
  163. d2 = do_lookup()
  164. self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
  165. d2.addCallback(check_result)
  166. # let the lookup complete
  167. complete_lookup.callback(None)
  168. return defer.gatherResults([d1, d2])
  169. def test_cache_logcontexts_with_exception(self):
  170. """Check that the cache sets and restores logcontexts correctly when
  171. the lookup function throws an exception"""
  172. class Cls(object):
  173. @descriptors.cached()
  174. def fn(self, arg1):
  175. @defer.inlineCallbacks
  176. def inner_fn():
  177. # we want this to behave like an asynchronous function
  178. yield run_on_reactor()
  179. raise SynapseError(400, "blah")
  180. return inner_fn()
  181. @defer.inlineCallbacks
  182. def do_lookup():
  183. with LoggingContext() as c1:
  184. c1.name = "c1"
  185. try:
  186. d = obj.fn(1)
  187. self.assertEqual(
  188. LoggingContext.current_context(), LoggingContext.sentinel
  189. )
  190. yield d
  191. self.fail("No exception thrown")
  192. except SynapseError:
  193. pass
  194. self.assertEqual(LoggingContext.current_context(), c1)
  195. # the cache should now be empty
  196. self.assertEqual(len(obj.fn.cache.cache), 0)
  197. obj = Cls()
  198. # set off a deferred which will do a cache lookup
  199. d1 = do_lookup()
  200. self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
  201. return d1
  202. @defer.inlineCallbacks
  203. def test_cache_default_args(self):
  204. class Cls(object):
  205. def __init__(self):
  206. self.mock = mock.Mock()
  207. @descriptors.cached()
  208. def fn(self, arg1, arg2=2, arg3=3):
  209. return self.mock(arg1, arg2, arg3)
  210. obj = Cls()
  211. obj.mock.return_value = "fish"
  212. r = yield obj.fn(1, 2, 3)
  213. self.assertEqual(r, "fish")
  214. obj.mock.assert_called_once_with(1, 2, 3)
  215. obj.mock.reset_mock()
  216. # a call with same params shouldn't call the mock again
  217. r = yield obj.fn(1, 2)
  218. self.assertEqual(r, "fish")
  219. obj.mock.assert_not_called()
  220. obj.mock.reset_mock()
  221. # a call with different params should call the mock again
  222. obj.mock.return_value = "chips"
  223. r = yield obj.fn(2, 3)
  224. self.assertEqual(r, "chips")
  225. obj.mock.assert_called_once_with(2, 3, 3)
  226. obj.mock.reset_mock()
  227. # the two values should now be cached
  228. r = yield obj.fn(1, 2)
  229. self.assertEqual(r, "fish")
  230. r = yield obj.fn(2, 3)
  231. self.assertEqual(r, "chips")
  232. obj.mock.assert_not_called()
  233. def test_cache_iterable(self):
  234. class Cls(object):
  235. def __init__(self):
  236. self.mock = mock.Mock()
  237. @descriptors.cached(iterable=True)
  238. def fn(self, arg1, arg2):
  239. return self.mock(arg1, arg2)
  240. obj = Cls()
  241. obj.mock.return_value = ["spam", "eggs"]
  242. r = obj.fn(1, 2)
  243. self.assertEqual(r, ["spam", "eggs"])
  244. obj.mock.assert_called_once_with(1, 2)
  245. obj.mock.reset_mock()
  246. # a call with different params should call the mock again
  247. obj.mock.return_value = ["chips"]
  248. r = obj.fn(1, 3)
  249. self.assertEqual(r, ["chips"])
  250. obj.mock.assert_called_once_with(1, 3)
  251. obj.mock.reset_mock()
  252. # the two values should now be cached
  253. self.assertEqual(len(obj.fn.cache.cache), 3)
  254. r = obj.fn(1, 2)
  255. self.assertEqual(r, ["spam", "eggs"])
  256. r = obj.fn(1, 3)
  257. self.assertEqual(r, ["chips"])
  258. obj.mock.assert_not_called()
  259. def test_cache_iterable_with_sync_exception(self):
  260. """If the wrapped function throws synchronously, things should continue to work
  261. """
  262. class Cls(object):
  263. @descriptors.cached(iterable=True)
  264. def fn(self, arg1):
  265. raise SynapseError(100, "mai spoon iz too big!!1")
  266. obj = Cls()
  267. # this should fail immediately
  268. d = obj.fn(1)
  269. self.failureResultOf(d, SynapseError)
  270. # ... leaving the cache empty
  271. self.assertEqual(len(obj.fn.cache.cache), 0)
  272. # and a second call should result in a second exception
  273. d = obj.fn(1)
  274. self.failureResultOf(d, SynapseError)
  275. class CachedListDescriptorTestCase(unittest.TestCase):
  276. @defer.inlineCallbacks
  277. def test_cache(self):
  278. class Cls(object):
  279. def __init__(self):
  280. self.mock = mock.Mock()
  281. @descriptors.cached()
  282. def fn(self, arg1, arg2):
  283. pass
  284. @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
  285. def list_fn(self, args1, arg2):
  286. assert LoggingContext.current_context().request == "c1"
  287. # we want this to behave like an asynchronous function
  288. yield run_on_reactor()
  289. assert LoggingContext.current_context().request == "c1"
  290. return self.mock(args1, arg2)
  291. with LoggingContext() as c1:
  292. c1.request = "c1"
  293. obj = Cls()
  294. obj.mock.return_value = {10: "fish", 20: "chips"}
  295. d1 = obj.list_fn([10, 20], 2)
  296. self.assertEqual(LoggingContext.current_context(), LoggingContext.sentinel)
  297. r = yield d1
  298. self.assertEqual(LoggingContext.current_context(), c1)
  299. obj.mock.assert_called_once_with([10, 20], 2)
  300. self.assertEqual(r, {10: "fish", 20: "chips"})
  301. obj.mock.reset_mock()
  302. # a call with different params should call the mock again
  303. obj.mock.return_value = {30: "peas"}
  304. r = yield obj.list_fn([20, 30], 2)
  305. obj.mock.assert_called_once_with([30], 2)
  306. self.assertEqual(r, {20: "chips", 30: "peas"})
  307. obj.mock.reset_mock()
  308. # all the values should now be cached
  309. r = yield obj.fn(10, 2)
  310. self.assertEqual(r, "fish")
  311. r = yield obj.fn(20, 2)
  312. self.assertEqual(r, "chips")
  313. r = yield obj.fn(30, 2)
  314. self.assertEqual(r, "peas")
  315. r = yield obj.list_fn([10, 20, 30], 2)
  316. obj.mock.assert_not_called()
  317. self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
  318. @defer.inlineCallbacks
  319. def test_invalidate(self):
  320. """Make sure that invalidation callbacks are called."""
  321. class Cls(object):
  322. def __init__(self):
  323. self.mock = mock.Mock()
  324. @descriptors.cached()
  325. def fn(self, arg1, arg2):
  326. pass
  327. @descriptors.cachedList("fn", "args1", inlineCallbacks=True)
  328. def list_fn(self, args1, arg2):
  329. # we want this to behave like an asynchronous function
  330. yield run_on_reactor()
  331. return self.mock(args1, arg2)
  332. obj = Cls()
  333. invalidate0 = mock.Mock()
  334. invalidate1 = mock.Mock()
  335. # cache miss
  336. obj.mock.return_value = {10: "fish", 20: "chips"}
  337. r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
  338. obj.mock.assert_called_once_with([10, 20], 2)
  339. self.assertEqual(r1, {10: "fish", 20: "chips"})
  340. obj.mock.reset_mock()
  341. # cache hit
  342. r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
  343. obj.mock.assert_not_called()
  344. self.assertEqual(r2, {10: "fish", 20: "chips"})
  345. invalidate0.assert_not_called()
  346. invalidate1.assert_not_called()
  347. # now if we invalidate the keys, both invalidations should get called
  348. obj.fn.invalidate((10, 2))
  349. invalidate0.assert_called_once()
  350. invalidate1.assert_called_once()