Browse Source

Convert internal pusher dicts to attrs classes. (#8940)

This improves type hinting and should use less memory.
Patrick Cloke 3 years ago
parent
commit
bd30cfe86a

+ 1 - 0
changelog.d/8940.misc

@@ -0,0 +1 @@
+Add type hints to push module.

+ 1 - 0
mypy.ini

@@ -65,6 +65,7 @@ files =
   synapse/state,
   synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/events.py,
+  synapse/storage/databases/main/pusher.py,
   synapse/storage/databases/main/registration.py,
   synapse/storage/databases/main/stream.py,
   synapse/storage/databases/main/ui_auth.py,

+ 53 - 7
synapse/push/__init__.py

@@ -14,24 +14,70 @@
 # limitations under the License.
 
 import abc
-from typing import TYPE_CHECKING, Any, Dict
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
-from synapse.types import RoomStreamToken
+import attr
+
+from synapse.types import JsonDict, RoomStreamToken
 
 if TYPE_CHECKING:
     from synapse.app.homeserver import HomeServer
 
 
+@attr.s(slots=True)
+class PusherConfig:
+    """Parameters necessary to configure a pusher."""
+
+    id = attr.ib(type=Optional[str])
+    user_name = attr.ib(type=str)
+    access_token = attr.ib(type=Optional[int])
+    profile_tag = attr.ib(type=str)
+    kind = attr.ib(type=str)
+    app_id = attr.ib(type=str)
+    app_display_name = attr.ib(type=str)
+    device_display_name = attr.ib(type=str)
+    pushkey = attr.ib(type=str)
+    ts = attr.ib(type=int)
+    lang = attr.ib(type=Optional[str])
+    data = attr.ib(type=Optional[JsonDict])
+    last_stream_ordering = attr.ib(type=Optional[int])
+    last_success = attr.ib(type=Optional[int])
+    failing_since = attr.ib(type=Optional[int])
+
+    def as_dict(self) -> Dict[str, Any]:
+        """Information that can be retrieved about a pusher after creation."""
+        return {
+            "app_display_name": self.app_display_name,
+            "app_id": self.app_id,
+            "data": self.data,
+            "device_display_name": self.device_display_name,
+            "kind": self.kind,
+            "lang": self.lang,
+            "profile_tag": self.profile_tag,
+            "pushkey": self.pushkey,
+        }
+
+
+@attr.s(slots=True)
+class ThrottleParams:
+    """Parameters for controlling the rate of sending pushes via email."""
+
+    last_sent_ts = attr.ib(type=int)
+    throttle_ms = attr.ib(type=int)
+
+
 class Pusher(metaclass=abc.ABCMeta):
-    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
         self.hs = hs
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
 
-        self.pusher_id = pusherdict["id"]
-        self.user_id = pusherdict["user_name"]
-        self.app_id = pusherdict["app_id"]
-        self.pushkey = pusherdict["pushkey"]
+        self.pusher_id = pusher_config.id
+        self.user_id = pusher_config.user_name
+        self.app_id = pusher_config.app_id
+        self.pushkey = pusher_config.pushkey
+
+        self.last_stream_ordering = pusher_config.last_stream_ordering
 
         # This is the highest stream ordering we know it's safe to process.
         # When new events arrive, we'll be given a window of new events: we

+ 14 - 13
synapse/push/emailpusher.py

@@ -14,13 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
 
 from twisted.internet.base import DelayedCall
 from twisted.internet.error import AlreadyCalled, AlreadyCancelled
 
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig, ThrottleParams
 from synapse.push.mailer import Mailer
 
 if TYPE_CHECKING:
@@ -60,15 +60,14 @@ class EmailPusher(Pusher):
     factor out the common parts
     """
 
-    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
-        super().__init__(hs, pusherdict)
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig, mailer: Mailer):
+        super().__init__(hs, pusher_config)
         self.mailer = mailer
 
         self.store = self.hs.get_datastore()
-        self.email = pusherdict["pushkey"]
-        self.last_stream_ordering = pusherdict["last_stream_ordering"]
+        self.email = pusher_config.pushkey
         self.timed_call = None  # type: Optional[DelayedCall]
-        self.throttle_params = {}  # type: Dict[str, Dict[str, int]]
+        self.throttle_params = {}  # type: Dict[str, ThrottleParams]
         self._inited = False
 
         self._is_processing = False
@@ -132,6 +131,7 @@ class EmailPusher(Pusher):
 
             if not self._inited:
                 # this is our first loop: load up the throttle params
+                assert self.pusher_id is not None
                 self.throttle_params = await self.store.get_throttle_params_by_room(
                     self.pusher_id
                 )
@@ -157,6 +157,7 @@ class EmailPusher(Pusher):
         being run.
         """
         start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
+        assert start is not None
         unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
             self.user_id, start, self.max_stream_ordering
         )
