Browse Source

Add type hints to `synapse.storage.databases.main.client_ips` (#10972)

Sean Quah 2 years ago
parent
commit
36224e056a

+ 1 - 0
changelog.d/10972.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.storage.databases.main.client_ips`.

+ 4 - 0
mypy.ini

@@ -53,6 +53,7 @@ files =
   synapse/storage/_base.py,
   synapse/storage/_base.py,
   synapse/storage/background_updates.py,
   synapse/storage/background_updates.py,
   synapse/storage/databases/main/appservice.py,
   synapse/storage/databases/main/appservice.py,
+  synapse/storage/databases/main/client_ips.py,
   synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/keys.py,
   synapse/storage/databases/main/keys.py,
   synapse/storage/databases/main/pusher.py,
   synapse/storage/databases/main/pusher.py,
@@ -108,6 +109,9 @@ disallow_untyped_defs = True
 [mypy-synapse.state.*]
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 disallow_untyped_defs = True
 
 
+[mypy-synapse.storage.databases.main.client_ips]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.util.*]
 [mypy-synapse.storage.util.*]
 disallow_untyped_defs = True
 disallow_untyped_defs = True
 
 

+ 13 - 2
synapse/handlers/device.py

@@ -14,7 +14,18 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 import logging
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
 
 
 from synapse.api import errors
 from synapse.api import errors
 from synapse.api.constants import EventTypes
 from synapse.api.constants import EventTypes
@@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 
 
 def _update_device_from_client_ips(
 def _update_device_from_client_ips(
-    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
+    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})

+ 3 - 3
synapse/module_api/__init__.py

@@ -773,9 +773,9 @@ class ModuleApi:
             # Sanitize some of the data. We don't want to return tokens.
             # Sanitize some of the data. We don't want to return tokens.
             return [
             return [
                 UserIpAndAgent(
                 UserIpAndAgent(
-                    ip=str(data["ip"]),
-                    user_agent=str(data["user_agent"]),
-                    last_seen=int(data["last_seen"]),
+                    ip=data["ip"],
+                    user_agent=data["user_agent"],
+                    last_seen=data["last_seen"],
                 )
                 )
                 for data in raw_data
                 for data in raw_data
             ]
             ]

+ 100 - 40
synapse/storage/databases/main/client_ips.py

@@ -13,14 +13,26 @@
 # limitations under the License.
 # limitations under the License.
 
 
 import logging
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+
+from typing_extensions import TypedDict
 
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.types import UserID
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.lrucache import LruCache
 
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
 
 
+class DeviceLastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for a user and device combination"""
+
+    # These types must match the columns in the `devices` table
+    user_id: str
+    device_id: str
+
+    ip: Optional[str]
+    user_agent: Optional[str]
+    last_seen: Optional[int]
+
+
+class LastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for an access token and IP combination"""
+
+    # These types must match the columns in the `user_ips` table
+    access_token: str
+    ip: str
+
+    user_agent: str
+    last_seen: int
+
+
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         super().__init__(database, db_conn, hs)
 
 
         self.db_pool.updates.register_background_index_update(
         self.db_pool.updates.register_background_index_update(
@@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             "devices_last_seen", self._devices_last_seen_update
             "devices_last_seen", self._devices_last_seen_update
         )
         )
 
 
-    async def _remove_user_ip_nonunique(self, progress, batch_size):
-        def f(conn):
+    async def _remove_user_ip_nonunique(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
             txn.close()
@@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         )
         )
         return 1
         return 1
 
 
-    async def _analyze_user_ip(self, progress, batch_size):
+    async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
         # Background update to analyze user_ips table before we run the
         # Background update to analyze user_ips table before we run the
         # deduplication background update. The table may not have been analyzed
         # deduplication background update. The table may not have been analyzed
         # for ages due to the table locks.
         # for ages due to the table locks.
         #
         #
         # This will lock out the naive upserts to user_ips while it happens, but
         # This will lock out the naive upserts to user_ips while it happens, but
         # the analyze should be quick (28GB table takes ~10s)
         # the analyze should be quick (28GB table takes ~10s)
-        def user_ips_analyze(txn):
+        def user_ips_analyze(txn: LoggingTransaction) -> None:
             txn.execute("ANALYZE user_ips")
             txn.execute("ANALYZE user_ips")
 
 
         await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
         await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
         return 1
         return 1
 
 
-    async def _remove_user_ip_dupes(self, progress, batch_size):
+    async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
         # This works function works by scanning the user_ips table in batches
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
         # based on `last_seen`. For each row in a batch it searches the rest of
         # the table to see if there are any duplicates, if there are then they
         # the table to see if there are any duplicates, if there are then they
         # are removed and replaced with a suitable row.
         # are removed and replaced with a suitable row.
 
 
         # Fetch the start of the batch
         # Fetch the start of the batch
-        begin_last_seen = progress.get("last_seen", 0)
+        begin_last_seen: int = progress.get("last_seen", 0)
 
 
-        def get_last_seen(txn):
+        def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
             txn.execute(
             txn.execute(
                 """
                 """
                 SELECT last_seen FROM user_ips
                 SELECT last_seen FROM user_ips
@@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 """,
                 """,
                 (begin_last_seen, batch_size),
                 (begin_last_seen, batch_size),
             )
             )
-            row = txn.fetchone()
+            row = cast(Optional[Tuple[int]], txn.fetchone())
             if row:
             if row:
                 return row[0]
                 return row[0]
             else:
             else:
@@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             end_last_seen,
             end_last_seen,
         )
         )
 
 
-        def remove(txn):
+        def remove(txn: LoggingTransaction) -> None:
             # This works by looking at all entries in the given time span, and
             # This works by looking at all entries in the given time span, and
             # then for each (user_id, access_token, ip) tuple in that range
             # then for each (user_id, access_token, ip) tuple in that range
             # checking for any duplicates in the rest of the table (via a join).
             # checking for any duplicates in the rest of the table (via a join).
@@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
             # Define the search space, which requires handling the last batch in
             # Define the search space, which requires handling the last batch in
             # a different way
             # a different way
+            args: Tuple[int, ...]
             if last:
             if last:
                 clause = "? <= last_seen"
                 clause = "? <= last_seen"
                 args = (begin_last_seen,)
                 args = (begin_last_seen,)
             else:
             else:
+                assert end_last_seen is not None
                 clause = "? <= last_seen AND last_seen < ?"
                 clause = "? <= last_seen AND last_seen < ?"
                 args = (begin_last_seen, end_last_seen)
                 args = (begin_last_seen, end_last_seen)
 
 
@@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 ),
                 ),
                 args,
                 args,
             )
             )
-            res = txn.fetchall()
+            res = cast(
+                List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
+            )
 
 
             # We've got some duplicates
             # We've got some duplicates
             for i in res:
             for i in res:
@@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
         return batch_size
         return batch_size
 
 
-    async def _devices_last_seen_update(self, progress, batch_size):
+    async def _devices_last_seen_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to insert last seen info into devices table"""
         """Background update to insert last seen info into devices table"""
 
 
-        last_user_id = progress.get("last_user_id", "")
-        last_device_id = progress.get("last_device_id", "")
+        last_user_id: str = progress.get("last_user_id", "")
+        last_device_id: str = progress.get("last_device_id", "")
 
 
-        def _devices_last_seen_update_txn(txn):
+        def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
             # This consists of two queries:
             # This consists of two queries:
             #
             #
             #   1. The sub-query searches for the next N devices and joins
             #   1. The sub-query searches for the next N devices and joins
@@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             #      we'll just end up updating the same device row multiple
             #      we'll just end up updating the same device row multiple
             #      times, which is fine.
             #      times, which is fine.
 
 
+            where_args: List[Union[str, int]]
             where_clause, where_args = make_tuple_comparison_clause(
             where_clause, where_args = make_tuple_comparison_clause(
                 [("user_id", last_user_id), ("device_id", last_device_id)],
                 [("user_id", last_user_id), ("device_id", last_device_id)],
             )
             )
@@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             }
             }
             txn.execute(sql, where_args + [batch_size])
             txn.execute(sql, where_args + [batch_size])
 
 
-            rows = txn.fetchall()
+            rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
             if not rows:
             if not rows:
                 return 0
                 return 0
 
 
@@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
         super().__init__(database, db_conn, hs)
 
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
 
     @wrap_as_background_process("prune_old_user_ips")
     @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
+    async def _prune_old_user_ips(self) -> None:
         """Removes entries in user IPs older than the configured period."""
         """Removes entries in user IPs older than the configured period."""
 
 
         if self.user_ips_max_age is None:
         if self.user_ips_max_age is None:
@@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             )
             )
         """
         """
 
 
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
+        timestamp = self._clock.time_msec() - self.user_ips_max_age
 
 
-        def _prune_old_user_ips_txn(txn):
+        def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
             txn.execute(sql, (timestamp,))
             txn.execute(sql, (timestamp,))
 
 
         await self.db_pool.runInteraction(
         await self.db_pool.runInteraction(
@@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
 
     async def get_last_client_ip_by_device(
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on.
         """For each device_id listed, give the user_ip it was last seen on.
 
 
         The result might be slightly out of date as client IPs are inserted in batches.
         The result might be slightly out of date as client IPs are inserted in batches.
@@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
         if device_id is not None:
             keyvalues["device_id"] = device_id
             keyvalues["device_id"] = device_id
 
 
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        res = cast(
+            List[DeviceLastConnectionInfo],
+            await self.db_pool.simple_select_list(
+                table="devices",
+                keyvalues=keyvalues,
+                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            ),
         )
         )
 
 
         return {(d["user_id"], d["device_id"]): d for d in res}
         return {(d["user_id"], d["device_id"]): d for d in res}
 
 
 
 
-class ClientIpStore(ClientIpWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
 
 
-        self.client_ip_last_seen = LruCache(
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
             cache_name="client_ip_last_seen", max_size=50000
             cache_name="client_ip_last_seen", max_size=50000
         )
         )
 
 
         super().__init__(database, db_conn, hs)
         super().__init__(database, db_conn, hs)
 
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
+        self._batch_row_update: Dict[
+            Tuple[str, str, str], Tuple[str, Optional[str], int]
+        ] = {}
 
 
         self._client_ip_looper = self._clock.looping_call(
         self._client_ip_looper = self._clock.looping_call(
             self._update_client_ips_batch, 5 * 1000
             self._update_client_ips_batch, 5 * 1000
@@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore):
         )
         )
 
 
     async def insert_client_ip(
     async def insert_client_ip(
-        self, user_id, access_token, ip, user_agent, device_id, now=None
-    ):
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: Optional[str],
+        now: Optional[int] = None,
+    ) -> None:
         if not now:
         if not now:
             now = int(self._clock.time_msec())
             now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
         key = (user_id, access_token, ip)
@@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore):
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
         )
 
 
-    def _update_client_ips_batch_txn(self, txn, to_update):
+    def _update_client_ips_batch_txn(
+        self,
+        txn: LoggingTransaction,
+        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+    ) -> None:
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
             not self.database_engine.can_native_upsert
         ):
         ):
@@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore):
 
 
     async def get_last_client_ip_by_device(
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on
         """For each device_id listed, give the user_ip it was last seen on
 
 
         Args:
         Args:
@@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore):
 
 
     async def get_user_ip_and_agents(
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
         self, user: UserID, since_ts: int = 0
-    ) -> List[Dict[str, Union[str, int]]]:
+    ) -> List[LastConnectionInfo]:
         """
         """
         Fetch IP/User Agent connection since a given timestamp.
         Fetch IP/User Agent connection since a given timestamp.
         """
         """
         user_id = user.to_string()
         user_id = user.to_string()
-        results = {}
+        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
 
 
         for key in self._batch_row_update:
         for key in self._batch_row_update:
             (
             (
@@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 if last_seen >= since_ts:
                 if last_seen >= since_ts:
                     results[(access_token, ip)] = (user_agent, last_seen)
                     results[(access_token, ip)] = (user_agent, last_seen)
 
 
-        def get_recent(txn):
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
             txn.execute(
             txn.execute(
                 """
                 """
                 SELECT access_token, ip, user_agent, last_seen FROM user_ips
                 SELECT access_token, ip, user_agent, last_seen FROM user_ips
@@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 """,
                 """,
                 (since_ts, user_id),
                 (since_ts, user_id),
             )
             )
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
 
 
         rows = await self.db_pool.runInteraction(
         rows = await self.db_pool.runInteraction(
             desc="get_user_ip_and_agents", func=get_recent
             desc="get_user_ip_and_agents", func=get_recent