浏览代码

Add type hints to `synapse.storage.databases.main.client_ips` (#10972)

Sean Quah 2 年之前
父节点
当前提交
36224e056a
共有 5 个文件被更改,包括 121 次插入45 次删除
  1. 1 0
      changelog.d/10972.misc
  2. 4 0
      mypy.ini
  3. 13 2
      synapse/handlers/device.py
  4. 3 3
      synapse/module_api/__init__.py
  5. 100 40
      synapse/storage/databases/main/client_ips.py

+ 1 - 0
changelog.d/10972.misc

@@ -0,0 +1 @@
+Add type hints to `synapse.storage.databases.main.client_ips`.

+ 4 - 0
mypy.ini

@@ -53,6 +53,7 @@ files =
   synapse/storage/_base.py,
   synapse/storage/background_updates.py,
   synapse/storage/databases/main/appservice.py,
+  synapse/storage/databases/main/client_ips.py,
   synapse/storage/databases/main/events.py,
   synapse/storage/databases/main/keys.py,
   synapse/storage/databases/main/pusher.py,
@@ -108,6 +109,9 @@ disallow_untyped_defs = True
 [mypy-synapse.state.*]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.client_ips]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.util.*]
 disallow_untyped_defs = True
 

+ 13 - 2
synapse/handlers/device.py

@@ -14,7 +14,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+)
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -595,7 +606,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
 
 def _update_device_from_client_ips(
-    device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict]
+    device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})

+ 3 - 3
synapse/module_api/__init__.py

@@ -773,9 +773,9 @@ class ModuleApi:
             # Sanitize some of the data. We don't want to return tokens.
             return [
                 UserIpAndAgent(
-                    ip=str(data["ip"]),
-                    user_agent=str(data["user_agent"]),
-                    last_seen=int(data["last_seen"]),
+                    ip=data["ip"],
+                    user_agent=data["user_agent"],
+                    last_seen=data["last_seen"],
                 )
                 for data in raw_data
             ]

+ 100 - 40
synapse/storage/databases/main/client_ips.py

@@ -13,14 +13,26 @@
 # limitations under the License.
 
 import logging
-from typing import Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Union, cast
+
+from typing_extensions import TypedDict
 
 from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
-from synapse.types import UserID
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
+from synapse.storage.databases.main.monthly_active_users import MonthlyActiveUsersStore
+from synapse.storage.types import Connection
+from synapse.types import JsonDict, UserID
 from synapse.util.caches.lrucache import LruCache
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 # Number of msec of granularity to store the user IP 'last seen' time. Smaller
@@ -29,8 +41,31 @@ logger = logging.getLogger(__name__)
 LAST_SEEN_GRANULARITY = 120 * 1000
 
 
+class DeviceLastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for a user and device combination"""
+
+    # These types must match the columns in the `devices` table
+    user_id: str
+    device_id: str
+
+    ip: Optional[str]
+    user_agent: Optional[str]
+    last_seen: Optional[int]
+
+
+class LastConnectionInfo(TypedDict):
+    """Metadata for the last connection seen for an access token and IP combination"""
+
+    # These types must match the columns in the `user_ips` table
+    access_token: str
+    ip: str
+
+    user_agent: str
+    last_seen: int
+
+
 class ClientIpBackgroundUpdateStore(SQLBaseStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.db_pool.updates.register_background_index_update(
@@ -81,8 +116,10 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             "devices_last_seen", self._devices_last_seen_update
         )
 
-    async def _remove_user_ip_nonunique(self, progress, batch_size):
-        def f(conn):
+    async def _remove_user_ip_nonunique(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
             txn.close()
@@ -93,14 +130,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _analyze_user_ip(self, progress, batch_size):
+    async def _analyze_user_ip(self, progress: JsonDict, batch_size: int) -> int:
         # Background update to analyze user_ips table before we run the
         # deduplication background update. The table may not have been analyzed
         # for ages due to the table locks.
         #
         # This will lock out the naive upserts to user_ips while it happens, but
         # the analyze should be quick (28GB table takes ~10s)
-        def user_ips_analyze(txn):
+        def user_ips_analyze(txn: LoggingTransaction) -> None:
             txn.execute("ANALYZE user_ips")
 
         await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
@@ -109,16 +146,16 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return 1
 
-    async def _remove_user_ip_dupes(self, progress, batch_size):
+    async def _remove_user_ip_dupes(self, progress: JsonDict, batch_size: int) -> int:
         # This works function works by scanning the user_ips table in batches
         # based on `last_seen`. For each row in a batch it searches the rest of
         # the table to see if there are any duplicates, if there are then they
         # are removed and replaced with a suitable row.
 
         # Fetch the start of the batch
-        begin_last_seen = progress.get("last_seen", 0)
+        begin_last_seen: int = progress.get("last_seen", 0)
 
-        def get_last_seen(txn):
+        def get_last_seen(txn: LoggingTransaction) -> Optional[int]:
             txn.execute(
                 """
                 SELECT last_seen FROM user_ips