@@ -244,13 +245,13 @@ class EmailPusher(Pusher):
 
     def get_room_throttle_ms(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["throttle_ms"]
+            return self.throttle_params[room_id].throttle_ms
         else:
             return 0
 
     def get_room_last_sent_ts(self, room_id: str) -> int:
         if room_id in self.throttle_params:
-            return self.throttle_params[room_id]["last_sent_ts"]
+            return self.throttle_params[room_id].last_sent_ts
         else:
             return 0
 
@@ -301,10 +302,10 @@ class EmailPusher(Pusher):
                 new_throttle_ms = min(
                     current_throttle_ms * THROTTLE_MULTIPLIER, THROTTLE_MAX_MS
                 )
-        self.throttle_params[room_id] = {
-            "last_sent_ts": self.clock.time_msec(),
-            "throttle_ms": new_throttle_ms,
-        }
+        self.throttle_params[room_id] = ThrottleParams(
+            self.clock.time_msec(), new_throttle_ms,
+        )
+        assert self.pusher_id is not None
         await self.store.set_throttle_params(
             self.pusher_id, room_id, self.throttle_params[room_id]
         )

+ 18 - 18
synapse/push/httppusher.py

@@ -25,7 +25,7 @@ from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.logging import opentracing
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
 
 from . import push_rule_evaluator, push_tools
 
@@ -62,33 +62,29 @@ class HttpPusher(Pusher):
     # This one's in ms because we compare it against the clock
     GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
 
-    def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
-        super().__init__(hs, pusherdict)
+    def __init__(self, hs: "HomeServer", pusher_config: PusherConfig):
+        super().__init__(hs, pusher_config)
         self.storage = self.hs.get_storage()
-        self.app_display_name = pusherdict["app_display_name"]
-        self.device_display_name = pusherdict["device_display_name"]
-        self.pushkey_ts = pusherdict["ts"]
-        self.data = pusherdict["data"]
-        self.last_stream_ordering = pusherdict["last_stream_ordering"]
+        self.app_display_name = pusher_config.app_display_name
+        self.device_display_name = pusher_config.device_display_name
+        self.pushkey_ts = pusher_config.ts
+        self.data = pusher_config.data
         self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
-        self.failing_since = pusherdict["failing_since"]
+        self.failing_since = pusher_config.failing_since
         self.timed_call = None
         self._is_processing = False
         self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
 
-        if "data" not in pusherdict:
-            raise PusherConfigException("No 'data' key for HTTP pusher")
-        self.data = pusherdict["data"]
+        self.data = pusher_config.data
+        if self.data is None:
+            raise PusherConfigException("'data' key can not be null for HTTP pusher")
 
         self.name = "%s/%s/%s" % (
-            pusherdict["user_name"],
-            pusherdict["app_id"],
-            pusherdict["pushkey"],
+            pusher_config.user_name,
+            pusher_config.app_id,
+            pusher_config.pushkey,
         )
 
-        if self.data is None:
-            raise PusherConfigException("data can not be null for HTTP pusher")
-
         # Validate that there's a URL and it is of the proper form.
         if "url" not in self.data:
             raise PusherConfigException("'url' required in data for HTTP pusher")
@@ -180,6 +176,7 @@ class HttpPusher(Pusher):
         Never call this directly: use _process which will only allow this to
         run once per pusher.
         """
+        assert self.last_stream_ordering is not None
         unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_http(
             self.user_id, self.last_stream_ordering, self.max_stream_ordering
         )
@@ -208,6 +205,7 @@ class HttpPusher(Pusher):
                 http_push_processed_counter.inc()
                 self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
                 self.last_stream_ordering = push_action["stream_ordering"]
+                assert self.last_stream_ordering is not None
                 pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
                     self.app_id,
                     self.pushkey,
@@ -314,6 +312,8 @@ class HttpPusher(Pusher):
             #  or may do so (i.e. is encrypted so has unknown effects).
             priority = "high"
 
+        # This was checked in the __init__, but mypy doesn't seem to know that.
+        assert self.data is not None
         if self.data.get("format") == "event_id_only":
             d = {
                 "notification": {

+ 12 - 12
synapse/push/pusher.py

@@ -14,9 +14,9 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Callable, Dict, Optional
 
-from synapse.push import Pusher
+from synapse.push import Pusher, PusherConfig
 from synapse.push.emailpusher import EmailPusher
 from synapse.push.httppusher import HttpPusher
 from synapse.push.mailer import Mailer
@@ -34,7 +34,7 @@ class PusherFactory:
 
         self.pusher_types = {
             "http": HttpPusher
-        }  # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
+        }  # type: Dict[str, Callable[[HomeServer, PusherConfig], Pusher]]
 
         logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
         if hs.config.email_enable_notifs:
@@ -47,18 +47,18 @@ class PusherFactory:
 
             logger.info("defined email pusher type")
 
-    def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
-        kind = pusherdict["kind"]
+    def create_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
+        kind = pusher_config.kind
         f = self.pusher_types.get(kind, None)
         if not f:
             return None
-        logger.debug("creating %s pusher for %r", kind, pusherdict)
-        return f(self.hs, pusherdict)
+        logger.debug("creating %s pusher for %r", kind, pusher_config)
+        return f(self.hs, pusher_config)
 
     def _create_email_pusher(
-        self, _hs: "HomeServer", pusherdict: Dict[str, Any]
+        self, _hs: "HomeServer", pusher_config: PusherConfig
     ) -> EmailPusher:
-        app_name = self._app_name_from_pusherdict(pusherdict)
+        app_name = self._app_name_from_pusherdict(pusher_config)
         mailer = self.mailers.get(app_name)
         if not mailer:
             mailer = Mailer(
@@ -68,10 +68,10 @@ class PusherFactory:
                 template_text=self._notif_template_text,
             )
             self.mailers[app_name] = mailer
-        return EmailPusher(self.hs, pusherdict, mailer)
+        return EmailPusher(self.hs, pusher_config, mailer)
 
-    def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
-        data = pusherdict["data"]
+    def _app_name_from_pusherdict(self, pusher_config: PusherConfig) -> str:
+        data = pusher_config.data
 
         if isinstance(data, dict):
             brand = data.get("brand")

+ 70 - 65
synapse/push/pusherpool.py

@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
 from prometheus_client import Gauge
 
@@ -23,9 +23,9 @@ from synapse.metrics.background_process_metrics import (
     run_as_background_process,
     wrap_as_background_process,
 )
-from synapse.push import Pusher, PusherConfigException
+from synapse.push import Pusher, PusherConfig, PusherConfigException
 from synapse.push.pusher import PusherFactory
-from synapse.types import RoomStreamToken
+from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.async_helpers import concurrently_execute
 
 if TYPE_CHECKING:
@@ -77,7 +77,7 @@ class PusherPool:
         # map from user id to app_id:pushkey to pusher
         self.pushers = {}  # type: Dict[str, Dict[str, Pusher]]
 
-    def start(self):
+    def start(self) -> None:
         """Starts the pushers off in a background process.
         """
         if not self._should_start_pushers:
@@ -87,16 +87,16 @@ class PusherPool:
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        lang,
-        data,
-        profile_tag="",
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        lang: Optional[str],
+        data: JsonDict,
+        profile_tag: str = "",
     ) -> Optional[Pusher]:
         """Creates a new pusher and adds it to the pool
 
@@ -111,21 +111,23 @@ class PusherPool:
         # recreated, added and started: this means we have only one
         # code path adding pushers.
         self.pusher_factory.create_pusher(
-            {
-                "id": None,
-                "user_name": user_id,
-                "kind": kind,
-                "app_id": app_id,
-                "app_display_name": app_display_name,
-                "device_display_name": device_display_name,
-                "pushkey": pushkey,
-                "ts": time_now_msec,
-                "lang": lang,
-                "data": data,
-                "last_stream_ordering": None,
-                "last_success": None,
-                "failing_since": None,
-            }
+            PusherConfig(
+                id=None,
+                user_name=user_id,
+                access_token=access_token,
+                profile_tag=profile_tag,
+                kind=kind,
+                app_id=app_id,
+                app_display_name=app_display_name,
+                device_display_name=device_display_name,
+                pushkey=pushkey,
+                ts=time_now_msec,
+                lang=lang,
+                data=data,
+                last_stream_ordering=None,
+                last_success=None,
+                failing_since=None,
+            )
         )
 
         # create the pusher setting last_stream_ordering to the current maximum
@@ -151,43 +153,44 @@ class PusherPool:
         return pusher
 
     async def remove_pushers_by_app_id_and_pushkey_not_user(
-        self, app_id, pushkey, not_user_id
-    ):
+        self, app_id: str, pushkey: str, not_user_id: str
+    ) -> None:
         to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
         for p in to_remove:
-            if p["user_name"] != not_user_id:
+            if p.user_name != not_user_id:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     app_id,
                     pushkey,
-                    p["user_name"],
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    async def remove_pushers_by_access_token(self, user_id, access_tokens):
+    async def remove_pushers_by_access_token(
+        self, user_id: str, access_tokens: Iterable[int]
+    ) -> None:
         """Remove the pushers for a given user corresponding to a set of
         access_tokens.
 
         Args:
-            user_id (str): user to remove pushers for
-            access_tokens (Iterable[int]): access token *ids* to remove pushers
-                for
+            user_id: user to remove pushers for
+            access_tokens: access token *ids* to remove pushers for
         """
         if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
             return
 
         tokens = set(access_tokens)
         for p in await self.store.get_pushers_by_user_id(user_id):
-            if p["access_token"] in tokens:
+            if p.access_token in tokens:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
-                    p["app_id"],
-                    p["pushkey"],
-                    p["user_name"],
+                    p.app_id,
+                    p.pushkey,
+                    p.user_name,
                 )
-                await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
+                await self.remove_pusher(p.app_id, p.pushkey, p.user_name)
 
-    def on_new_notifications(self, max_token: RoomStreamToken):
+    def on_new_notifications(self, max_token: RoomStreamToken) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -206,7 +209,7 @@ class PusherPool:
         self._on_new_notifications(max_token)
 
     @wrap_as_background_process("on_new_notifications")
