test__base.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mock import Mock
  16. from twisted.internet import defer
  17. from synapse.util.async_helpers import ObservableDeferred
  18. from synapse.util.caches.descriptors import Cache, cached
  19. from tests import unittest
  20. class CacheTestCase(unittest.TestCase):
  21. def setUp(self):
  22. self.cache = Cache("test")
  23. def test_empty(self):
  24. failed = False
  25. try:
  26. self.cache.get("foo")
  27. except KeyError:
  28. failed = True
  29. self.assertTrue(failed)
  30. def test_hit(self):
  31. self.cache.prefill("foo", 123)
  32. self.assertEquals(self.cache.get("foo"), 123)
  33. def test_invalidate(self):
  34. self.cache.prefill(("foo",), 123)
  35. self.cache.invalidate(("foo",))
  36. failed = False
  37. try:
  38. self.cache.get(("foo",))
  39. except KeyError:
  40. failed = True
  41. self.assertTrue(failed)
  42. def test_eviction(self):
  43. cache = Cache("test", max_entries=2)
  44. cache.prefill(1, "one")
  45. cache.prefill(2, "two")
  46. cache.prefill(3, "three") # 1 will be evicted
  47. failed = False
  48. try:
  49. cache.get(1)
  50. except KeyError:
  51. failed = True
  52. self.assertTrue(failed)
  53. cache.get(2)
  54. cache.get(3)
  55. def test_eviction_lru(self):
  56. cache = Cache("test", max_entries=2)
  57. cache.prefill(1, "one")
  58. cache.prefill(2, "two")
  59. # Now access 1 again, thus causing 2 to be least-recently used
  60. cache.get(1)
  61. cache.prefill(3, "three")
  62. failed = False
  63. try:
  64. cache.get(2)
  65. except KeyError:
  66. failed = True
  67. self.assertTrue(failed)
  68. cache.get(1)
  69. cache.get(3)
  70. class CacheDecoratorTestCase(unittest.TestCase):
  71. @defer.inlineCallbacks
  72. def test_passthrough(self):
  73. class A(object):
  74. @cached()
  75. def func(self, key):
  76. return key
  77. a = A()
  78. self.assertEquals((yield a.func("foo")), "foo")
  79. self.assertEquals((yield a.func("bar")), "bar")
  80. @defer.inlineCallbacks
  81. def test_hit(self):
  82. callcount = [0]
  83. class A(object):
  84. @cached()
  85. def func(self, key):
  86. callcount[0] += 1
  87. return key
  88. a = A()
  89. yield a.func("foo")
  90. self.assertEquals(callcount[0], 1)
  91. self.assertEquals((yield a.func("foo")), "foo")
  92. self.assertEquals(callcount[0], 1)
  93. @defer.inlineCallbacks
  94. def test_invalidate(self):
  95. callcount = [0]
  96. class A(object):
  97. @cached()
  98. def func(self, key):
  99. callcount[0] += 1
  100. return key
  101. a = A()
  102. yield a.func("foo")
  103. self.assertEquals(callcount[0], 1)
  104. a.func.invalidate(("foo",))
  105. yield a.func("foo")
  106. self.assertEquals(callcount[0], 2)
  107. def test_invalidate_missing(self):
  108. class A(object):
  109. @cached()
  110. def func(self, key):
  111. return key
  112. A().func.invalidate(("what",))
  113. @defer.inlineCallbacks
  114. def test_max_entries(self):
  115. callcount = [0]
  116. class A(object):
  117. @cached(max_entries=10)
  118. def func(self, key):
  119. callcount[0] += 1
  120. return key
  121. a = A()
  122. for k in range(0, 12):
  123. yield a.func(k)
  124. self.assertEquals(callcount[0], 12)
  125. # There must have been at least 2 evictions, meaning if we calculate
  126. # all 12 values again, we must get called at least 2 more times
  127. for k in range(0, 12):
  128. yield a.func(k)
  129. self.assertTrue(
  130. callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
  131. )
  132. def test_prefill(self):
  133. callcount = [0]
  134. d = defer.succeed(123)
  135. class A(object):
  136. @cached()
  137. def func(self, key):
  138. callcount[0] += 1
  139. return d
  140. a = A()
  141. a.func.prefill(("foo",), ObservableDeferred(d))
  142. self.assertEquals(a.func("foo"), d.result)
  143. self.assertEquals(callcount[0], 0)
  144. @defer.inlineCallbacks
  145. def test_invalidate_context(self):
  146. callcount = [0]
  147. callcount2 = [0]
  148. class A(object):
  149. @cached()
  150. def func(self, key):
  151. callcount[0] += 1
  152. return key
  153. @cached(cache_context=True)
  154. def func2(self, key, cache_context):
  155. callcount2[0] += 1
  156. return self.func(key, on_invalidate=cache_context.invalidate)
  157. a = A()
  158. yield a.func2("foo")
  159. self.assertEquals(callcount[0], 1)
  160. self.assertEquals(callcount2[0], 1)
  161. a.func.invalidate(("foo",))
  162. yield a.func("foo")
  163. self.assertEquals(callcount[0], 2)
  164. self.assertEquals(callcount2[0], 1)
  165. yield a.func2("foo")
  166. self.assertEquals(callcount[0], 2)
  167. self.assertEquals(callcount2[0], 2)
  168. @defer.inlineCallbacks
  169. def test_eviction_context(self):
  170. callcount = [0]
  171. callcount2 = [0]
  172. class A(object):
  173. @cached(max_entries=4) # HACK: This makes it 2 due to cache factor
  174. def func(self, key):
  175. callcount[0] += 1
  176. return key
  177. @cached(cache_context=True)
  178. def func2(self, key, cache_context):
  179. callcount2[0] += 1
  180. return self.func(key, on_invalidate=cache_context.invalidate)
  181. a = A()
  182. yield a.func2("foo")
  183. yield a.func2("foo2")
  184. self.assertEquals(callcount[0], 2)
  185. self.assertEquals(callcount2[0], 2)
  186. yield a.func2("foo")
  187. self.assertEquals(callcount[0], 2)
  188. self.assertEquals(callcount2[0], 2)
  189. yield a.func("foo3")
  190. self.assertEquals(callcount[0], 3)
  191. self.assertEquals(callcount2[0], 2)
  192. yield a.func2("foo")
  193. self.assertEquals(callcount[0], 4)
  194. self.assertEquals(callcount2[0], 3)
  195. @defer.inlineCallbacks
  196. def test_double_get(self):
  197. callcount = [0]
  198. callcount2 = [0]
  199. class A(object):
  200. @cached()
  201. def func(self, key):
  202. callcount[0] += 1
  203. return key
  204. @cached(cache_context=True)
  205. def func2(self, key, cache_context):
  206. callcount2[0] += 1
  207. return self.func(key, on_invalidate=cache_context.invalidate)
  208. a = A()
  209. a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
  210. yield a.func2("foo")
  211. self.assertEquals(callcount[0], 1)
  212. self.assertEquals(callcount2[0], 1)
  213. a.func2.invalidate(("foo",))
  214. self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
  215. yield a.func2("foo")
  216. a.func2.invalidate(("foo",))
  217. self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
  218. self.assertEquals(callcount[0], 1)
  219. self.assertEquals(callcount2[0], 2)
  220. a.func.invalidate(("foo",))
  221. self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
  222. yield a.func("foo")
  223. self.assertEquals(callcount[0], 2)
  224. self.assertEquals(callcount2[0], 2)
  225. yield a.func2("foo")
  226. self.assertEquals(callcount[0], 2)
  227. self.assertEquals(callcount2[0], 3)