test__base.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from tests import unittest
  16. from twisted.internet import defer
  17. from mock import Mock
  18. from synapse.util.async import ObservableDeferred
  19. from synapse.util.caches.descriptors import Cache, cached
  20. class CacheTestCase(unittest.TestCase):
  21. def setUp(self):
  22. self.cache = Cache("test")
  23. def test_empty(self):
  24. failed = False
  25. try:
  26. self.cache.get("foo")
  27. except KeyError:
  28. failed = True
  29. self.assertTrue(failed)
  30. def test_hit(self):
  31. self.cache.prefill("foo", 123)
  32. self.assertEquals(self.cache.get("foo"), 123)
  33. def test_invalidate(self):
  34. self.cache.prefill(("foo",), 123)
  35. self.cache.invalidate(("foo",))
  36. failed = False
  37. try:
  38. self.cache.get(("foo",))
  39. except KeyError:
  40. failed = True
  41. self.assertTrue(failed)
  42. def test_eviction(self):
  43. cache = Cache("test", max_entries=2)
  44. cache.prefill(1, "one")
  45. cache.prefill(2, "two")
  46. cache.prefill(3, "three") # 1 will be evicted
  47. failed = False
  48. try:
  49. cache.get(1)
  50. except KeyError:
  51. failed = True
  52. self.assertTrue(failed)
  53. cache.get(2)
  54. cache.get(3)
  55. def test_eviction_lru(self):
  56. cache = Cache("test", max_entries=2)
  57. cache.prefill(1, "one")
  58. cache.prefill(2, "two")
  59. # Now access 1 again, thus causing 2 to be least-recently used
  60. cache.get(1)
  61. cache.prefill(3, "three")
  62. failed = False
  63. try:
  64. cache.get(2)
  65. except KeyError:
  66. failed = True
  67. self.assertTrue(failed)
  68. cache.get(1)
  69. cache.get(3)
  70. class CacheDecoratorTestCase(unittest.TestCase):
  71. @defer.inlineCallbacks
  72. def test_passthrough(self):
  73. class A(object):
  74. @cached()
  75. def func(self, key):
  76. return key
  77. a = A()
  78. self.assertEquals((yield a.func("foo")), "foo")
  79. self.assertEquals((yield a.func("bar")), "bar")
  80. @defer.inlineCallbacks
  81. def test_hit(self):
  82. callcount = [0]
  83. class A(object):
  84. @cached()
  85. def func(self, key):
  86. callcount[0] += 1
  87. return key
  88. a = A()
  89. yield a.func("foo")
  90. self.assertEquals(callcount[0], 1)
  91. self.assertEquals((yield a.func("foo")), "foo")
  92. self.assertEquals(callcount[0], 1)
  93. @defer.inlineCallbacks
  94. def test_invalidate(self):
  95. callcount = [0]
  96. class A(object):
  97. @cached()
  98. def func(self, key):
  99. callcount[0] += 1
  100. return key
  101. a = A()
  102. yield a.func("foo")
  103. self.assertEquals(callcount[0], 1)
  104. a.func.invalidate(("foo",))
  105. yield a.func("foo")
  106. self.assertEquals(callcount[0], 2)
  107. def test_invalidate_missing(self):
  108. class A(object):
  109. @cached()
  110. def func(self, key):
  111. return key
  112. A().func.invalidate(("what",))
  113. @defer.inlineCallbacks
  114. def test_max_entries(self):
  115. callcount = [0]
  116. class A(object):
  117. @cached(max_entries=10)
  118. def func(self, key):
  119. callcount[0] += 1
  120. return key
  121. a = A()
  122. for k in range(0, 12):
  123. yield a.func(k)
  124. self.assertEquals(callcount[0], 12)
  125. # There must have been at least 2 evictions, meaning if we calculate
  126. # all 12 values again, we must get called at least 2 more times
  127. for k in range(0, 12):
  128. yield a.func(k)
  129. self.assertTrue(
  130. callcount[0] >= 14,
  131. msg="Expected callcount >= 14, got %d" % (callcount[0])
  132. )
  133. def test_prefill(self):
  134. callcount = [0]
  135. d = defer.succeed(123)
  136. class A(object):
  137. @cached()
  138. def func(self, key):
  139. callcount[0] += 1
  140. return d
  141. a = A()
  142. a.func.prefill(("foo",), ObservableDeferred(d))
  143. self.assertEquals(a.func("foo"), d.result)
  144. self.assertEquals(callcount[0], 0)
  145. @defer.inlineCallbacks
  146. def test_invalidate_context(self):
  147. callcount = [0]
  148. callcount2 = [0]
  149. class A(object):
  150. @cached()
  151. def func(self, key):
  152. callcount[0] += 1
  153. return key
  154. @cached(cache_context=True)
  155. def func2(self, key, cache_context):
  156. callcount2[0] += 1
  157. return self.func(key, on_invalidate=cache_context.invalidate)
  158. a = A()
  159. yield a.func2("foo")
  160. self.assertEquals(callcount[0], 1)
  161. self.assertEquals(callcount2[0], 1)
  162. a.func.invalidate(("foo",))
  163. yield a.func("foo")
  164. self.assertEquals(callcount[0], 2)
  165. self.assertEquals(callcount2[0], 1)
  166. yield a.func2("foo")
  167. self.assertEquals(callcount[0], 2)
  168. self.assertEquals(callcount2[0], 2)
  169. @defer.inlineCallbacks
  170. def test_eviction_context(self):
  171. callcount = [0]
  172. callcount2 = [0]
  173. class A(object):
  174. @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
  175. def func(self, key):
  176. callcount[0] += 1
  177. return key
  178. @cached(cache_context=True)
  179. def func2(self, key, cache_context):
  180. callcount2[0] += 1
  181. return self.func(key, on_invalidate=cache_context.invalidate)
  182. a = A()
  183. yield a.func2("foo")
  184. yield a.func2("foo2")
  185. self.assertEquals(callcount[0], 2)
  186. self.assertEquals(callcount2[0], 2)
  187. yield a.func2("foo")
  188. self.assertEquals(callcount[0], 2)
  189. self.assertEquals(callcount2[0], 2)
  190. yield a.func("foo3")
  191. self.assertEquals(callcount[0], 3)
  192. self.assertEquals(callcount2[0], 2)
  193. yield a.func2("foo")
  194. self.assertEquals(callcount[0], 4)
  195. self.assertEquals(callcount2[0], 3)
  196. @defer.inlineCallbacks
  197. def test_double_get(self):
  198. callcount = [0]
  199. callcount2 = [0]
  200. class A(object):
  201. @cached()
  202. def func(self, key):
  203. callcount[0] += 1
  204. return key
  205. @cached(cache_context=True)
  206. def func2(self, key, cache_context):
  207. callcount2[0] += 1
  208. return self.func(key, on_invalidate=cache_context.invalidate)
  209. a = A()
  210. a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
  211. yield a.func2("foo")
  212. self.assertEquals(callcount[0], 1)
  213. self.assertEquals(callcount2[0], 1)
  214. a.func2.invalidate(("foo",))
  215. self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
  216. yield a.func2("foo")
  217. a.func2.invalidate(("foo",))
  218. self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
  219. self.assertEquals(callcount[0], 1)
  220. self.assertEquals(callcount2[0], 2)
  221. a.func.invalidate(("foo",))
  222. self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
  223. yield a.func("foo")
  224. self.assertEquals(callcount[0], 2)
  225. self.assertEquals(callcount2[0], 2)
  226. yield a.func2("foo")
  227. self.assertEquals(callcount[0], 2)
  228. self.assertEquals(callcount2[0], 3)