-    async def _on_new_notifications(self, max_token: RoomStreamToken):
+    async def _on_new_notifications(self, max_token: RoomStreamToken) -> None:
         # We just use the minimum stream ordering and ignore the vector clock
         # component. This is safe to do as long as we *always* ignore the vector
         # clock components.
@@ -236,7 +239,9 @@ class PusherPool:
         except Exception:
             logger.exception("Exception in pusher on_new_notifications")
 
-    async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
+    async def on_new_receipts(
+        self, min_stream_id: int, max_stream_id: int, affected_room_ids: Iterable[str]
+    ) -> None:
         if not self.pushers:
             # nothing to do here.
             return
@@ -280,14 +285,14 @@ class PusherPool:
 
         resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
 
-        pusher_dict = None
+        pusher_config = None
         for r in resultlist:
-            if r["user_name"] == user_id:
-                pusher_dict = r
+            if r.user_name == user_id:
+                pusher_config = r
 
         pusher = None
-        if pusher_dict:
-            pusher = await self._start_pusher(pusher_dict)
+        if pusher_config:
+            pusher = await self._start_pusher(pusher_config)
 
         return pusher
 
@@ -302,44 +307,44 @@ class PusherPool:
 
         logger.info("Started pushers")
 
-    async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
+    async def _start_pusher(self, pusher_config: PusherConfig) -> Optional[Pusher]:
         """Start the given pusher
 
         Args:
-            pusherdict: dict with the values pulled from the db table
+            pusher_config: The pusher configuration with the values pulled from the db table
 
         Returns:
             The newly created pusher or None.
         """
         if not self._pusher_shard_config.should_handle(
-            self._instance_name, pusherdict["user_name"]
+            self._instance_name, pusher_config.user_name
         ):
             return None
 
         try:
-            p = self.pusher_factory.create_pusher(pusherdict)
+            p = self.pusher_factory.create_pusher(pusher_config)
         except PusherConfigException as e:
             logger.warning(
                 "Pusher incorrectly configured id=%i, user=%s, appid=%s, pushkey=%s: %s",
-                pusherdict["id"],
-                pusherdict.get("user_name"),
-                pusherdict.get("app_id"),
-                pusherdict.get("pushkey"),
+                pusher_config.id,
+                pusher_config.user_name,
+                pusher_config.app_id,
+                pusher_config.pushkey,
                 e,
             )
             return None
         except Exception:
             logger.exception(
-                "Couldn't start pusher id %i: caught Exception", pusherdict["id"],
+                "Couldn't start pusher id %i: caught Exception", pusher_config.id,
             )
             return None
 
         if not p:
             return None
 