@@ -129,7 +166,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 """,
                 (begin_last_seen, batch_size),
             )
-            row = txn.fetchone()
+            row = cast(Optional[Tuple[int]], txn.fetchone())
             if row:
                 return row[0]
             else:
@@ -149,7 +186,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             end_last_seen,
         )
 
-        def remove(txn):
+        def remove(txn: LoggingTransaction) -> None:
             # This works by looking at all entries in the given time span, and
             # then for each (user_id, access_token, ip) tuple in that range
             # checking for any duplicates in the rest of the table (via a join).
@@ -161,10 +198,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
             # Define the search space, which requires handling the last batch in
             # a different way
+            args: Tuple[int, ...]
             if last:
                 clause = "? <= last_seen"
                 args = (begin_last_seen,)
             else:
+                assert end_last_seen is not None
                 clause = "? <= last_seen AND last_seen < ?"
                 args = (begin_last_seen, end_last_seen)
 
@@ -189,7 +228,9 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
                 ),
                 args,
             )
-            res = txn.fetchall()
+            res = cast(
+                List[Tuple[str, str, str, Optional[str], str, int, int]], txn.fetchall()
+            )
 
             # We've got some duplicates
             for i in res:
@@ -278,13 +319,15 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
-    async def _devices_last_seen_update(self, progress, batch_size):
+    async def _devices_last_seen_update(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to insert last seen info into devices table"""
 
-        last_user_id = progress.get("last_user_id", "")
-        last_device_id = progress.get("last_device_id", "")
+        last_user_id: str = progress.get("last_user_id", "")
+        last_device_id: str = progress.get("last_device_id", "")
 
-        def _devices_last_seen_update_txn(txn):
+        def _devices_last_seen_update_txn(txn: LoggingTransaction) -> int:
             # This consists of two queries:
             #
             #   1. The sub-query searches for the next N devices and joins
@@ -296,6 +339,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             #      we'll just end up updating the same device row multiple
             #      times, which is fine.
 
+            where_args: List[Union[str, int]]
             where_clause, where_args = make_tuple_comparison_clause(
                 [("user_id", last_user_id), ("device_id", last_device_id)],
             )
@@ -319,7 +363,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
             }
             txn.execute(sql, where_args + [batch_size])
 
-            rows = txn.fetchall()
+            rows = cast(List[Tuple[int, str, str, str, str]], txn.fetchall())
             if not rows:
                 return 0
 
@@ -350,7 +394,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
 
 
 class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
         super().__init__(database, db_conn, hs)
 
         self.user_ips_max_age = hs.config.server.user_ips_max_age
@@ -359,7 +403,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
 
     @wrap_as_background_process("prune_old_user_ips")
-    async def _prune_old_user_ips(self):
+    async def _prune_old_user_ips(self) -> None:
         """Removes entries in user IPs older than the configured period."""
 
         if self.user_ips_max_age is None:
