Browse Source

Fix up logcontexts

Erik Johnston 8 years ago
parent
commit
2c1fbea531

+ 3 - 1
synapse/api/auth.py

@@ -24,6 +24,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules
 from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
 from synapse.types import Requester, RoomID, UserID, EventID
 from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_context_over_fn
 from unpaddedbase64 import decode_base64
 
 import logging
@@ -529,7 +530,8 @@ class Auth(object):
                 default=[""]
             )[0]
             if user and access_token and ip_addr:
-                self.store.insert_client_ip(
+                preserve_context_over_fn(
+                    self.store.insert_client_ip,
                     user=user,
                     access_token=access_token,
                     ip=ip_addr,

+ 2 - 0
synapse/app/homeserver.py

@@ -709,6 +709,8 @@ def run(hs):
         phone_home_task.start(60 * 60 * 24, now=False)
 
     def in_thread():
+        # Uncomment to enable tracing of log context changes.
+        # sys.settrace(logcontext_tracer)
         with LoggingContext("run"):
             change_resource_limit(hs.config.soft_file_limit)
             reactor.run()

+ 46 - 37
synapse/crypto/keyring.py

@@ -18,6 +18,10 @@ from synapse.api.errors import SynapseError, Codes
 from synapse.util.retryutils import get_retry_limiter
 from synapse.util import unwrapFirstError
 from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import (
+    preserve_context_over_deferred, preserve_context_over_fn, PreserveLoggingContext,
+    preserve_fn
+)
 
 from twisted.internet import defer
 
@@ -142,40 +146,43 @@ class Keyring(object):
             for server_name, _ in server_and_json
         }
 
-        # We want to wait for any previous lookups to complete before
-        # proceeding.
-        wait_on_deferred = self.wait_for_previous_lookups(
-            [server_name for server_name, _ in server_and_json],
-            server_to_deferred,
-        )
+        with PreserveLoggingContext():
 
-        # Actually start fetching keys.
-        wait_on_deferred.addBoth(
-            lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
-        )
+            # We want to wait for any previous lookups to complete before
+            # proceeding.
+            wait_on_deferred = self.wait_for_previous_lookups(
+                [server_name for server_name, _ in server_and_json],
+                server_to_deferred,
+            )
 
-        # When we've finished fetching all the keys for a given server_name,
-        # resolve the deferred passed to `wait_for_previous_lookups` so that
-        # any lookups waiting will proceed.
-        server_to_gids = {}
+            # Actually start fetching keys.
+            wait_on_deferred.addBoth(
+                lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+            )
+
+            # When we've finished fetching all the keys for a given server_name,
+            # resolve the deferred passed to `wait_for_previous_lookups` so that
+            # any lookups waiting will proceed.
+            server_to_gids = {}
 
-        def remove_deferreds(res, server_name, group_id):
-            server_to_gids[server_name].discard(group_id)
-            if not server_to_gids[server_name]:
-                d = server_to_deferred.pop(server_name, None)
-                if d:
-                    d.callback(None)
-            return res
+            def remove_deferreds(res, server_name, group_id):
+                server_to_gids[server_name].discard(group_id)
+                if not server_to_gids[server_name]:
+                    d = server_to_deferred.pop(server_name, None)
+                    if d:
+                        d.callback(None)
+                return res
 
-        for g_id, deferred in deferreds.items():
-            server_name = group_id_to_group[g_id].server_name
-            server_to_gids.setdefault(server_name, set()).add(g_id)
-            deferred.addBoth(remove_deferreds, server_name, g_id)
+            for g_id, deferred in deferreds.items():
+                server_name = group_id_to_group[g_id].server_name
+                server_to_gids.setdefault(server_name, set()).add(g_id)
+                deferred.addBoth(remove_deferreds, server_name, g_id)
 
         # Pass those keys to handle_key_deferred so that the json object
         # signatures can be verified
         return [
-            handle_key_deferred(
+            preserve_context_over_fn(
+                handle_key_deferred,
                 group_id_to_group[g_id],
                 deferreds[g_id],
             )
@@ -198,12 +205,13 @@ class Keyring(object):
                 if server_name in self.key_downloads
             ]
             if wait_on:
-                yield defer.DeferredList(wait_on)
+                with PreserveLoggingContext():
+                    yield defer.DeferredList(wait_on)
             else:
                 break
 
         for server_name, deferred in server_to_deferred.items():
-            d = ObservableDeferred(deferred)
+            d = ObservableDeferred(preserve_context_over_deferred(deferred))
             self.key_downloads[server_name] = d
 
             def rm(r, server_name):
@@ -244,12 +252,13 @@ class Keyring(object):
                 for group in group_id_to_group.values():
                     for key_id in group.key_ids:
                         if key_id in merged_results[group.server_name]:
-                            group_id_to_deferred[group.group_id].callback((
-                                group.group_id,
-                                group.server_name,
-                                key_id,
-                                merged_results[group.server_name][key_id],
-                            ))
+                            with PreserveLoggingContext():
+                                group_id_to_deferred[group.group_id].callback((
+                                    group.group_id,
+                                    group.server_name,
+                                    key_id,
+                                    merged_results[group.server_name][key_id],
+                                ))
                             break
                     else:
                         missing_groups.setdefault(
@@ -504,7 +513,7 @@ class Keyring(object):
 
         yield defer.gatherResults(
             [
-                self.store_keys(
+                preserve_fn(self.store_keys)(
                     server_name=key_server_name,
                     from_server=server_name,
                     verify_keys=verify_keys,
@@ -573,7 +582,7 @@ class Keyring(object):
 
         yield defer.gatherResults(
             [
-                self.store.store_server_keys_json(
+                preserve_fn(self.store.store_server_keys_json)(
                     server_name=server_name,
                     key_id=key_id,
                     from_server=server_name,
@@ -675,7 +684,7 @@ class Keyring(object):
         # TODO(markjh): Store whether the keys have expired.
         yield defer.gatherResults(
             [
-                self.store.store_server_verify_key(
+                preserve_fn(self.store.store_server_verify_key)(
                     server_name, server_name, key.time_added, key
                 )
                 for key_id, key in verify_keys.items()

+ 1 - 3
synapse/federation/federation_server.py

@@ -126,10 +126,8 @@ class FederationServer(FederationBase):
         results = []
 
         for pdu in pdu_list:
-            d = self._handle_new_pdu(transaction.origin, pdu)
-
             try:
-                yield d
+                yield self._handle_new_pdu(transaction.origin, pdu)
                 results.append({})
             except FederationError as e:
                 self.send_failure(e, transaction.origin)

+ 0 - 3
synapse/federation/transaction_queue.py

@@ -103,7 +103,6 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    @defer.inlineCallbacks
     def enqueue_pdu(self, pdu, destinations, order):
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
@@ -141,8 +140,6 @@ class TransactionQueue(object):
 
             deferreds.append(deferred)
 
-        yield defer.DeferredList(deferreds, consumeErrors=True)
-
     # NO inlineCallbacks
     def enqueue_edu(self, edu):
         destination = edu.destination

+ 1 - 9
synapse/handlers/_base.py

@@ -293,19 +293,11 @@ class BaseHandler(object):
 
         with PreserveLoggingContext():
             # Don't block waiting on waking up all the listeners.
-            notify_d = self.notifier.on_new_room_event(
+            self.notifier.on_new_room_event(
                 event, event_stream_id, max_stream_id,
                 extra_users=extra_users
             )
 
-        def log_failure(f):
-            logger.warn(
-                "Failed to notify about %s: %s",
-                event.event_id, f.value
-            )
-
-        notify_d.addErrback(log_failure)
-
         # If invite, remove room_state from unsigned before sending.
         event.unsigned.pop("invite_room_state", None)
 

+ 9 - 2
synapse/handlers/events.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.util.logutils import log_function
 from synapse.types import UserID
 from synapse.events.utils import serialize_event
+from synapse.util.logcontext import preserve_context_over_fn
 
 from ._base import BaseHandler
 
@@ -29,11 +30,17 @@ logger = logging.getLogger(__name__)
 
 
 def started_user_eventstream(distributor, user):
-    return distributor.fire("started_user_eventstream", user)
+    return preserve_context_over_fn(
+        distributor.fire,
+        "started_user_eventstream", user
+    )
 
 
 def stopped_user_eventstream(distributor, user):
-    return distributor.fire("stopped_user_eventstream", user)
+    return preserve_context_over_fn(
+        distributor.fire,
+        "stopped_user_eventstream", user
+    )
 
 
 class EventStreamHandler(BaseHandler):

+ 5 - 45
synapse/handlers/federation.py

@@ -221,19 +221,11 @@ class FederationHandler(BaseHandler):
                 extra_users.append(target_user)
 
             with PreserveLoggingContext():
-                d = self.notifier.on_new_room_event(
+                self.notifier.on_new_room_event(
                     event, event_stream_id, max_stream_id,
                     extra_users=extra_users
                 )
 
-            def log_failure(f):
-                logger.warn(
-                    "Failed to notify about %s: %s",
-                    event.event_id, f.value
-                )
-
-            d.addErrback(log_failure)
-
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 prev_state = context.current_state.get((event.type, event.state_key))
@@ -643,19 +635,11 @@ class FederationHandler(BaseHandler):
             )
 
             with PreserveLoggingContext():
-                d = self.notifier.on_new_room_event(
+                self.notifier.on_new_room_event(
                     event, event_stream_id, max_stream_id,
                     extra_users=[joinee]
                 )
 
-            def log_failure(f):
-                logger.warn(
-                    "Failed to notify about %s: %s",
-                    event.event_id, f.value
-                )
-
-            d.addErrback(log_failure)
-
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
             room_queue = self.room_queues[room_id]
@@ -730,18 +714,10 @@ class FederationHandler(BaseHandler):
             extra_users.append(target_user)
 
         with PreserveLoggingContext():
-            d = self.notifier.on_new_room_event(
+            self.notifier.on_new_room_event(
                 event, event_stream_id, max_stream_id, extra_users=extra_users
             )
 
-        def log_failure(f):
-            logger.warn(
-                "Failed to notify about %s: %s",
-                event.event_id, f.value
-            )
-
-        d.addErrback(log_failure)
-
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
                 user = UserID.from_string(event.state_key)
@@ -811,19 +787,11 @@ class FederationHandler(BaseHandler):
 
         target_user = UserID.from_string(event.state_key)
         with PreserveLoggingContext():
-            d = self.notifier.on_new_room_event(
+            self.notifier.on_new_room_event(
                 event, event_stream_id, max_stream_id,
                 extra_users=[target_user],
             )
 
-        def log_failure(f):
-            logger.warn(
-                "Failed to notify about %s: %s",
-                event.event_id, f.value
-            )
-
-        d.addErrback(log_failure)
-
         defer.returnValue(event)
 
     @defer.inlineCallbacks
@@ -948,18 +916,10 @@ class FederationHandler(BaseHandler):
             extra_users.append(target_user)
 
         with PreserveLoggingContext():
-            d = self.notifier.on_new_room_event(
+            self.notifier.on_new_room_event(
                 event, event_stream_id, max_stream_id, extra_users=extra_users
             )
 
-        def log_failure(f):
-            logger.warn(
-                "Failed to notify about %s: %s",
-                event.event_id, f.value
-            )
-
-        d.addErrback(log_failure)
-
         new_pdu = event
 
         destinations = set()

+ 11 - 9
synapse/handlers/presence.py

@@ -378,9 +378,9 @@ class PresenceHandler(BaseHandler):
         was_polling = target_user in self._user_cachemap
 
         if now_online and not was_polling:
-            self.start_polling_presence(target_user, state=state)
+            yield self.start_polling_presence(target_user, state=state)
         elif not now_online and was_polling:
-            self.stop_polling_presence(target_user)
+            yield self.stop_polling_presence(target_user)
 
         # TODO(paul): perform a presence push as part of start/stop poll so
         #   we don't have to do this all the time
@@ -394,7 +394,8 @@ class PresenceHandler(BaseHandler):
         if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
             return
 
-        self.changed_presencelike_data(user, {"last_active": now})
+        with PreserveLoggingContext():
+            self.changed_presencelike_data(user, {"last_active": now})
 
     def get_joined_rooms_for_user(self, user):
         """Get the list of rooms a user is joined to.
@@ -466,11 +467,12 @@ class PresenceHandler(BaseHandler):
                 local_user, room_ids=[room_id], add_to_cache=False
             )
 
-            self.push_update_to_local_and_remote(
-                observed_user=local_user,
-                users_to_push=[user],
-                statuscache=statuscache,
-            )
+            with PreserveLoggingContext():
+                self.push_update_to_local_and_remote(
+                    observed_user=local_user,
+                    users_to_push=[user],
+                    statuscache=statuscache,
+                )
 
     @defer.inlineCallbacks
     def send_presence_invite(self, observer_user, observed_user):
@@ -556,7 +558,7 @@ class PresenceHandler(BaseHandler):
             observer_user.localpart, observed_user.to_string()
         )
 
-        self.start_polling_presence(
+        yield self.start_polling_presence(
             observer_user, target_user=observed_user
         )
 

+ 1 - 1
synapse/handlers/register.py

@@ -186,7 +186,7 @@ class RegistrationHandler(BaseHandler):
             token=token,
             password_hash=""
         )
-        registered_user(self.distributor, user)
+        yield registered_user(self.distributor, user)
         defer.returnValue((user_id, token))
 
     @defer.inlineCallbacks

+ 9 - 2
synapse/handlers/room.py

@@ -25,6 +25,7 @@ from synapse.api.constants import (
 from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
 from synapse.util import stringutils, unwrapFirstError
 from synapse.util.async import run_on_reactor
+from synapse.util.logcontext import preserve_context_over_fn
 
 from signedjson.sign import verify_signed_json
 from signedjson.key import decode_verify_key_bytes
@@ -46,11 +47,17 @@ def collect_presencelike_data(distributor, user, content):
 
 
 def user_left_room(distributor, user, room_id):
-    return distributor.fire("user_left_room", user=user, room_id=room_id)
+    return preserve_context_over_fn(
+        distributor.fire,
+        "user_left_room", user=user, room_id=room_id
+    )
 
 
 def user_joined_room(distributor, user, room_id):
-    return distributor.fire("user_joined_room", user=user, room_id=room_id)
+    return preserve_context_over_fn(
+        distributor.fire,
+        "user_joined_room", user=user, room_id=room_id
+    )
 
 
 class RoomCreationHandler(BaseHandler):

+ 21 - 19
synapse/handlers/sync.py

@@ -18,7 +18,7 @@ from ._base import BaseHandler
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import Membership, EventTypes
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 
 from twisted.internet import defer
 
@@ -241,15 +241,16 @@ class SyncHandler(BaseHandler):
         deferreds = []
         for event in room_list:
             if event.membership == Membership.JOIN:
-                room_sync_deferred = self.full_state_sync_for_joined_room(
-                    room_id=event.room_id,
-                    sync_config=sync_config,
-                    now_token=now_token,
-                    timeline_since_token=timeline_since_token,
-                    ephemeral_by_room=ephemeral_by_room,
-                    tags_by_room=tags_by_room,
-                    account_data_by_room=account_data_by_room,
-                )
+                with PreserveLoggingContext(LoggingContext.current_context()):
+                    room_sync_deferred = self.full_state_sync_for_joined_room(
+                        room_id=event.room_id,
+                        sync_config=sync_config,
+                        now_token=now_token,
+                        timeline_since_token=timeline_since_token,
+                        ephemeral_by_room=ephemeral_by_room,
+                        tags_by_room=tags_by_room,
+                        account_data_by_room=account_data_by_room,
+                    )
                 room_sync_deferred.addCallback(joined.append)
                 deferreds.append(room_sync_deferred)
             elif event.membership == Membership.INVITE:
@@ -262,15 +263,16 @@ class SyncHandler(BaseHandler):
                 leave_token = now_token.copy_and_replace(
                     "room_key", "s%d" % (event.stream_ordering,)
                 )
-                room_sync_deferred = self.full_state_sync_for_archived_room(
-                    sync_config=sync_config,
-                    room_id=event.room_id,
-                    leave_event_id=event.event_id,
-                    leave_token=leave_token,
-                    timeline_since_token=timeline_since_token,
-                    tags_by_room=tags_by_room,
-                    account_data_by_room=account_data_by_room,
-                )
+                with PreserveLoggingContext(LoggingContext.current_context()):
+                    room_sync_deferred = self.full_state_sync_for_archived_room(
+                        sync_config=sync_config,
+                        room_id=event.room_id,
+                        leave_event_id=event.event_id,
+                        leave_token=leave_token,
+                        timeline_since_token=timeline_since_token,
+                        tags_by_room=tags_by_room,
+                        account_data_by_room=account_data_by_room,
+                    )
                 room_sync_deferred.addCallback(archived.append)
                 deferreds.append(room_sync_deferred)
 

+ 2 - 3
synapse/http/server.py

@@ -99,9 +99,8 @@ def request_handler(request_handler):
             request_context.request = request_id
             with request.processing():
                 try:
-                    d = request_handler(self, request)
-                    with PreserveLoggingContext():
-                        yield d
+                    with PreserveLoggingContext(request_context):
+                        yield request_handler(self, request)
                 except CodeMessageException as e:
                     code = e.code
                     if isinstance(e, SynapseError):

+ 29 - 29
synapse/notifier.py

@@ -18,7 +18,8 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 
 from synapse.util.logutils import log_function
-from synapse.util.async import run_on_reactor, ObservableDeferred
+from synapse.util.async import ObservableDeferred
+from synapse.util.logcontext import PreserveLoggingContext
 from synapse.types import StreamToken
 import synapse.metrics
 
@@ -73,7 +74,8 @@ class _NotifierUserStream(object):
         self.current_token = current_token
         self.last_notified_ms = time_now_ms
 
-        self.notify_deferred = ObservableDeferred(defer.Deferred())
+        with PreserveLoggingContext():
+            self.notify_deferred = ObservableDeferred(defer.Deferred())
 
     def notify(self, stream_key, stream_id, time_now_ms):
         """Notify any listeners for this user of a new event from an
@@ -88,8 +90,10 @@ class _NotifierUserStream(object):
         )
         self.last_notified_ms = time_now_ms
         noify_deferred = self.notify_deferred
-        self.notify_deferred = ObservableDeferred(defer.Deferred())
-        noify_deferred.callback(self.current_token)
+
+        with PreserveLoggingContext():
+            self.notify_deferred = ObservableDeferred(defer.Deferred())
+            noify_deferred.callback(self.current_token)
 
     def remove(self, notifier):
         """ Remove this listener from all the indexes in the Notifier
@@ -184,8 +188,6 @@ class Notifier(object):
             lambda: count(bool, self.appservice_to_user_streams.values()),
         )
 
-    @log_function
-    @defer.inlineCallbacks
     def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
                           extra_users=[]):
         """ Used by handlers to inform the notifier something has happened
@@ -199,12 +201,11 @@ class Notifier(object):
         until all previous events have been persisted before notifying
         the client streams.
         """
-        yield run_on_reactor()
-
-        self.pending_new_room_events.append((
-            room_stream_id, event, extra_users
-        ))
-        self._notify_pending_new_room_events(max_room_stream_id)
+        with PreserveLoggingContext():
+            self.pending_new_room_events.append((
+                room_stream_id, event, extra_users
+            ))
+            self._notify_pending_new_room_events(max_room_stream_id)
 
     def _notify_pending_new_room_events(self, max_room_stream_id):
         """Notify for the room events that were queued waiting for a previous
@@ -251,31 +252,29 @@ class Notifier(object):
             extra_streams=app_streams,
         )
 
-    @defer.inlineCallbacks
-    @log_function
     def on_new_event(self, stream_key, new_token, users=[], rooms=[],
                      extra_streams=set()):
         """ Used to inform listeners that something has happend event wise.
 
         Will wake up all listeners for the given users and rooms.
         """
-        yield run_on_reactor()
-        user_streams = set()
+        with PreserveLoggingContext():
+            user_streams = set()
 
-        for user in users:
-            user_stream = self.user_to_user_stream.get(str(user))
-            if user_stream is not None:
-                user_streams.add(user_stream)
+            for user in users:
+                user_stream = self.user_to_user_stream.get(str(user))
+                if user_stream is not None:
+                    user_streams.add(user_stream)
 
-        for room in rooms:
-            user_streams |= self.room_to_user_streams.get(room, set())
+            for room in rooms:
+                user_streams |= self.room_to_user_streams.get(room, set())
 
-        time_now_ms = self.clock.time_msec()
-        for user_stream in user_streams:
-            try:
-                user_stream.notify(stream_key, new_token, time_now_ms)
-            except:
-                logger.exception("Failed to notify listener")
+            time_now_ms = self.clock.time_msec()
+            for user_stream in user_streams:
+                try:
+                    user_stream.notify(stream_key, new_token, time_now_ms)
+                except:
+                    logger.exception("Failed to notify listener")
 
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
@@ -325,7 +324,8 @@ class Notifier(object):
                     # that we don't miss any current_token updates.
                     prev_token = current_token
                     listener = user_stream.new_listener(prev_token)
-                    yield listener.deferred
+                    with PreserveLoggingContext():
+                        yield listener.deferred
                 except defer.CancelledError:
                     break
 

+ 1 - 1
synapse/push/__init__.py

@@ -111,7 +111,7 @@ class Pusher(object):
                     self.user_id, config, timeout=0, affect_presence=False
                 )
                 self.last_token = chunk['end']
-                self.store.update_pusher_last_token(
+                yield self.store.update_pusher_last_token(
                     self.app_id, self.pushkey, self.user_id, self.last_token
                 )
                 logger.info("New pusher %s for user %s starting from token %s",

+ 5 - 4
synapse/push/pusherpool.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from httppusher import HttpPusher
 from synapse.push import PusherConfigException
+from synapse.util.logcontext import preserve_fn
 
 import logging
 
@@ -76,7 +77,7 @@ class PusherPool:
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     app_id, pushkey, p['user_name']
                 )
-                self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
     def remove_pushers_by_user(self, user_id):
@@ -91,7 +92,7 @@ class PusherPool:
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
-                self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
+                yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
     def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
@@ -110,7 +111,7 @@ class PusherPool:
             lang=lang,
             data=data,
         )
-        self._refresh_pusher(app_id, pushkey, user_id)
+        yield self._refresh_pusher(app_id, pushkey, user_id)
 
     def _create_pusher(self, pusherdict):
         if pusherdict['kind'] == 'http':
@@ -166,7 +167,7 @@ class PusherPool:
                 if fullid in self.pushers:
                     self.pushers[fullid].stop()
                 self.pushers[fullid] = p
-                p.start()
+                preserve_fn(p.start)()
 
         logger.info("Started pushers")
 

+ 2 - 2
synapse/rest/client/v2_alpha/account_data.py

@@ -57,7 +57,7 @@ class AccountDataServlet(RestServlet):
             user_id, account_data_type, body
         )
 
-        yield self.notifier.on_new_event(
+        self.notifier.on_new_event(
             "account_data_key", max_id, users=[user_id]
         )
 
@@ -99,7 +99,7 @@ class RoomAccountDataServlet(RestServlet):
             user_id, room_id, account_data_type, body
         )
 
-        yield self.notifier.on_new_event(
+        self.notifier.on_new_event(
             "account_data_key", max_id, users=[user_id]
         )
 

+ 2 - 2
synapse/rest/client/v2_alpha/tags.py

@@ -80,7 +80,7 @@ class TagServlet(RestServlet):
 
         max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
 
-        yield self.notifier.on_new_event(
+        self.notifier.on_new_event(
             "account_data_key", max_id, users=[user_id]
         )
 
@@ -94,7 +94,7 @@ class TagServlet(RestServlet):
 
         max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag)
 
-        yield self.notifier.on_new_event(
+        self.notifier.on_new_event(
             "account_data_key", max_id, users=[user_id]
         )
 

+ 9 - 9
synapse/storage/_base.py

@@ -15,7 +15,7 @@
 import logging
 
 from synapse.api.errors import StoreError
-from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
+from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
 import synapse.metrics
@@ -298,10 +298,10 @@ class SQLBaseStore(object):
                     func, *args, **kwargs
                 )
 
-        result = yield preserve_context_over_fn(
-            self._db_pool.runWithConnection,
-            inner_func, *args, **kwargs
-        )
+        with PreserveLoggingContext():
+            result = yield self._db_pool.runWithConnection(
+                inner_func, *args, **kwargs
+            )
 
         for after_callback, after_args in after_callbacks:
             after_callback(*after_args)
@@ -326,10 +326,10 @@ class SQLBaseStore(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield preserve_context_over_fn(
-            self._db_pool.runWithConnection,
-            inner_func, *args, **kwargs
-        )
+        with PreserveLoggingContext():
+            result = yield self._db_pool.runWithConnection(
+                inner_func, *args, **kwargs
+            )
 
         defer.returnValue(result)
 

+ 20 - 14
synapse/storage/events.py

@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
 from synapse.events import FrozenEvent, USE_FROZEN_DICTS
 from synapse.events.utils import prune_event
 
-from synapse.util.logcontext import preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, PreserveLoggingContext
 from synapse.util.logutils import log_function
 from synapse.api.constants import EventTypes
 
@@ -664,14 +664,16 @@ class EventsStore(SQLBaseStore):
                     for ids, d in lst:
                         if not d.called:
                             try:
-                                d.callback([
-                                    res[i]
-                                    for i in ids
-                                    if i in res
-                                ])
+                                with PreserveLoggingContext():
+                                    d.callback([
+                                        res[i]
+                                        for i in ids
+                                        if i in res
+                                    ])
                             except:
                                 logger.exception("Failed to callback")
-                reactor.callFromThread(fire, event_list, row_dict)
+                with PreserveLoggingContext():
+                    reactor.callFromThread(fire, event_list, row_dict)
             except Exception as e:
                 logger.exception("do_fetch")
 
@@ -679,10 +681,12 @@ class EventsStore(SQLBaseStore):
                 def fire(evs):
                     for _, d in evs:
                         if not d.called:
-                            d.errback(e)
+                            with PreserveLoggingContext():
+                                d.errback(e)
 
                 if event_list:
-                    reactor.callFromThread(fire, event_list)
+                    with PreserveLoggingContext():
+                        reactor.callFromThread(fire, event_list)
 
     @defer.inlineCallbacks
     def _enqueue_events(self, events, check_redacted=True,
@@ -709,18 +713,20 @@ class EventsStore(SQLBaseStore):
                 should_start = False
 
         if should_start:
-            self.runWithConnection(
-                self._do_fetch
-            )
+            with PreserveLoggingContext():
+                self.runWithConnection(
+                    self._do_fetch
+                )
 
-        rows = yield preserve_context_over_deferred(events_d)
+        with PreserveLoggingContext():
+            rows = yield events_d
 
         if not allow_rejected:
             rows[:] = [r for r in rows if not r["rejects"]]
 
         res = yield defer.gatherResults(
             [
-                self._get_event_from_row(
+                preserve_fn(self._get_event_from_row)(
                     row["internal_metadata"], row["json"], row["redacts"],
                     check_redacted=check_redacted,
                     get_prev_content=get_prev_content,

+ 3 - 2
synapse/storage/presence.py

@@ -68,8 +68,9 @@ class PresenceStore(SQLBaseStore):
             for row in rows
         })
 
+    @defer.inlineCallbacks
     def set_presence_state(self, user_localpart, new_state):
-        res = self._simple_update_one(
+        res = yield self._simple_update_one(
             table="presence",
             keyvalues={"user_id": user_localpart},
             updatevalues={"state": new_state["state"],
@@ -79,7 +80,7 @@ class PresenceStore(SQLBaseStore):
         )
 
         self.get_presence_state.invalidate((user_localpart,))
-        return res
+        defer.returnValue(res)
 
     def allow_presence_visible(self, observed_localpart, observer_userid):
         return self._simple_insert(

+ 5 - 4
synapse/storage/stream.py

@@ -39,6 +39,7 @@ from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 from synapse.api.constants import EventTypes
 from synapse.types import RoomStreamToken
+from synapse.util.logcontext import preserve_fn
 
 import logging
 
@@ -170,12 +171,12 @@ class StreamStore(SQLBaseStore):
         room_ids = list(room_ids)
         for rm_ids in (room_ids[i:i + 20] for i in xrange(0, len(room_ids), 20)):
             res = yield defer.gatherResults([
-                self.get_room_events_stream_for_room(
-                    room_id, from_key, to_key, limit
-                ).addCallback(lambda r, rm: (rm, r), room_id)
+                preserve_fn(self.get_room_events_stream_for_room)(
+                    room_id, from_key, to_key, limit,
+                )
                 for room_id in room_ids
             ])
-            results.update(dict(res))
+            results.update(dict(zip(rm_ids, res)))
 
         defer.returnValue(results)
 

+ 2 - 4
synapse/util/__init__.py

@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext
 
 from twisted.internet import defer, reactor, task
 
@@ -61,10 +61,8 @@ class Clock(object):
             *args: Postional arguments to pass to function.
             **kwargs: Key arguments to pass to function.
         """
-        current_context = LoggingContext.current_context()
-
         def wrapped_callback(*args, **kwargs):
-            with PreserveLoggingContext(current_context):
+            with PreserveLoggingContext():
                 callback(*args, **kwargs)
 
         with PreserveLoggingContext():

+ 8 - 3
synapse/util/async.py

@@ -16,13 +16,16 @@
 
 from twisted.internet import defer, reactor
 
-from .logcontext import preserve_context_over_deferred
+from .logcontext import PreserveLoggingContext
 
 
+@defer.inlineCallbacks
 def sleep(seconds):
     d = defer.Deferred()
-    reactor.callLater(seconds, d.callback, seconds)
-    return preserve_context_over_deferred(d)
+    with PreserveLoggingContext():
+        reactor.callLater(seconds, d.callback, seconds)
+        res = yield d
+    defer.returnValue(res)
 
 
 def run_on_reactor():
@@ -54,6 +57,7 @@ class ObservableDeferred(object):
             object.__setattr__(self, "_result", (True, r))
             while self._observers:
                 try:
+                    # TODO: Handle errors here.
                     self._observers.pop().callback(r)
                 except:
                     pass
@@ -63,6 +67,7 @@ class ObservableDeferred(object):
             object.__setattr__(self, "_result", (False, f))
             while self._observers:
                 try:
+                    # TODO: Handle errors here.
                     self._observers.pop().errback(f)
                 except:
                     pass

+ 11 - 5
synapse/util/caches/descriptors.py

@@ -18,6 +18,9 @@ from synapse.util.async import ObservableDeferred
 from synapse.util import unwrapFirstError
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.treecache import TreeCache
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
+)
 
 from . import caches_by_name, DEBUG_CACHES, cache_counter
 
@@ -190,7 +193,7 @@ class CacheDescriptor(object):
                         defer.returnValue(cached_result)
                     observer.addCallback(check_result)
 
-                return observer
+                return preserve_context_over_deferred(observer)
             except KeyError:
                 # Get the sequence number of the cache before reading from the
                 # database so that we can tell if the cache is invalidated
@@ -198,6 +201,7 @@ class CacheDescriptor(object):
                 sequence = self.cache.sequence
 
                 ret = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     obj, *args, **kwargs
                 )
@@ -211,7 +215,7 @@ class CacheDescriptor(object):
                 ret = ObservableDeferred(ret, consumeErrors=True)
                 self.cache.update(sequence, cache_key, ret)
 
-                return ret.observe()
+                return preserve_context_over_deferred(ret.observe())
 
         wrapped.invalidate = self.cache.invalidate
         wrapped.invalidate_all = self.cache.invalidate_all
@@ -299,6 +303,7 @@ class CacheListDescriptor(object):
                 args_to_call[self.list_name] = missing
 
                 ret_d = defer.maybeDeferred(
+                    preserve_context_over_fn,
                     self.function_to_call,
                     **args_to_call
                 )
@@ -308,7 +313,8 @@ class CacheListDescriptor(object):
                 # We need to create deferreds for each arg in the list so that
                 # we can insert the new deferred into the cache.
                 for arg in missing:
-                    observer = ret_d.observe()
+                    with PreserveLoggingContext():
+                        observer = ret_d.observe()
                     observer.addCallback(lambda r, arg: r.get(arg, None), arg)
 
                     observer = ObservableDeferred(observer)
@@ -327,10 +333,10 @@ class CacheListDescriptor(object):
 
                     cached[arg] = res
 
-            return defer.gatherResults(
+            return preserve_context_over_deferred(defer.gatherResults(
                 cached.values(),
                 consumeErrors=True,
-            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res))
+            ).addErrback(unwrapFirstError).addCallback(lambda res: dict(res)))
 
         obj.__dict__[self.orig.__name__] = wrapped
 

+ 2 - 1
synapse/util/caches/snapshot_cache.py

@@ -87,7 +87,8 @@ class SnapshotCache(object):
             # expire from the rotation of that cache.
             self.next_result_cache[key] = result
             self.pending_result_cache.pop(key, None)
+            return r
 
-        result.observe().addBoth(shuffle_along)
+        result.addBoth(shuffle_along)
 
         return result.observe()

+ 9 - 6
synapse/util/distributor.py

@@ -15,9 +15,7 @@
 
 from twisted.internet import defer
 
-from synapse.util.logcontext import (
-    PreserveLoggingContext, preserve_context_over_deferred,
-)
+from synapse.util.logcontext import PreserveLoggingContext
 
 from synapse.util import unwrapFirstError
 
@@ -97,6 +95,7 @@ class Signal(object):
         Each observer callable may return a Deferred."""
         self.observers.append(observer)
 
+    @defer.inlineCallbacks
     def fire(self, *args, **kwargs):
         """Invokes every callable in the observer list, passing in the args and
         kwargs. Exceptions thrown by observers are logged but ignored. It is
@@ -116,6 +115,7 @@ class Signal(object):
                         failure.getTracebackObject()))
                 if not self.suppress_failures:
                     return failure
+
             return defer.maybeDeferred(observer, *args, **kwargs).addErrback(eb)
 
         with PreserveLoggingContext():
@@ -124,8 +124,11 @@ class Signal(object):
                 for observer in self.observers
             ]
 
-            d = defer.gatherResults(deferreds, consumeErrors=True)
+            res = yield defer.gatherResults(
+                deferreds, consumeErrors=True
+            ).addErrback(unwrapFirstError)
 
-        d.addErrback(unwrapFirstError)
+        defer.returnValue(res)
 
-        return preserve_context_over_deferred(d)
+    def __repr__(self):
+        return "<Signal name=%r>" % (self.name,)

+ 94 - 4
synapse/util/logcontext.py

@@ -48,7 +48,7 @@ class LoggingContext(object):
 
     __slots__ = [
         "parent_context", "name", "usage_start", "usage_end", "main_thread",
-        "__dict__", "tag",
+        "__dict__", "tag", "alive",
     ]
 
     thread_local = threading.local()
@@ -88,6 +88,7 @@ class LoggingContext(object):
         self.usage_start = None
         self.main_thread = threading.current_thread()
         self.tag = ""
+        self.alive = True
 
     def __str__(self):
         return "%s@%x" % (self.name, id(self))
@@ -106,6 +107,7 @@ class LoggingContext(object):
             The context that was previously active
         """
         current = cls.current_context()
+
         if current is not context:
             current.stop()
             cls.thread_local.current_context = context
@@ -117,6 +119,7 @@ class LoggingContext(object):
         if self.parent_context is not None:
             raise Exception("Attempt to enter logging context multiple times")
         self.parent_context = self.set_current_context(self)
+        self.alive = True
         return self
 
     def __exit__(self, type, value, traceback):
@@ -136,6 +139,7 @@ class LoggingContext(object):
                     self
                 )
         self.parent_context = None
+        self.alive = False
 
     def __getattr__(self, name):
         """Delegate member lookup to parent context"""
@@ -213,7 +217,7 @@ class PreserveLoggingContext(object):
     exited. Used to restore the context after a function using
     @defer.inlineCallbacks is resumed by a callback from the reactor."""
 
-    __slots__ = ["current_context", "new_context"]
+    __slots__ = ["current_context", "new_context", "has_parent"]
 
     def __init__(self, new_context=LoggingContext.sentinel):
         self.new_context = new_context
@@ -224,11 +228,26 @@ class PreserveLoggingContext(object):
             self.new_context
         )
 
+        if self.current_context:
+            self.has_parent = self.current_context.parent_context is not None
+            if not self.current_context.alive:
+                logger.warn(
+                    "Entering dead context: %s",
+                    self.current_context,
+                )
+
     def __exit__(self, type, value, traceback):
         """Restores the current logging context"""
-        LoggingContext.set_current_context(self.current_context)
+        context = LoggingContext.set_current_context(self.current_context)
+
+        if context != self.new_context:
+            logger.warn(
+                "Unexpected logging context: %s is not %s",
+                context, self.new_context,
+            )
+
         if self.current_context is not LoggingContext.sentinel:
-            if self.current_context.parent_context is None:
+            if not self.current_context.alive:
                 logger.warn(
                     "Restoring dead context: %s",
                     self.current_context,
@@ -289,3 +308,74 @@ def preserve_context_over_deferred(deferred):
     d = _PreservingContextDeferred(current_context)
     deferred.chainDeferred(d)
     return d
+
+
+def preserve_fn(f):
+    """Ensures that function is called with correct context and that context is
+    restored after return. Useful for wrapping functions that return a deferred
+    which you don't yield on.
+    """
+    current = LoggingContext.current_context()
+
+    def g(*args, **kwargs):
+        with PreserveLoggingContext(current):
+            return f(*args, **kwargs)
+
+    return g
+
+
+# modules to ignore in `logcontext_tracer`
+_to_ignore = [
+    "synapse.util.logcontext",
+    "synapse.http.server",
+    "synapse.storage._base",
+    "synapse.util.async",
+]
+
+
+def logcontext_tracer(frame, event, arg):
+    """A tracer that logs whenever a logcontext "unexpectedly" changes within
+    a function. Probably inaccurate.
+
+    Use by calling `sys.settrace(logcontext_tracer)` in the main thread.
+    """
+    if event == 'call':
+        name = frame.f_globals["__name__"]
+        if name.startswith("synapse"):
+            if name == "synapse.util.logcontext":
+                if frame.f_code.co_name in ["__enter__", "__exit__"]:
+                    tracer = frame.f_back.f_trace
+                    if tracer:
+                        tracer.just_changed = True
+
+            tracer = frame.f_trace
+            if tracer:
+                return tracer
+
+            if not any(name.startswith(ig) for ig in _to_ignore):
+                return LineTracer()
+
+
+class LineTracer(object):
+    __slots__ = ["context", "just_changed"]
+
+    def __init__(self):
+        self.context = LoggingContext.current_context()
+        self.just_changed = False
+
+    def __call__(self, frame, event, arg):
+        if event in 'line':
+            if self.just_changed:
+                self.context = LoggingContext.current_context()
+                self.just_changed = False
+            else:
+                c = LoggingContext.current_context()
+                if c != self.context:
+                    logger.info(
+                        "Context changed! %s -> %s, %s, %s",
+                        self.context, c,
+                        frame.f_code.co_filename, frame.f_lineno
+                    )
+                    self.context = c
+
+        return self

+ 35 - 0
synapse/util/logutils.py

@@ -168,3 +168,38 @@ def trace_function(f):
 
     wrapped.__name__ = func_name
     return wrapped
+
+
+def get_previous_frames():
+    s = inspect.currentframe().f_back.f_back
+    to_return = []
+    while s:
+        if s.f_globals["__name__"].startswith("synapse"):
+            filename, lineno, function, _, _ = inspect.getframeinfo(s)
+            args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+            to_return.append("{{  %s:%d %s - Args: %s }}" % (
+                filename, lineno, function, args_string
+            ))
+
+        s = s.f_back
+
+    return ", ". join(to_return)
+
+
+def get_previous_frame(ignore=[]):
+    s = inspect.currentframe().f_back.f_back
+
+    while s:
+        if s.f_globals["__name__"].startswith("synapse"):
+            if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
+                filename, lineno, function, _, _ = inspect.getframeinfo(s)
+                args_string = inspect.formatargvalues(*inspect.getargvalues(s))
+
+                return "{{  %s:%d %s - Args: %s }}" % (
+                    filename, lineno, function, args_string
+                )
+
+        s = s.f_back
+
+    return None

+ 6 - 4
synapse/util/metrics.py

@@ -68,16 +68,18 @@ class Measure(object):
         block_timer.inc_by(duration, self.name)
 
         context = LoggingContext.current_context()
-        if not context:
-            return
 
         if context != self.start_context:
             logger.warn(
-                "Context have unexpectedly changed %r, %r",
-                context, self.start_context
+                "Context have unexpectedly changed from '%s' to '%s'. (%r)",
+                context, self.start_context, self.name
             )
             return
 
+        if not context:
+            logger.warn("Expected context. (%r)", self.name)
+            return
+
         ru_utime, ru_stime = context.get_resource_usage()
 
         block_ru_utime.inc_by(ru_utime, self.name)

+ 2 - 1
synapse/util/ratelimitutils.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import LimitExceededError
 
 from synapse.util.async import sleep
+from synapse.util.logcontext import preserve_fn
 
 import collections
 import contextlib
@@ -163,7 +164,7 @@ class _PerHostRatelimiter(object):
                 "Ratelimit [%s]: sleeping req",
                 id(request_id),
             )
-            ret_defer = sleep(self.sleep_msec / 1000.0)
+            ret_defer = preserve_fn(sleep)(self.sleep_msec / 1000.0)
 
             self.sleeping_requests.add(request_id)