-        appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
+        appid_pushkey = "%s:%s" % (pusher_config.app_id, pusher_config.pushkey)
 
-        byuser = self.pushers.setdefault(pusherdict["user_name"], {})
+        byuser = self.pushers.setdefault(pusher_config.user_name, {})
         if appid_pushkey in byuser:
             byuser[appid_pushkey].on_stop()
         byuser[appid_pushkey] = p
@@ -349,8 +354,8 @@ class PusherPool:
         # Check if there *may* be push to process. We do this as this check is a
         # lot cheaper to do than actually fetching the exact rows we need to
         # push.
-        user_id = pusherdict["user_name"]
-        last_stream_ordering = pusherdict["last_stream_ordering"]
+        user_id = pusher_config.user_name
+        last_stream_ordering = pusher_config.last_stream_ordering
         if last_stream_ordering:
             have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
                 user_id, last_stream_ordering
@@ -364,7 +369,7 @@ class PusherPool:
 
         return p
 
-    async def remove_pusher(self, app_id, pushkey, user_id):
+    async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
         appid_pushkey = "%s:%s" % (app_id, pushkey)
 
         byuser = self.pushers.get(user_id, {})

+ 15 - 5
synapse/replication/slave/storage/_slaved_id_tracker.py

@@ -12,21 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import List, Optional, Tuple
 
+from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import _load_current_id
 
 
 class SlavedIdTracker:
-    def __init__(self, db_conn, table, column, extra_tables=[], step=1):
+    def __init__(
+        self,
+        db_conn: Connection,
+        table: str,
+        column: str,
+        extra_tables: Optional[List[Tuple[str, str]]] = None,
+        step: int = 1,
+    ):
         self.step = step
         self._current = _load_current_id(db_conn, table, column, step)
-        for table, column in extra_tables:
-            self.advance(None, _load_current_id(db_conn, table, column))
+        if extra_tables:
+            for table, column in extra_tables:
+                self.advance(None, _load_current_id(db_conn, table, column))
 
-    def advance(self, instance_name, new_id):
+    def advance(self, instance_name: Optional[str], new_id: int):
         self._current = (max if self.step > 0 else min)(self._current, new_id)
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         """
 
         Returns:

+ 12 - 5
synapse/replication/slave/storage/pushers.py

@@ -13,26 +13,33 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import TYPE_CHECKING
 
 from synapse.replication.tcp.streams import PushersStream
 from synapse.storage.database import DatabasePool
 from synapse.storage.databases.main.pusher import PusherWorkerStore
+from synapse.storage.types import Connection
 
 from ._base import BaseSlavedStore
 from ._slaved_id_tracker import SlavedIdTracker
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 
 class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
-        self._pushers_id_gen = SlavedIdTracker(
+        self._pushers_id_gen = SlavedIdTracker(  # type: ignore
             db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
         )
 
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self, stream_name: str, instance_name: str, token, rows
+    ) -> None:
         if stream_name == PushersStream.NAME:
-            self._pushers_id_gen.advance(instance_name, token)
+            self._pushers_id_gen.advance(instance_name, token)  # type: ignore
         return super().process_replication_rows(stream_name, instance_name, token, rows)

+ 1 - 15
synapse/rest/admin/users.py

@@ -42,17 +42,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-_GET_PUSHERS_ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class UsersRestServlet(RestServlet):
     PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
@@ -770,10 +759,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.store.get_pushers_by_user_id(user_id)
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in _GET_PUSHERS_ALLOWED_KEYS}
-            for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers, "total": len(filtered_pushers)}
 

+ 1 - 14
synapse/rest/client/v1/pusher.py

@@ -28,17 +28,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 
 logger = logging.getLogger(__name__)
 
-ALLOWED_KEYS = {
-    "app_display_name",
-    "app_id",
-    "data",
-    "device_display_name",
-    "kind",
-    "lang",
-    "profile_tag",
-    "pushkey",
-}
-
 
 class PushersRestServlet(RestServlet):
     PATTERNS = client_patterns("/pushers$", v1=True)
@@ -54,9 +43,7 @@ class PushersRestServlet(RestServlet):
 
         pushers = await self.hs.get_datastore().get_pushers_by_user_id(user.to_string())
 
-        filtered_pushers = [
-            {k: v for k, v in p.items() if k in ALLOWED_KEYS} for p in pushers
-        ]
+        filtered_pushers = [p.as_dict() for p in pushers]
 
         return 200, {"pushers": filtered_pushers}
 

+ 0 - 3
synapse/storage/databases/main/__init__.py

@@ -149,9 +149,6 @@ class DataStore(
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
         self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
         self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
-        )
         self._group_updates_id_gen = StreamIdGenerator(
             db_conn, "local_group_updates", "stream_id"
         )

+ 57 - 36
synapse/storage/databases/main/pusher.py

@@ -15,18 +15,32 @@
 # limitations under the License.
 
 import logging
-from typing import Iterable, Iterator, List, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
 
 from canonicaljson import encode_canonical_json
 
+from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import DatabasePool
+from synapse.storage.types import Connection
+from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached, cachedList
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class PusherWorkerStore(SQLBaseStore):
-    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[dict]:
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+        super().__init__(database, db_conn, hs)
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+        )
+
+    def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
         """JSON-decode the data in the rows returned from the `pushers` table
 
         Drops any rows whose data cannot be decoded