@@ -394,9 +438,9 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
             )
         """
 
-        timestamp = self.clock.time_msec() - self.user_ips_max_age
+        timestamp = self._clock.time_msec() - self.user_ips_max_age
 
-        def _prune_old_user_ips_txn(txn):
+        def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None:
             txn.execute(sql, (timestamp,))
 
         await self.db_pool.runInteraction(
@@ -405,7 +449,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on.
 
         The result might be slightly out of date as client IPs are inserted in batches.
@@ -423,26 +467,32 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
         if device_id is not None:
             keyvalues["device_id"] = device_id
 
-        res = await self.db_pool.simple_select_list(
-            table="devices",
-            keyvalues=keyvalues,
-            retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+        res = cast(
+            List[DeviceLastConnectionInfo],
+            await self.db_pool.simple_select_list(
+                table="devices",
+                keyvalues=keyvalues,
+                retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
+            ),
         )
 
         return {(d["user_id"], d["device_id"]): d for d in res}
 
 
-class ClientIpStore(ClientIpWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn, hs):
+class ClientIpStore(ClientIpWorkerStore, MonthlyActiveUsersStore):
+    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
 
-        self.client_ip_last_seen = LruCache(
+        # (user_id, access_token, ip,) -> last_seen
+        self.client_ip_last_seen = LruCache[Tuple[str, str, str], int](
             cache_name="client_ip_last_seen", max_size=50000
         )
 
         super().__init__(database, db_conn, hs)
 
         # (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
-        self._batch_row_update = {}
+        self._batch_row_update: Dict[
+            Tuple[str, str, str], Tuple[str, Optional[str], int]
+        ] = {}
 
         self._client_ip_looper = self._clock.looping_call(
             self._update_client_ips_batch, 5 * 1000
@@ -452,8 +502,14 @@ class ClientIpStore(ClientIpWorkerStore):
         )
 
     async def insert_client_ip(
-        self, user_id, access_token, ip, user_agent, device_id, now=None
-    ):
+        self,
+        user_id: str,
+        access_token: str,
+        ip: str,
+        user_agent: str,
+        device_id: Optional[str],
+        now: Optional[int] = None,
+    ) -> None:
         if not now:
             now = int(self._clock.time_msec())
         key = (user_id, access_token, ip)
@@ -485,7 +541,11 @@ class ClientIpStore(ClientIpWorkerStore):
             "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
         )
 
-    def _update_client_ips_batch_txn(self, txn, to_update):
+    def _update_client_ips_batch_txn(
+        self,
+        txn: LoggingTransaction,
+        to_update: Mapping[Tuple[str, str, str], Tuple[str, Optional[str], int]],
+    ) -> None:
         if "user_ips" in self.db_pool._unsafe_to_upsert_tables or (
             not self.database_engine.can_native_upsert
         ):
@@ -525,7 +585,7 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_last_client_ip_by_device(
         self, user_id: str, device_id: Optional[str]
-    ) -> Dict[Tuple[str, str], dict]:
+    ) -> Dict[Tuple[str, str], DeviceLastConnectionInfo]:
         """For each device_id listed, give the user_ip it was last seen on
 
         Args:
@@ -561,12 +621,12 @@ class ClientIpStore(ClientIpWorkerStore):
 
     async def get_user_ip_and_agents(
         self, user: UserID, since_ts: int = 0
-    ) -> List[Dict[str, Union[str, int]]]:
+    ) -> List[LastConnectionInfo]:
         """
         Fetch IP/User Agent connection since a given timestamp.
         """
         user_id = user.to_string()
-        results = {}
+        results: Dict[Tuple[str, str], Tuple[str, int]] = {}
 
         for key in self._batch_row_update:
             (
@@ -579,7 +639,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 if last_seen >= since_ts:
                     results[(access_token, ip)] = (user_agent, last_seen)
 
-        def get_recent(txn):
+        def get_recent(txn: LoggingTransaction) -> List[Tuple[str, str, str, int]]:
             txn.execute(
                 """
                 SELECT access_token, ip, user_agent, last_seen FROM user_ips
@@ -589,7 +649,7 @@ class ClientIpStore(ClientIpWorkerStore):
                 """,
                 (since_ts, user_id),
             )
-            return txn.fetchall()
+            return cast(List[Tuple[str, str, str, int]], txn.fetchall())
 
         rows = await self.db_pool.runInteraction(
             desc="get_user_ip_and_agents", func=get_recent