test_descriptors.py 22 KB


  1. # -*- coding: utf-8 -*-
  2. # Copyright 2016 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import logging
  17. from typing import Set
  18. import mock
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import SynapseError
  21. from synapse.logging.context import (
  22. SENTINEL_CONTEXT,
  23. LoggingContext,
  24. PreserveLoggingContext,
  25. current_context,
  26. make_deferred_yieldable,
  27. )
  28. from synapse.util.caches import descriptors
  29. from synapse.util.caches.descriptors import cached, lru_cache
  30. from tests import unittest
  31. from tests.test_utils import get_awaitable_result
  32. logger = logging.getLogger(__name__)
  33. class LruCacheDecoratorTestCase(unittest.TestCase):
  34. def test_base(self):
  35. class Cls:
  36. def __init__(self):
  37. self.mock = mock.Mock()
  38. @lru_cache()
  39. def fn(self, arg1, arg2):
  40. return self.mock(arg1, arg2)
  41. obj = Cls()
  42. obj.mock.return_value = "fish"
  43. r = obj.fn(1, 2)
  44. self.assertEqual(r, "fish")
  45. obj.mock.assert_called_once_with(1, 2)
  46. obj.mock.reset_mock()
  47. # a call with different params should call the mock again
  48. obj.mock.return_value = "chips"
  49. r = obj.fn(1, 3)
  50. self.assertEqual(r, "chips")
  51. obj.mock.assert_called_once_with(1, 3)
  52. obj.mock.reset_mock()
  53. # the two values should now be cached
  54. r = obj.fn(1, 2)
  55. self.assertEqual(r, "fish")
  56. r = obj.fn(1, 3)
  57. self.assertEqual(r, "chips")
  58. obj.mock.assert_not_called()
  59. def run_on_reactor():
  60. d = defer.Deferred()
  61. reactor.callLater(0, d.callback, 0)
  62. return make_deferred_yieldable(d)
  63. class DescriptorTestCase(unittest.TestCase):
  64. @defer.inlineCallbacks
  65. def test_cache(self):
  66. class Cls:
  67. def __init__(self):
  68. self.mock = mock.Mock()
  69. @descriptors.cached()
  70. def fn(self, arg1, arg2):
  71. return self.mock(arg1, arg2)
  72. obj = Cls()
  73. obj.mock.return_value = "fish"
  74. r = yield obj.fn(1, 2)
  75. self.assertEqual(r, "fish")
  76. obj.mock.assert_called_once_with(1, 2)
  77. obj.mock.reset_mock()
  78. # a call with different params should call the mock again
  79. obj.mock.return_value = "chips"
  80. r = yield obj.fn(1, 3)
  81. self.assertEqual(r, "chips")
  82. obj.mock.assert_called_once_with(1, 3)
  83. obj.mock.reset_mock()
  84. # the two values should now be cached
  85. r = yield obj.fn(1, 2)
  86. self.assertEqual(r, "fish")
  87. r = yield obj.fn(1, 3)
  88. self.assertEqual(r, "chips")
  89. obj.mock.assert_not_called()
  90. @defer.inlineCallbacks
  91. def test_cache_num_args(self):
  92. """Only the first num_args arguments should matter to the cache"""
  93. class Cls:
  94. def __init__(self):
  95. self.mock = mock.Mock()
  96. @descriptors.cached(num_args=1)
  97. def fn(self, arg1, arg2):
  98. return self.mock(arg1, arg2)
  99. obj = Cls()
  100. obj.mock.return_value = "fish"
  101. r = yield obj.fn(1, 2)
  102. self.assertEqual(r, "fish")
  103. obj.mock.assert_called_once_with(1, 2)
  104. obj.mock.reset_mock()
  105. # a call with different params should call the mock again
  106. obj.mock.return_value = "chips"
  107. r = yield obj.fn(2, 3)
  108. self.assertEqual(r, "chips")
  109. obj.mock.assert_called_once_with(2, 3)
  110. obj.mock.reset_mock()
  111. # the two values should now be cached; we should be able to vary
  112. # the second argument and still get the cached result.
  113. r = yield obj.fn(1, 4)
  114. self.assertEqual(r, "fish")
  115. r = yield obj.fn(2, 5)
  116. self.assertEqual(r, "chips")
  117. obj.mock.assert_not_called()
  118. def test_cache_with_sync_exception(self):
  119. """If the wrapped function throws synchronously, things should continue to work"""
  120. class Cls:
  121. @cached()
  122. def fn(self, arg1):
  123. raise SynapseError(100, "mai spoon iz too big!!1")
  124. obj = Cls()
  125. # this should fail immediately
  126. d = obj.fn(1)
  127. self.failureResultOf(d, SynapseError)
  128. # ... leaving the cache empty
  129. self.assertEqual(len(obj.fn.cache.cache), 0)
  130. # and a second call should result in a second exception
  131. d = obj.fn(1)
  132. self.failureResultOf(d, SynapseError)
  133. def test_cache_with_async_exception(self):
  134. """The wrapped function returns a failure"""
  135. class Cls:
  136. result = None
  137. call_count = 0
  138. @cached()
  139. def fn(self, arg1):
  140. self.call_count += 1
  141. return self.result
  142. obj = Cls()
  143. callbacks = set() # type: Set[str]
  144. # set off an asynchronous request
  145. obj.result = origin_d = defer.Deferred()
  146. d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
  147. self.assertFalse(d1.called)
  148. # a second request should also return a deferred, but should not call the
  149. # function itself.
  150. d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
  151. self.assertFalse(d2.called)
  152. self.assertEqual(obj.call_count, 1)
  153. # no callbacks yet
  154. self.assertEqual(callbacks, set())
  155. # the original request fails
  156. e = Exception("bzz")
  157. origin_d.errback(e)
  158. # ... which should cause the lookups to fail similarly
  159. self.assertIs(self.failureResultOf(d1, Exception).value, e)
  160. self.assertIs(self.failureResultOf(d2, Exception).value, e)
  161. # ... and the callbacks to have been, uh, called.
  162. self.assertEqual(callbacks, {"d1", "d2"})
  163. # ... leaving the cache empty
  164. self.assertEqual(len(obj.fn.cache.cache), 0)
  165. # and a second call should work as normal
  166. obj.result = defer.succeed(100)
  167. d3 = obj.fn(1)
  168. self.assertEqual(self.successResultOf(d3), 100)
  169. self.assertEqual(obj.call_count, 2)
  170. def test_cache_logcontexts(self):
  171. """Check that logcontexts are set and restored correctly when
  172. using the cache."""
  173. complete_lookup = defer.Deferred()
  174. class Cls:
  175. @descriptors.cached()
  176. def fn(self, arg1):
  177. @defer.inlineCallbacks
  178. def inner_fn():
  179. with PreserveLoggingContext():
  180. yield complete_lookup
  181. return 1
  182. return inner_fn()
  183. @defer.inlineCallbacks
  184. def do_lookup():
  185. with LoggingContext() as c1:
  186. c1.name = "c1"
  187. r = yield obj.fn(1)
  188. self.assertEqual(current_context(), c1)
  189. return r
  190. def check_result(r):
  191. self.assertEqual(r, 1)
  192. obj = Cls()
  193. # set off a deferred which will do a cache lookup
  194. d1 = do_lookup()
  195. self.assertEqual(current_context(), SENTINEL_CONTEXT)
  196. d1.addCallback(check_result)
  197. # and another
  198. d2 = do_lookup()
  199. self.assertEqual(current_context(), SENTINEL_CONTEXT)
  200. d2.addCallback(check_result)
  201. # let the lookup complete
  202. complete_lookup.callback(None)
  203. return defer.gatherResults([d1, d2])
  204. def test_cache_logcontexts_with_exception(self):
  205. """Check that the cache sets and restores logcontexts correctly when
  206. the lookup function throws an exception"""
  207. class Cls:
  208. @descriptors.cached()
  209. def fn(self, arg1):
  210. @defer.inlineCallbacks
  211. def inner_fn():
  212. # we want this to behave like an asynchronous function
  213. yield run_on_reactor()
  214. raise SynapseError(400, "blah")
  215. return inner_fn()
  216. @defer.inlineCallbacks
  217. def do_lookup():
  218. with LoggingContext() as c1:
  219. c1.name = "c1"
  220. try:
  221. d = obj.fn(1)
  222. self.assertEqual(
  223. current_context(),
  224. SENTINEL_CONTEXT,
  225. )
  226. yield d
  227. self.fail("No exception thrown")
  228. except SynapseError:
  229. pass
  230. self.assertEqual(current_context(), c1)
  231. # the cache should now be empty
  232. self.assertEqual(len(obj.fn.cache.cache), 0)
  233. obj = Cls()
  234. # set off a deferred which will do a cache lookup
  235. d1 = do_lookup()
  236. self.assertEqual(current_context(), SENTINEL_CONTEXT)
  237. return d1
  238. @defer.inlineCallbacks
  239. def test_cache_default_args(self):
  240. class Cls:
  241. def __init__(self):
  242. self.mock = mock.Mock()
  243. @descriptors.cached()
  244. def fn(self, arg1, arg2=2, arg3=3):
  245. return self.mock(arg1, arg2, arg3)
  246. obj = Cls()
  247. obj.mock.return_value = "fish"
  248. r = yield obj.fn(1, 2, 3)
  249. self.assertEqual(r, "fish")
  250. obj.mock.assert_called_once_with(1, 2, 3)
  251. obj.mock.reset_mock()
  252. # a call with same params shouldn't call the mock again
  253. r = yield obj.fn(1, 2)
  254. self.assertEqual(r, "fish")
  255. obj.mock.assert_not_called()
  256. obj.mock.reset_mock()
  257. # a call with different params should call the mock again
  258. obj.mock.return_value = "chips"
  259. r = yield obj.fn(2, 3)
  260. self.assertEqual(r, "chips")
  261. obj.mock.assert_called_once_with(2, 3, 3)
  262. obj.mock.reset_mock()
  263. # the two values should now be cached
  264. r = yield obj.fn(1, 2)
  265. self.assertEqual(r, "fish")
  266. r = yield obj.fn(2, 3)
  267. self.assertEqual(r, "chips")
  268. obj.mock.assert_not_called()
  269. def test_cache_iterable(self):
  270. class Cls:
  271. def __init__(self):
  272. self.mock = mock.Mock()
  273. @descriptors.cached(iterable=True)
  274. def fn(self, arg1, arg2):
  275. return self.mock(arg1, arg2)
  276. obj = Cls()
  277. obj.mock.return_value = ["spam", "eggs"]
  278. r = obj.fn(1, 2)
  279. self.assertEqual(r.result, ["spam", "eggs"])
  280. obj.mock.assert_called_once_with(1, 2)
  281. obj.mock.reset_mock()
  282. # a call with different params should call the mock again
  283. obj.mock.return_value = ["chips"]
  284. r = obj.fn(1, 3)
  285. self.assertEqual(r.result, ["chips"])
  286. obj.mock.assert_called_once_with(1, 3)
  287. obj.mock.reset_mock()
  288. # the two values should now be cached
  289. self.assertEqual(len(obj.fn.cache.cache), 3)
  290. r = obj.fn(1, 2)
  291. self.assertEqual(r.result, ["spam", "eggs"])
  292. r = obj.fn(1, 3)
  293. self.assertEqual(r.result, ["chips"])
  294. obj.mock.assert_not_called()
  295. def test_cache_iterable_with_sync_exception(self):
  296. """If the wrapped function throws synchronously, things should continue to work"""
  297. class Cls:
  298. @descriptors.cached(iterable=True)
  299. def fn(self, arg1):
  300. raise SynapseError(100, "mai spoon iz too big!!1")
  301. obj = Cls()
  302. # this should fail immediately
  303. d = obj.fn(1)
  304. self.failureResultOf(d, SynapseError)
  305. # ... leaving the cache empty
  306. self.assertEqual(len(obj.fn.cache.cache), 0)
  307. # and a second call should result in a second exception
  308. d = obj.fn(1)
  309. self.failureResultOf(d, SynapseError)
  310. def test_invalidate_cascade(self):
  311. """Invalidations should cascade up through cache contexts"""
  312. class Cls:
  313. @cached(cache_context=True)
  314. async def func1(self, key, cache_context):
  315. return await self.func2(key, on_invalidate=cache_context.invalidate)
  316. @cached(cache_context=True)
  317. async def func2(self, key, cache_context):
  318. return self.func3(key, on_invalidate=cache_context.invalidate)
  319. @lru_cache(cache_context=True)
  320. def func3(self, key, cache_context):
  321. self.invalidate = cache_context.invalidate
  322. return 42
  323. obj = Cls()
  324. top_invalidate = mock.Mock()
  325. r = get_awaitable_result(obj.func1("k1", on_invalidate=top_invalidate))
  326. self.assertEqual(r, 42)
  327. obj.invalidate()
  328. top_invalidate.assert_called_once()
  329. class CacheDecoratorTestCase(unittest.HomeserverTestCase):
  330. """More tests for @cached
  331. The following is a set of tests that got lost in a different file for a while.
  332. There are probably duplicates of the tests in DescriptorTestCase. Ideally the
  333. duplicates would be removed and the two sets of classes combined.
  334. """
  335. @defer.inlineCallbacks
  336. def test_passthrough(self):
  337. class A:
  338. @cached()
  339. def func(self, key):
  340. return key
  341. a = A()
  342. self.assertEquals((yield a.func("foo")), "foo")
  343. self.assertEquals((yield a.func("bar")), "bar")
  344. @defer.inlineCallbacks
  345. def test_hit(self):
  346. callcount = [0]
  347. class A:
  348. @cached()
  349. def func(self, key):
  350. callcount[0] += 1
  351. return key
  352. a = A()
  353. yield a.func("foo")
  354. self.assertEquals(callcount[0], 1)
  355. self.assertEquals((yield a.func("foo")), "foo")
  356. self.assertEquals(callcount[0], 1)
  357. @defer.inlineCallbacks
  358. def test_invalidate(self):
  359. callcount = [0]
  360. class A:
  361. @cached()
  362. def func(self, key):
  363. callcount[0] += 1
  364. return key
  365. a = A()
  366. yield a.func("foo")
  367. self.assertEquals(callcount[0], 1)
  368. a.func.invalidate(("foo",))
  369. yield a.func("foo")
  370. self.assertEquals(callcount[0], 2)
  371. def test_invalidate_missing(self):
  372. class A:
  373. @cached()
  374. def func(self, key):
  375. return key
  376. A().func.invalidate(("what",))
  377. @defer.inlineCallbacks
  378. def test_max_entries(self):
  379. callcount = [0]
  380. class A:
  381. @cached(max_entries=10)
  382. def func(self, key):
  383. callcount[0] += 1
  384. return key
  385. a = A()
  386. for k in range(0, 12):
  387. yield a.func(k)
  388. self.assertEquals(callcount[0], 12)
  389. # There must have been at least 2 evictions, meaning if we calculate
  390. # all 12 values again, we must get called at least 2 more times
  391. for k in range(0, 12):
  392. yield a.func(k)
  393. self.assertTrue(
  394. callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
  395. )
  396. def test_prefill(self):
  397. callcount = [0]
  398. d = defer.succeed(123)
  399. class A:
  400. @cached()
  401. def func(self, key):
  402. callcount[0] += 1
  403. return d
  404. a = A()
  405. a.func.prefill(("foo",), 456)
  406. self.assertEquals(a.func("foo").result, 456)
  407. self.assertEquals(callcount[0], 0)
  408. @defer.inlineCallbacks
  409. def test_invalidate_context(self):
  410. callcount = [0]
  411. callcount2 = [0]
  412. class A:
  413. @cached()
  414. def func(self, key):
  415. callcount[0] += 1
  416. return key
  417. @cached(cache_context=True)
  418. def func2(self, key, cache_context):
  419. callcount2[0] += 1
  420. return self.func(key, on_invalidate=cache_context.invalidate)
  421. a = A()
  422. yield a.func2("foo")
  423. self.assertEquals(callcount[0], 1)
  424. self.assertEquals(callcount2[0], 1)
  425. a.func.invalidate(("foo",))
  426. yield a.func("foo")
  427. self.assertEquals(callcount[0], 2)
  428. self.assertEquals(callcount2[0], 1)
  429. yield a.func2("foo")
  430. self.assertEquals(callcount[0], 2)
  431. self.assertEquals(callcount2[0], 2)
  432. @defer.inlineCallbacks
  433. def test_eviction_context(self):
  434. callcount = [0]
  435. callcount2 = [0]
  436. class A:
  437. @cached(max_entries=2)
  438. def func(self, key):
  439. callcount[0] += 1
  440. return key
  441. @cached(cache_context=True)
  442. def func2(self, key, cache_context):
  443. callcount2[0] += 1
  444. return self.func(key, on_invalidate=cache_context.invalidate)
  445. a = A()
  446. yield a.func2("foo")
  447. yield a.func2("foo2")
  448. self.assertEquals(callcount[0], 2)
  449. self.assertEquals(callcount2[0], 2)
  450. yield a.func2("foo")
  451. self.assertEquals(callcount[0], 2)
  452. self.assertEquals(callcount2[0], 2)
  453. yield a.func("foo3")
  454. self.assertEquals(callcount[0], 3)
  455. self.assertEquals(callcount2[0], 2)
  456. yield a.func2("foo")
  457. self.assertEquals(callcount[0], 4)
  458. self.assertEquals(callcount2[0], 3)
  459. @defer.inlineCallbacks
  460. def test_double_get(self):
  461. callcount = [0]
  462. callcount2 = [0]
  463. class A:
  464. @cached()
  465. def func(self, key):
  466. callcount[0] += 1
  467. return key
  468. @cached(cache_context=True)
  469. def func2(self, key, cache_context):
  470. callcount2[0] += 1
  471. return self.func(key, on_invalidate=cache_context.invalidate)
  472. a = A()
  473. a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
  474. yield a.func2("foo")
  475. self.assertEquals(callcount[0], 1)
  476. self.assertEquals(callcount2[0], 1)
  477. a.func2.invalidate(("foo",))
  478. self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
  479. yield a.func2("foo")
  480. a.func2.invalidate(("foo",))
  481. self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
  482. self.assertEquals(callcount[0], 1)
  483. self.assertEquals(callcount2[0], 2)
  484. a.func.invalidate(("foo",))
  485. self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
  486. yield a.func("foo")
  487. self.assertEquals(callcount[0], 2)
  488. self.assertEquals(callcount2[0], 2)
  489. yield a.func2("foo")
  490. self.assertEquals(callcount[0], 2)
  491. self.assertEquals(callcount2[0], 3)
  492. class CachedListDescriptorTestCase(unittest.TestCase):
  493. @defer.inlineCallbacks
  494. def test_cache(self):
  495. class Cls:
  496. def __init__(self):
  497. self.mock = mock.Mock()
  498. @descriptors.cached()
  499. def fn(self, arg1, arg2):
  500. pass
  501. @descriptors.cachedList("fn", "args1")
  502. async def list_fn(self, args1, arg2):
  503. assert current_context().request == "c1"
  504. # we want this to behave like an asynchronous function
  505. await run_on_reactor()
  506. assert current_context().request == "c1"
  507. return self.mock(args1, arg2)
  508. with LoggingContext() as c1:
  509. c1.request = "c1"
  510. obj = Cls()
  511. obj.mock.return_value = {10: "fish", 20: "chips"}
  512. d1 = obj.list_fn([10, 20], 2)
  513. self.assertEqual(current_context(), SENTINEL_CONTEXT)
  514. r = yield d1
  515. self.assertEqual(current_context(), c1)
  516. obj.mock.assert_called_once_with([10, 20], 2)
  517. self.assertEqual(r, {10: "fish", 20: "chips"})
  518. obj.mock.reset_mock()
  519. # a call with different params should call the mock again
  520. obj.mock.return_value = {30: "peas"}
  521. r = yield obj.list_fn([20, 30], 2)
  522. obj.mock.assert_called_once_with([30], 2)
  523. self.assertEqual(r, {20: "chips", 30: "peas"})
  524. obj.mock.reset_mock()
  525. # all the values should now be cached
  526. r = yield obj.fn(10, 2)
  527. self.assertEqual(r, "fish")
  528. r = yield obj.fn(20, 2)
  529. self.assertEqual(r, "chips")
  530. r = yield obj.fn(30, 2)
  531. self.assertEqual(r, "peas")
  532. r = yield obj.list_fn([10, 20, 30], 2)
  533. obj.mock.assert_not_called()
  534. self.assertEqual(r, {10: "fish", 20: "chips", 30: "peas"})
  535. @defer.inlineCallbacks
  536. def test_invalidate(self):
  537. """Make sure that invalidation callbacks are called."""
  538. class Cls:
  539. def __init__(self):
  540. self.mock = mock.Mock()
  541. @descriptors.cached()
  542. def fn(self, arg1, arg2):
  543. pass
  544. @descriptors.cachedList("fn", "args1")
  545. async def list_fn(self, args1, arg2):
  546. # we want this to behave like an asynchronous function
  547. await run_on_reactor()
  548. return self.mock(args1, arg2)
  549. obj = Cls()
  550. invalidate0 = mock.Mock()
  551. invalidate1 = mock.Mock()
  552. # cache miss
  553. obj.mock.return_value = {10: "fish", 20: "chips"}
  554. r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
  555. obj.mock.assert_called_once_with([10, 20], 2)
  556. self.assertEqual(r1, {10: "fish", 20: "chips"})
  557. obj.mock.reset_mock()
  558. # cache hit
  559. r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
  560. obj.mock.assert_not_called()
  561. self.assertEqual(r2, {10: "fish", 20: "chips"})
  562. invalidate0.assert_not_called()
  563. invalidate1.assert_not_called()
  564. # now if we invalidate the keys, both invalidations should get called
  565. obj.fn.invalidate((10, 2))
  566. invalidate0.assert_called_once()
  567. invalidate1.assert_called_once()