@@ -44,21 +58,23 @@ class PusherWorkerStore(SQLBaseStore):
                 )
                 continue
 
-            yield r
+            yield PusherConfig(**r)
 
-    async def user_has_pusher(self, user_id):
+    async def user_has_pusher(self, user_id: str) -> bool:
         ret = await self.db_pool.simple_select_one_onecol(
             "pushers", {"user_name": user_id}, "id", allow_none=True
         )
         return ret is not None
 
-    def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
-        return self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
+    async def get_pushers_by_app_id_and_pushkey(
+        self, app_id: str, pushkey: str
+    ) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"app_id": app_id, "pushkey": pushkey})
 
-    def get_pushers_by_user_id(self, user_id):
-        return self.get_pushers_by({"user_name": user_id})
+    async def get_pushers_by_user_id(self, user_id: str) -> Iterator[PusherConfig]:
+        return await self.get_pushers_by({"user_name": user_id})
 
-    async def get_pushers_by(self, keyvalues):
+    async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
         ret = await self.db_pool.simple_select_list(
             "pushers",
             keyvalues,
@@ -83,7 +99,7 @@ class PusherWorkerStore(SQLBaseStore):
         )
         return self._decode_pushers_rows(ret)
 
-    async def get_all_pushers(self):
+    async def get_all_pushers(self) -> Iterator[PusherConfig]:
         def get_pushers(txn):
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
@@ -159,14 +175,16 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id):
+    async def get_if_user_has_pusher(self, user_id: str):
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     @cachedList(
         cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
     )
