Browse Source

Port to sortedcontainers (with tests!) (#3332)

Amber Brown 5 years ago
parent
commit
f7869f8f8b

+ 1 - 0
.gitignore

@@ -43,6 +43,7 @@ media_store/
 
 build/
 venv/
+venv*/
 
 localhost-800*/
 static/client/register/register_config.js

+ 2 - 1
setup.cfg

@@ -17,4 +17,5 @@ ignore =
 [flake8]
 max-line-length = 90
 #  W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
-ignore = W503
+#  E203 is contrary to PEP8.
+ignore = W503,E203

+ 7 - 7
synapse/federation/send_queue.py

@@ -35,7 +35,7 @@ from synapse.storage.presence import UserPresenceState
 from synapse.util.metrics import Measure
 from synapse.metrics import LaterGauge
 
-from blist import sorteddict
+from sortedcontainers import SortedDict
 from collections import namedtuple
 
 import logging
@@ -55,19 +55,19 @@ class FederationRemoteSendQueue(object):
         self.is_mine_id = hs.is_mine_id
 
         self.presence_map = {}  # Pending presence map user_id -> UserPresenceState
-        self.presence_changed = sorteddict()  # Stream position -> user_id
+        self.presence_changed = SortedDict()  # Stream position -> user_id
 
         self.keyed_edu = {}  # (destination, key) -> EDU
-        self.keyed_edu_changed = sorteddict()  # stream position -> (destination, key)
+        self.keyed_edu_changed = SortedDict()  # stream position -> (destination, key)
 
-        self.edus = sorteddict()  # stream position -> Edu
+        self.edus = SortedDict()  # stream position -> Edu
 
-        self.failures = sorteddict()  # stream position -> (destination, Failure)
+        self.failures = SortedDict()  # stream position -> (destination, Failure)
 
-        self.device_messages = sorteddict()  # stream position -> destination
+        self.device_messages = SortedDict()  # stream position -> destination
 
         self.pos = 1
-        self.pos_time = sorteddict()
+        self.pos_time = SortedDict()
 
         # EVERYTHING IS SAD. In particular, python only makes new scopes when
         # we make a new function, so we need to make a new function so the inner

+ 2 - 1
synapse/python_dependencies.py

@@ -50,7 +50,7 @@ REQUIREMENTS = {
     "bcrypt": ["bcrypt>=3.1.0"],
     "pillow": ["PIL"],
     "pydenticon": ["pydenticon"],
-    "blist": ["blist"],
+    "sortedcontainers": ["sortedcontainers"],
     "pysaml2>=3.0.0": ["saml2>=3.0.0"],
     "pymacaroons-pynacl": ["pymacaroons"],
     "msgpack-python>=0.3.0": ["msgpack"],
@@ -58,6 +58,7 @@ REQUIREMENTS = {
     "six": ["six"],
     "prometheus_client": ["prometheus_client"],
 }
+
 CONDITIONAL_REQUIREMENTS = {
     "web_client": {
         "matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],

+ 31 - 26
synapse/util/caches/stream_change_cache.py

@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.caches import register_cache, CACHE_SIZE_FACTOR
+from synapse.util import caches
 
 
-from blist import sorteddict
+from sortedcontainers import SortedDict
 import logging
 
 
@@ -32,16 +32,18 @@ class StreamChangeCache(object):
     entities that may have changed since that position. If position key is too
     old then the cache will simply return all given entities.
     """
-    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache={}):
-        self._max_size = int(max_size * CACHE_SIZE_FACTOR)
+
+    def __init__(self, name, current_stream_pos, max_size=10000, prefilled_cache=None):
+        self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR)
         self._entity_to_key = {}
-        self._cache = sorteddict()
+        self._cache = SortedDict()
         self._earliest_known_stream_pos = current_stream_pos
         self.name = name
-        self.metrics = register_cache("cache", self.name, self._cache)
+        self.metrics = caches.register_cache("cache", self.name, self._cache)
 
-        for entity, stream_pos in prefilled_cache.items():
-            self.entity_has_changed(entity, stream_pos)
+        if prefilled_cache:
+            for entity, stream_pos in prefilled_cache.items():
+                self.entity_has_changed(entity, stream_pos)
 
     def has_entity_changed(self, entity, stream_pos):
         """Returns True if the entity may have been updated since stream_pos
@@ -65,22 +67,25 @@ class StreamChangeCache(object):
         return False
 
     def get_entities_changed(self, entities, stream_pos):
-        """Returns subset of entities that have had new things since the
-        given position. If the position is too old it will just return the given list.
+        """
+        Returns subset of entities that have had new things since the given
+        position.  Entities unknown to the cache will be returned.  If the
+        position is too old it will just return the given list.
         """
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
+            not_known_entities = set(entities) - set(self._entity_to_key)
 
-            result = set(
-                self._cache[k] for k in keys[i:]
-            ).intersection(entities)
+            result = (
+                set(self._cache.values()[self._cache.bisect_right(stream_pos) :])
+                .intersection(entities)
+                .union(not_known_entities)
+            )
 
             self.metrics.inc_hits()
         else:
-            result = entities
+            result = set(entities)
             self.metrics.inc_misses()
 
         return result
@@ -90,12 +95,13 @@ class StreamChangeCache(object):
         """
         assert type(stream_pos) is int
 
+        if not self._cache:
+            # If we have no cache, nothing can have changed.
+            return False
+
         if stream_pos >= self._earliest_known_stream_pos:
             self.metrics.inc_hits()
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
-
-            return i < len(keys)
+            return self._cache.bisect_right(stream_pos) < len(self._cache)
         else:
             self.metrics.inc_misses()
             return True
@@ -107,10 +113,7 @@ class StreamChangeCache(object):
         assert type(stream_pos) is int
 
         if stream_pos >= self._earliest_known_stream_pos:
-            keys = self._cache.keys()
-            i = keys.bisect_right(stream_pos)
-
-            return [self._cache[k] for k in keys[i:]]
+            return self._cache.values()[self._cache.bisect_right(stream_pos) :]
         else:
             return None
 
@@ -129,8 +132,10 @@ class StreamChangeCache(object):
             self._entity_to_key[entity] = stream_pos
 
             while len(self._cache) > self._max_size:
-                k, r = self._cache.popitem()
-                self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos)
+                k, r = self._cache.popitem(0)
+                self._earliest_known_stream_pos = max(
+                    k, self._earliest_known_stream_pos,
+                )
                 self._entity_to_key.pop(r, None)
 
     def get_max_pos_of_last_change(self, entity):

+ 198 - 0
tests/util/test_stream_change_cache.py

@@ -0,0 +1,198 @@
+from tests import unittest
+from mock import patch
+
+from synapse.util.caches.stream_change_cache import StreamChangeCache
+
+
+class StreamChangeCacheTests(unittest.TestCase):
+    """
+    Tests for StreamChangeCache.
+    """
+
+    def test_prefilled_cache(self):
+        """
+        Providing a prefilled cache to StreamChangeCache will result in a cache
+        with the prefilled-cache entered in.
+        """
+        cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2})
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 1))
+
+    def test_has_entity_changed(self):
+        """
+        StreamChangeCache.entity_has_changed will mark entities as changed, and
+        has_entity_changed will observe the changed entities.
+        """
+        cache = StreamChangeCache("#test", 3)
+
+        cache.entity_has_changed("user@foo.com", 6)
+        cache.entity_has_changed("bar@baz.net", 7)
+
+        # If it's been changed after that stream position, return True
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 4))
+        self.assertTrue(cache.has_entity_changed("bar@baz.net", 4))
+
+        # If it's been changed at that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 6))
+
+        # If there's no changes after that stream position, return False
+        self.assertFalse(cache.has_entity_changed("user@foo.com", 7))
+
+        # If the entity does not exist, return False.
+        self.assertFalse(cache.has_entity_changed("not@here.website", 7))
+
+        # If we request before the stream cache's earliest known position,
+        # return True, whether it's a known entity or not.
+        self.assertTrue(cache.has_entity_changed("user@foo.com", 0))
+        self.assertTrue(cache.has_entity_changed("not@here.website", 0))
+
+    @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0)
+    def test_has_entity_changed_pops_off_start(self):
+        """
+        StreamChangeCache.entity_has_changed will respect the max size and
+        purge the oldest items upon reaching that max size.
+        """
+        cache = StreamChangeCache("#test", 1, max_size=2)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # The cache is at the max size, 2
+        self.assertEqual(len(cache._cache), 2)
+
+        # The oldest item has been popped off
+        self.assertTrue("user@foo.com" not in cache._entity_to_key)
+
+        # If we update an existing entity, it keeps the two existing entities
+        cache.entity_has_changed("bar@baz.net", 5)
+        self.assertEqual(
+            set(["bar@baz.net", "user@elsewhere.org"]), set(cache._entity_to_key)
+        )
+
+    def test_get_all_entities_changed(self):
+        """
+        StreamChangeCache.get_all_entities_changed will return all changed
+        entities since the given position.  If the position is before the start
+        of the known stream, it returns None instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        self.assertEqual(
+            cache.get_all_entities_changed(1),
+            ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
+        )
+        self.assertEqual(
+            cache.get_all_entities_changed(2), ["bar@baz.net", "user@elsewhere.org"]
+        )
+        self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"])
+        self.assertEqual(cache.get_all_entities_changed(0), None)
+
+    def test_has_any_entity_changed(self):
+        """
+        StreamChangeCache.has_any_entity_changed will return True if any
+        entities have been changed since the provided stream position, and
+        False if they have not.  If the cache has entries and the provided
+        stream position is before it, it will return True, otherwise False if
+        the cache has no entries.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        # With no entities, it returns False for the past, present, and future.
+        self.assertFalse(cache.has_any_entity_changed(0))
+        self.assertFalse(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+
+        # We add an entity
+        cache.entity_has_changed("user@foo.com", 2)
+
+        # With an entity, it returns True for the past, the stream start
+        # position, and False for the stream position the entity was changed
+        # on and ones after it.
+        self.assertTrue(cache.has_any_entity_changed(0))
+        self.assertTrue(cache.has_any_entity_changed(1))
+        self.assertFalse(cache.has_any_entity_changed(2))
+        self.assertFalse(cache.has_any_entity_changed(3))
+
+    def test_get_entities_changed(self):
+        """
+        StreamChangeCache.get_entities_changed will return the entities in the
+        given list that have changed since the provided stream ID.  If the
+        stream position is earlier than the earliest known position, it will
+        return all of the entities queried for.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Query all the entries, but mid-way through the stream. We should only
+        # get the ones after that point.
+        self.assertEqual(
+            cache.get_entities_changed(
+                ["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2
+            ),
+            set(["bar@baz.net", "user@elsewhere.org"]),
+        )
+
+        # Query all the entries mid-way through the stream, but include one
+        # that doesn't exist in it. We should get back the one that doesn't
+        # exist, too.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=2,
+            ),
+            set(["bar@baz.net", "user@elsewhere.org", "not@here.website"]),
+        )
+
+        # Query all the entries, but before the first known point. We will get
+        # all the entries we queried for, including ones that don't exist.
+        self.assertEqual(
+            cache.get_entities_changed(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ],
+                stream_pos=0,
+            ),
+            set(
+                [
+                    "user@foo.com",
+                    "bar@baz.net",
+                    "user@elsewhere.org",
+                    "not@here.website",
+                ]
+            ),
+        )
+
+    def test_max_pos(self):
+        """
+        StreamChangeCache.get_max_pos_of_last_change will return the most
+        recent point where the entity could have changed.  If the entity is not
+        known, the stream start is provided instead.
+        """
+        cache = StreamChangeCache("#test", 1)
+
+        cache.entity_has_changed("user@foo.com", 2)
+        cache.entity_has_changed("bar@baz.net", 3)
+        cache.entity_has_changed("user@elsewhere.org", 4)
+
+        # Known entities will return the point where they were changed.
+        self.assertEqual(cache.get_max_pos_of_last_change("user@foo.com"), 2)
+        self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
+        self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
+
+        # Unknown entities will return the stream start position.
+        self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)