Browse Source

Merge pull request #1620 from matrix-org/erikj/concurrent_room_access

Limit the number of events that can be created on a given room concurrently
Erik Johnston 7 years ago
parent
commit
d53a80af25
3 changed files with 161 additions and 27 deletions
  1. 33 27
      synapse/handlers/message.py
  2. 58 0
      synapse/util/async.py
  3. 70 0
      tests/util/test_limiter.py

+ 33 - 27
synapse/handlers/message.py

@@ -24,7 +24,7 @@ from synapse.push.action_generator import ActionGenerator
 from synapse.types import (
     UserID, RoomAlias, RoomStreamToken,
 )
-from synapse.util.async import run_on_reactor, ReadWriteLock
+from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
 from synapse.util.logcontext import preserve_fn
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
@@ -50,6 +50,10 @@ class MessageHandler(BaseHandler):
 
         self.pagination_lock = ReadWriteLock()
 
+        # We arbitrarily limit concurrent event creation for a room to 5.
+        # This is to stop us from diverging history *too* much.
+        self.limiter = Limiter(max_count=5)
+
     @defer.inlineCallbacks
     def purge_history(self, room_id, event_id):
         event = yield self.store.get_event(event_id)
@@ -191,36 +195,38 @@ class MessageHandler(BaseHandler):
         """
         builder = self.event_builder_factory.new(event_dict)
 
-        self.validator.validate_new(builder)
-
-        if builder.type == EventTypes.Member:
-            membership = builder.content.get("membership", None)
-            target = UserID.from_string(builder.state_key)
+        with (yield self.limiter.queue(builder.room_id)):
+            self.validator.validate_new(builder)
+
+            if builder.type == EventTypes.Member:
+                membership = builder.content.get("membership", None)
+                target = UserID.from_string(builder.state_key)
+
+                if membership in {Membership.JOIN, Membership.INVITE}:
+                    # If event doesn't include a display name, add one.
+                    profile = self.hs.get_handlers().profile_handler
+                    content = builder.content
+
+                    try:
+                        content["displayname"] = yield profile.get_displayname(target)
+                        content["avatar_url"] = yield profile.get_avatar_url(target)
+                    except Exception as e:
+                        logger.info(
+                            "Failed to get profile information for %r: %s",
+                            target, e
+                        )
 
-            if membership in {Membership.JOIN, Membership.INVITE}:
-                # If event doesn't include a display name, add one.
-                profile = self.hs.get_handlers().profile_handler
-                content = builder.content
+            if token_id is not None:
+                builder.internal_metadata.token_id = token_id
 
-                try:
-                    content["displayname"] = yield profile.get_displayname(target)
-                    content["avatar_url"] = yield profile.get_avatar_url(target)
-                except Exception as e:
-                    logger.info(
-                        "Failed to get profile information for %r: %s",
-                        target, e
-                    )
+            if txn_id is not None:
+                builder.internal_metadata.txn_id = txn_id
 
-        if token_id is not None:
-            builder.internal_metadata.token_id = token_id
-
-        if txn_id is not None:
-            builder.internal_metadata.txn_id = txn_id
+            event, context = yield self._create_new_client_event(
+                builder=builder,
+                prev_event_ids=prev_event_ids,
+            )
 
-        event, context = yield self._create_new_client_event(
-            builder=builder,
-            prev_event_ids=prev_event_ids,
-        )
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks

+ 58 - 0
synapse/util/async.py

@@ -197,6 +197,64 @@ class Linearizer(object):
         defer.returnValue(_ctx_manager())
 
 
+class Limiter(object):
+    """Limits concurrent access to resources based on a key. Useful to ensure
+    only a few thing happen at a time on a given resource.
+
+    Example:
+
+        with (yield limiter.queue("test_key")):
+            # do some work.
+
+    """
+    def __init__(self, max_count):
+        """
+        Args:
+            max_count(int): The maximum number of concurrent access
+        """
+        self.max_count = max_count
+
+        # key_to_defer is a map from the key to a 2 element list where
+        # the first element is the number of things executing
+        # the second element is a list of deferreds for the things blocked from
+        # executing.
+        self.key_to_defer = {}
+
+    @defer.inlineCallbacks
+    def queue(self, key):
+        entry = self.key_to_defer.setdefault(key, [0, []])
+
+        # If the number of things executing is greater than the maximum
+        # then add a deferred to the list of blocked items
+        # When on of the things currently executing finishes it will callback
+        # this item so that it can continue executing.
+        if entry[0] >= self.max_count:
+            new_defer = defer.Deferred()
+            entry[1].append(new_defer)
+            with PreserveLoggingContext():
+                yield new_defer
+
+        entry[0] += 1
+
+        @contextmanager
+        def _ctx_manager():
+            try:
+                yield
+            finally:
+                # We've finished executing so check if there are any things
+                # blocked waiting to execute and start one of them
+                entry[0] -= 1
+                try:
+                    entry[1].pop(0).callback(None)
+                except IndexError:
+                    # If nothing else is executing for this key then remove it
+                    # from the map
+                    if entry[0] == 0:
+                        self.key_to_defer.pop(key, None)
+
+        defer.returnValue(_ctx_manager())
+
+
 class ReadWriteLock(object):
     """A deferred style read write lock.
 

+ 70 - 0
tests/util/test_limiter.py

@@ -0,0 +1,70 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from tests import unittest
+
+from twisted.internet import defer
+
+from synapse.util.async import Limiter
+
+
+class LimiterTestCase(unittest.TestCase):
+
+    @defer.inlineCallbacks
+    def test_limiter(self):
+        limiter = Limiter(3)
+
+        key = object()
+
+        d1 = limiter.queue(key)
+        cm1 = yield d1
+
+        d2 = limiter.queue(key)
+        cm2 = yield d2
+
+        d3 = limiter.queue(key)
+        cm3 = yield d3
+
+        d4 = limiter.queue(key)
+        self.assertFalse(d4.called)
+
+        d5 = limiter.queue(key)
+        self.assertFalse(d5.called)
+
+        with cm1:
+            self.assertFalse(d4.called)
+            self.assertFalse(d5.called)
+
+        self.assertTrue(d4.called)
+        self.assertFalse(d5.called)
+
+        with cm3:
+            self.assertFalse(d5.called)
+
+        self.assertTrue(d5.called)
+
+        with cm2:
+            pass
+
+        with (yield d4):
+            pass
+
+        with (yield d5):
+            pass
+
+        d6 = limiter.queue(key)
+        with (yield d6):
+            pass