-    async def get_if_users_have_pushers(self, user_ids):
+    async def get_if_users_have_pushers(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, bool]:
         rows = await self.db_pool.simple_select_many_batch(
             table="pushers",
             column="user_name",
@@ -224,7 +242,7 @@ class PusherWorkerStore(SQLBaseStore):
         return bool(updated)
 
     async def update_pusher_failing_since(
-        self, app_id, pushkey, user_id, failing_since
+        self, app_id: str, pushkey: str, user_id: str, failing_since: Optional[int]
     ) -> None:
         await self.db_pool.simple_update(
             table="pushers",
@@ -233,7 +251,9 @@ class PusherWorkerStore(SQLBaseStore):
             desc="update_pusher_failing_since",
         )
 
-    async def get_throttle_params_by_room(self, pusher_id):
+    async def get_throttle_params_by_room(
+        self, pusher_id: str
+    ) -> Dict[str, ThrottleParams]:
         res = await self.db_pool.simple_select_list(
             "pusher_throttle",
             {"pusher": pusher_id},
@@ -243,43 +263,44 @@ class PusherWorkerStore(SQLBaseStore):
 
         params_by_room = {}
         for row in res:
-            params_by_room[row["room_id"]] = {
-                "last_sent_ts": row["last_sent_ts"],
-                "throttle_ms": row["throttle_ms"],
-            }
+            params_by_room[row["room_id"]] = ThrottleParams(
+                row["last_sent_ts"], row["throttle_ms"],
+            )
 
         return params_by_room
 
-    async def set_throttle_params(self, pusher_id, room_id, params) -> None:
+    async def set_throttle_params(
+        self, pusher_id: str, room_id: str, params: ThrottleParams
+    ) -> None:
         # no need to lock because `pusher_throttle` has a primary key on
         # (pusher, room_id) so simple_upsert will retry
         await self.db_pool.simple_upsert(
             "pusher_throttle",
             {"pusher": pusher_id, "room_id": room_id},
-            params,
+            {"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
             desc="set_throttle_params",
             lock=False,
         )
 
 
 class PusherStore(PusherWorkerStore):
-    def get_pushers_stream_token(self):
+    def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
     async def add_pusher(
         self,
-        user_id,
-        access_token,
-        kind,
-        app_id,
-        app_display_name,
-        device_display_name,
-        pushkey,
-        pushkey_ts,
-        lang,
-        data,
-        last_stream_ordering,
-        profile_tag="",
+        user_id: str,
+        access_token: Optional[int],
+        kind: str,
+        app_id: str,
+        app_display_name: str,
+        device_display_name: str,
+        pushkey: str,
+        pushkey_ts: int,
+        lang: Optional[str],
+        data: Optional[JsonDict],
+        last_stream_ordering: int,
+        profile_tag: str = "",
     ) -> None:
         async with self._pushers_id_gen.get_next() as stream_id:
             # no need to lock because `pushers` has a unique key on
@@ -311,16 +332,16 @@ class PusherStore(PusherWorkerStore):
                 # invalidate, since we the user might not have had a pusher before
                 await self.db_pool.runInteraction(
                     "add_pusher",
-                    self._invalidate_cache_and_stream,
+                    self._invalidate_cache_and_stream,  # type: ignore
                     self.get_if_user_has_pusher,
                     (user_id,),
                 )
 
     async def delete_pusher_by_app_id_pushkey_user_id(
-        self, app_id, pushkey, user_id
+        self, app_id: str, pushkey: str, user_id: str
     ) -> None:
         def delete_pusher_txn(txn, stream_id):
-            self._invalidate_cache_and_stream(
+            self._invalidate_cache_and_stream(  # type: ignore
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
 

+ 2 - 2
synapse/storage/util/id_generators.py

@@ -153,12 +153,12 @@ class StreamIdGenerator:
 
         return _AsyncCtxManagerWrapper(manager())
 
-    def get_current_token(self):
+    def get_current_token(self) -> int:
         """Returns the maximum stream id such that all stream ids less than or
         equal to it have been successfully persisted.
 
         Returns:
-            int
+            The maximum stream id.
         """
         with self._lock:
             if self._unfinished_ids:

+ 3 - 3
tests/push/test_email.py

@@ -209,7 +209,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump(10)
@@ -220,7 +220,7 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One email was attempted to be sent
         self.assertEqual(len(self.email_attempts), 1)
@@ -238,4 +238,4 @@ class EmailPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)

+ 5 - 5
tests/push/test_http.py

@@ -144,7 +144,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Advance time a bit, so the pusher will register something has happened
         self.pump()
@@ -155,7 +155,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual(last_stream_ordering, pushers[0]["last_stream_ordering"])
+        self.assertEqual(last_stream_ordering, pushers[0].last_stream_ordering)
 
         # One push was attempted to be sent -- it'll be the first message
         self.assertEqual(len(self.push_attempts), 1)
@@ -176,8 +176,8 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
-        last_stream_ordering = pushers[0]["last_stream_ordering"]
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
+        last_stream_ordering = pushers[0].last_stream_ordering
 
         # Now it'll try and send the second push message, which will be the second one
         self.assertEqual(len(self.push_attempts), 2)
@@ -198,7 +198,7 @@ class HTTPPusherTests(HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertTrue(pushers[0]["last_stream_ordering"] > last_stream_ordering)
+        self.assertTrue(pushers[0].last_stream_ordering > last_stream_ordering)
 
     def test_sends_high_priority_for_encrypted(self):
         """

+ 1 - 1
tests/rest/admin/test_user.py

@@ -766,7 +766,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         )
         pushers = list(pushers)
         self.assertEqual(len(pushers), 1)
-        self.assertEqual("@bob:test", pushers[0]["user_name"])
+        self.assertEqual("@bob:test", pushers[0].user_name)
 
     @override_config(
         {