Browse Source

Add type hints to `synapse/storage/databases/main/transactions.py` (#11589)

Dirk Klimpel 2 years ago
parent
commit
1847d027e6
3 changed files with 29 additions and 25 deletions
  1. 1 0
      changelog.d/11589.misc
  2. 3 1
      mypy.ini
  3. 25 24
      synapse/storage/databases/main/transactions.py

+ 1 - 0
changelog.d/11589.misc

@@ -0,0 +1 @@
+Add missing type hints to storage classes.

+ 3 - 1
mypy.ini

@@ -41,7 +41,6 @@ exclude = (?x)
    |synapse/storage/databases/main/search.py
    |synapse/storage/databases/main/state.py
    |synapse/storage/databases/main/stats.py
-   |synapse/storage/databases/main/transactions.py
    |synapse/storage/databases/main/user_directory.py
    |synapse/storage/schema/
 
@@ -216,6 +215,9 @@ disallow_untyped_defs = True
 [mypy-synapse.storage.databases.main.state_deltas]
 disallow_untyped_defs = True
 
+[mypy-synapse.storage.databases.main.transactions]
+disallow_untyped_defs = True
+
 [mypy-synapse.storage.databases.main.user_erasure_store]
 disallow_untyped_defs = True
 

+ 25 - 24
synapse/storage/databases/main/transactions.py

@@ -13,9 +13,8 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 from enum import Enum
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -39,16 +38,6 @@ db_binary_type = memoryview
 logger = logging.getLogger(__name__)
 
 
-_TransactionRow = namedtuple(
-    "_TransactionRow",
-    ("id", "transaction_id", "destination", "ts", "response_code", "response_json"),
-)
-
-_UpdateTransactionRow = namedtuple(
-    "_TransactionRow", ("response_code", "response_json")
-)
-
-
 class DestinationSortOrder(Enum):
     """Enum to define the sorting method used when returning destinations."""
 
@@ -91,7 +80,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         now = self._clock.time_msec()
         month_ago = now - 30 * 24 * 60 * 60 * 1000
 
-        def _cleanup_transactions_txn(txn):
+        def _cleanup_transactions_txn(txn: LoggingTransaction) -> None:
             txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
 
         await self.db_pool.runInteraction(
@@ -121,7 +110,9 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             origin,
         )
 
-    def _get_received_txn_response(self, txn, transaction_id, origin):
+    def _get_received_txn_response(
+        self, txn: LoggingTransaction, transaction_id: str, origin: str
+    ) -> Optional[Tuple[int, JsonDict]]:
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="received_transactions",
@@ -196,7 +187,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         return result
 
     def _get_destination_retry_timings(
-        self, txn, destination: str
+        self, txn: LoggingTransaction, destination: str
     ) -> Optional[DestinationRetryTimings]:
         result = self.db_pool.simple_select_one_txn(
             txn,
@@ -231,7 +222,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         """
 
         if self.database_engine.can_native_upsert:
-            return await self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "set_destination_retry_timings",
                 self._set_destination_retry_timings_native,
                 destination,
@@ -241,7 +232,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
                 db_autocommit=True,  # Safe as its a single upsert
             )
         else:
-            return await self.db_pool.runInteraction(
+            await self.db_pool.runInteraction(
                 "set_destination_retry_timings",
                 self._set_destination_retry_timings_emulated,
                 destination,
@@ -251,8 +242,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             )
 
     def _set_destination_retry_timings_native(
-        self, txn, destination, failure_ts, retry_last_ts, retry_interval
-    ):
+        self,
+        txn: LoggingTransaction,
+        destination: str,
+        failure_ts: Optional[int],
+        retry_last_ts: int,
+        retry_interval: int,
+    ) -> None:
         assert self.database_engine.can_native_upsert
 
         # Upsert retry time interval if retry_interval is zero (i.e. we're
@@ -282,8 +278,13 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
         )
 
     def _set_destination_retry_timings_emulated(
-        self, txn, destination, failure_ts, retry_last_ts, retry_interval
-    ):
+        self,
+        txn: LoggingTransaction,
+        destination: str,
+        failure_ts: Optional[int],
+        retry_last_ts: int,
+        retry_interval: int,
+    ) -> None:
         self.database_engine.lock_table(txn, "destinations")
 
         # We need to be careful here as the data may have changed from under us
@@ -393,7 +394,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             last_successful_stream_ordering: the stream_ordering of the most
                 recent successfully-sent PDU
         """
-        return await self.db_pool.simple_upsert(
+        await self.db_pool.simple_upsert(
             "destinations",
             keyvalues={"destination": destination},
             values={"last_successful_stream_ordering": last_successful_stream_ordering},
@@ -534,7 +535,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             else:
                 order = "ASC"
 
-            args = []
+            args: List[object] = []
             where_statement = ""
             if destination:
                 args.extend(["%" + destination.lower() + "%"])
@@ -543,7 +544,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
             sql_base = f"FROM destinations {where_statement} "
             sql = f"SELECT COUNT(*) as total_destinations {sql_base}"
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = f"""
                 SELECT destination, retry_last_ts, retry_interval, failure_ts,