lock.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # Copyright 2021 Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from types import TracebackType
  16. from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
  17. from weakref import WeakValueDictionary
  18. from twisted.internet.interfaces import IReactorCore
  19. from synapse.metrics.background_process_metrics import wrap_as_background_process
  20. from synapse.storage._base import SQLBaseStore
  21. from synapse.storage.database import (
  22. DatabasePool,
  23. LoggingDatabaseConnection,
  24. LoggingTransaction,
  25. )
  26. from synapse.storage.engines import PostgresEngine
  27. from synapse.util import Clock
  28. from synapse.util.stringutils import random_string
  29. if TYPE_CHECKING:
  30. from synapse.server import HomeServer
  31. logger = logging.getLogger(__name__)
  32. # How often to renew an acquired lock by updating the `last_renewed_ts` time in
  33. # the lock table.
  34. _RENEWAL_INTERVAL_MS = 30 * 1000
  35. # How long before an acquired lock times out.
  36. _LOCK_TIMEOUT_MS = 2 * 60 * 1000
  37. class LockStore(SQLBaseStore):
  38. """Provides a best effort distributed lock between worker instances.
  39. Locks are identified by a name and key. A lock is acquired by inserting into
  40. the `worker_locks` table if a) there is no existing row for the name/key or
  41. b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.
  42. When a lock is taken out the instance inserts a random `token`, the instance
  43. that holds that token holds the lock until it drops (or times out).
  44. The instance that holds the lock should regularly update the
  45. `last_renewed_ts` column with the current time.
  46. """
  47. def __init__(
  48. self,
  49. database: DatabasePool,
  50. db_conn: LoggingDatabaseConnection,
  51. hs: "HomeServer",
  52. ):
  53. super().__init__(database, db_conn, hs)
  54. self._reactor = hs.get_reactor()
  55. self._instance_name = hs.get_instance_id()
  56. # A map from `(lock_name, lock_key)` to lock that we think we
  57. # currently hold.
  58. self._live_lock_tokens: WeakValueDictionary[
  59. Tuple[str, str], Lock
  60. ] = WeakValueDictionary()
  61. # A map from `(lock_name, lock_key, token)` to read/write lock that we
  62. # think we currently hold. For a given lock_name/lock_key, there can be
  63. # multiple read locks at a time but only one write lock (no mixing read
  64. # and write locks at the same time).
  65. self._live_read_write_lock_tokens: WeakValueDictionary[
  66. Tuple[str, str, str], Lock
  67. ] = WeakValueDictionary()
  68. # When we shut down we want to remove the locks. Technically this can
  69. # lead to a race, as we may drop the lock while we are still processing.
  70. # However, a) it should be a small window, b) the lock is best effort
  71. # anyway and c) we want to really avoid leaking locks when we restart.
  72. hs.get_reactor().addSystemEventTrigger(
  73. "before",
  74. "shutdown",
  75. self._on_shutdown,
  76. )
  77. self._acquiring_locks: Set[Tuple[str, str]] = set()
  78. @wrap_as_background_process("LockStore._on_shutdown")
  79. async def _on_shutdown(self) -> None:
  80. """Called when the server is shutting down"""
  81. logger.info("Dropping held locks due to shutdown")
  82. # We need to take a copy of the locks as dropping the locks will cause
  83. # the dictionary to change.
  84. locks = list(self._live_lock_tokens.values()) + list(
  85. self._live_read_write_lock_tokens.values()
  86. )
  87. for lock in locks:
  88. await lock.release()
  89. logger.info("Dropped locks due to shutdown")
  90. async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
  91. """Try to acquire a lock for the given name/key. Will return an async
  92. context manager if the lock is successfully acquired, which *must* be
  93. used (otherwise the lock will leak).
  94. """
  95. if (lock_name, lock_key) in self._acquiring_locks:
  96. return None
  97. try:
  98. self._acquiring_locks.add((lock_name, lock_key))
  99. return await self._try_acquire_lock(lock_name, lock_key)
  100. finally:
  101. self._acquiring_locks.discard((lock_name, lock_key))
  102. async def _try_acquire_lock(
  103. self, lock_name: str, lock_key: str
  104. ) -> Optional["Lock"]:
  105. """Try to acquire a lock for the given name/key. Will return an async
  106. context manager if the lock is successfully acquired, which *must* be
  107. used (otherwise the lock will leak).
  108. """
  109. # Check if this process has taken out a lock and if it's still valid.
  110. lock = self._live_lock_tokens.get((lock_name, lock_key))
  111. if lock and await lock.is_still_valid():
  112. return None
  113. now = self._clock.time_msec()
  114. token = random_string(6)
  115. def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
  116. # We take out the lock if either a) there is no row for the lock
  117. # already, b) the existing row has timed out, or c) the row is
  118. # for this instance (which means the process got killed and
  119. # restarted)
  120. sql = """
  121. INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
  122. VALUES (?, ?, ?, ?, ?)
  123. ON CONFLICT (lock_name, lock_key)
  124. DO UPDATE
  125. SET
  126. token = EXCLUDED.token,
  127. instance_name = EXCLUDED.instance_name,
  128. last_renewed_ts = EXCLUDED.last_renewed_ts
  129. WHERE
  130. worker_locks.last_renewed_ts < ?
  131. OR worker_locks.instance_name = EXCLUDED.instance_name
  132. """
  133. txn.execute(
  134. sql,
  135. (
  136. lock_name,
  137. lock_key,
  138. self._instance_name,
  139. token,
  140. now,
  141. now - _LOCK_TIMEOUT_MS,
  142. ),
  143. )
  144. # We only acquired the lock if we inserted or updated the table.
  145. return bool(txn.rowcount)
  146. did_lock = await self.db_pool.runInteraction(
  147. "try_acquire_lock",
  148. _try_acquire_lock_txn,
  149. # We can autocommit here as we're executing a single query, this
  150. # will avoid serialization errors.
  151. db_autocommit=True,
  152. )
  153. if not did_lock:
  154. return None
  155. lock = Lock(
  156. self._reactor,
  157. self._clock,
  158. self,
  159. read_write=False,
  160. lock_name=lock_name,
  161. lock_key=lock_key,
  162. token=token,
  163. )
  164. self._live_lock_tokens[(lock_name, lock_key)] = lock
  165. return lock
  166. async def try_acquire_read_write_lock(
  167. self,
  168. lock_name: str,
  169. lock_key: str,
  170. write: bool,
  171. ) -> Optional["Lock"]:
  172. """Try to acquire a lock for the given name/key. Will return an async
  173. context manager if the lock is successfully acquired, which *must* be
  174. used (otherwise the lock will leak).
  175. """
  176. now = self._clock.time_msec()
  177. token = random_string(6)
  178. def _try_acquire_read_write_lock_txn(txn: LoggingTransaction) -> None:
  179. # We attempt to acquire the lock by inserting into
  180. # `worker_read_write_locks` and seeing if that fails any
  181. # constraints. If it doesn't then we have acquired the lock,
  182. # otherwise we haven't.
  183. #
  184. # Before that though we clear the table of any stale locks.
  185. delete_sql = """
  186. DELETE FROM worker_read_write_locks
  187. WHERE last_renewed_ts < ? AND lock_name = ? AND lock_key = ?;
  188. """
  189. insert_sql = """
  190. INSERT INTO worker_read_write_locks (lock_name, lock_key, write_lock, instance_name, token, last_renewed_ts)
  191. VALUES (?, ?, ?, ?, ?, ?)
  192. """
  193. if isinstance(self.database_engine, PostgresEngine):
  194. # For Postgres we can send these queries at the same time.
  195. txn.execute(
  196. delete_sql + ";" + insert_sql,
  197. (
  198. # DELETE args
  199. now - _LOCK_TIMEOUT_MS,
  200. lock_name,
  201. lock_key,
  202. # UPSERT args
  203. lock_name,
  204. lock_key,
  205. write,
  206. self._instance_name,
  207. token,
  208. now,
  209. ),
  210. )
  211. else:
  212. # For SQLite these need to be two queries.
  213. txn.execute(
  214. delete_sql,
  215. (
  216. now - _LOCK_TIMEOUT_MS,
  217. lock_name,
  218. lock_key,
  219. ),
  220. )
  221. txn.execute(
  222. insert_sql,
  223. (
  224. lock_name,
  225. lock_key,
  226. write,
  227. self._instance_name,
  228. token,
  229. now,
  230. ),
  231. )
  232. return
  233. try:
  234. await self.db_pool.runInteraction(
  235. "try_acquire_read_write_lock",
  236. _try_acquire_read_write_lock_txn,
  237. )
  238. except self.database_engine.module.IntegrityError:
  239. return None
  240. lock = Lock(
  241. self._reactor,
  242. self._clock,
  243. self,
  244. read_write=True,
  245. lock_name=lock_name,
  246. lock_key=lock_key,
  247. token=token,
  248. )
  249. self._live_read_write_lock_tokens[(lock_name, lock_key, token)] = lock
  250. return lock
  251. class Lock:
  252. """An async context manager that manages an acquired lock, ensuring it is
  253. regularly renewed and dropping it when the context manager exits.
  254. The lock object has an `is_still_valid` method which can be used to
  255. double-check the lock is still valid, if e.g. processing work in a loop.
  256. For example:
  257. lock = await self.store.try_acquire_lock(...)
  258. if not lock:
  259. return
  260. async with lock:
  261. for item in work:
  262. await process(item)
  263. if not await lock.is_still_valid():
  264. break
  265. """
  266. def __init__(
  267. self,
  268. reactor: IReactorCore,
  269. clock: Clock,
  270. store: LockStore,
  271. read_write: bool,
  272. lock_name: str,
  273. lock_key: str,
  274. token: str,
  275. ) -> None:
  276. self._reactor = reactor
  277. self._clock = clock
  278. self._store = store
  279. self._read_write = read_write
  280. self._lock_name = lock_name
  281. self._lock_key = lock_key
  282. self._token = token
  283. self._table = "worker_read_write_locks" if read_write else "worker_locks"
  284. self._looping_call = clock.looping_call(
  285. self._renew,
  286. _RENEWAL_INTERVAL_MS,
  287. store,
  288. clock,
  289. read_write,
  290. lock_name,
  291. lock_key,
  292. token,
  293. )
  294. self._dropped = False
  295. @staticmethod
  296. @wrap_as_background_process("Lock._renew")
  297. async def _renew(
  298. store: LockStore,
  299. clock: Clock,
  300. read_write: bool,
  301. lock_name: str,
  302. lock_key: str,
  303. token: str,
  304. ) -> None:
  305. """Renew the lock.
  306. Note: this is a static method, rather than using self.*, so that we
  307. don't end up with a reference to `self` in the reactor, which would stop
  308. this from being cleaned up if we dropped the context manager.
  309. """
  310. table = "worker_read_write_locks" if read_write else "worker_locks"
  311. await store.db_pool.simple_update(
  312. table=table,
  313. keyvalues={
  314. "lock_name": lock_name,
  315. "lock_key": lock_key,
  316. "token": token,
  317. },
  318. updatevalues={"last_renewed_ts": clock.time_msec()},
  319. desc="renew_lock",
  320. )
  321. async def is_still_valid(self) -> bool:
  322. """Check if the lock is still held by us"""
  323. last_renewed_ts = await self._store.db_pool.simple_select_one_onecol(
  324. table=self._table,
  325. keyvalues={
  326. "lock_name": self._lock_name,
  327. "lock_key": self._lock_key,
  328. "token": self._token,
  329. },
  330. retcol="last_renewed_ts",
  331. allow_none=True,
  332. desc="is_lock_still_valid",
  333. )
  334. return (
  335. last_renewed_ts is not None
  336. and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
  337. )
  338. async def __aenter__(self) -> None:
  339. if self._dropped:
  340. raise Exception("Cannot reuse a Lock object")
  341. async def __aexit__(
  342. self,
  343. _exctype: Optional[Type[BaseException]],
  344. _excinst: Optional[BaseException],
  345. _exctb: Optional[TracebackType],
  346. ) -> bool:
  347. await self.release()
  348. return False
  349. async def release(self) -> None:
  350. """Release the lock.
  351. This is automatically called when using the lock as a context manager.
  352. """
  353. if self._dropped:
  354. return
  355. if self._looping_call.running:
  356. self._looping_call.stop()
  357. await self._store.db_pool.simple_delete(
  358. table=self._table,
  359. keyvalues={
  360. "lock_name": self._lock_name,
  361. "lock_key": self._lock_key,
  362. "token": self._token,
  363. },
  364. desc="drop_lock",
  365. )
  366. if self._read_write:
  367. self._store._live_read_write_lock_tokens.pop(
  368. (self._lock_name, self._lock_key, self._token), None
  369. )
  370. else:
  371. self._store._live_lock_tokens.pop((self._lock_name, self._lock_key), None)
  372. self._dropped = True
  373. def __del__(self) -> None:
  374. if not self._dropped:
  375. # We should not be dropped without the lock being released (unless
  376. # we're shutting down), but if we are then let's at least stop
  377. # renewing the lock.
  378. if self._looping_call.running:
  379. self._looping_call.stop()
  380. if self._reactor.running:
  381. logger.error(
  382. "Lock for (%s, %s) dropped without being released",
  383. self._lock_name,
  384. self._lock_key,
  385. )