lock.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344
  1. # Copyright 2021 Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from types import TracebackType
  16. from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
  17. from weakref import WeakValueDictionary
  18. from twisted.internet.interfaces import IReactorCore
  19. from synapse.metrics.background_process_metrics import wrap_as_background_process
  20. from synapse.storage._base import SQLBaseStore
  21. from synapse.storage.database import (
  22. DatabasePool,
  23. LoggingDatabaseConnection,
  24. LoggingTransaction,
  25. )
  26. from synapse.util import Clock
  27. from synapse.util.stringutils import random_string
  28. if TYPE_CHECKING:
  29. from synapse.server import HomeServer
  30. logger = logging.getLogger(__name__)
  31. # How often to renew an acquired lock by updating the `last_renewed_ts` time in
  32. # the lock table.
  33. _RENEWAL_INTERVAL_MS = 30 * 1000
  34. # How long before an acquired lock times out.
  35. _LOCK_TIMEOUT_MS = 2 * 60 * 1000
  36. class LockStore(SQLBaseStore):
  37. """Provides a best effort distributed lock between worker instances.
  38. Locks are identified by a name and key. A lock is acquired by inserting into
  39. the `worker_locks` table if a) there is no existing row for the name/key or
  40. b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.
  41. When a lock is taken out the instance inserts a random `token`, the instance
  42. that holds that token holds the lock until it drops (or times out).
  43. The instance that holds the lock should regularly update the
  44. `last_renewed_ts` column with the current time.
  45. """
  46. def __init__(
  47. self,
  48. database: DatabasePool,
  49. db_conn: LoggingDatabaseConnection,
  50. hs: "HomeServer",
  51. ):
  52. super().__init__(database, db_conn, hs)
  53. self._reactor = hs.get_reactor()
  54. self._instance_name = hs.get_instance_id()
  55. # A map from `(lock_name, lock_key)` to the token of any locks that we
  56. # think we currently hold.
  57. self._live_tokens: WeakValueDictionary[
  58. Tuple[str, str], Lock
  59. ] = WeakValueDictionary()
  60. # When we shut down we want to remove the locks. Technically this can
  61. # lead to a race, as we may drop the lock while we are still processing.
  62. # However, a) it should be a small window, b) the lock is best effort
  63. # anyway and c) we want to really avoid leaking locks when we restart.
  64. hs.get_reactor().addSystemEventTrigger(
  65. "before",
  66. "shutdown",
  67. self._on_shutdown,
  68. )
  69. self._acquiring_locks: Set[Tuple[str, str]] = set()
  70. @wrap_as_background_process("LockStore._on_shutdown")
  71. async def _on_shutdown(self) -> None:
  72. """Called when the server is shutting down"""
  73. logger.info("Dropping held locks due to shutdown")
  74. # We need to take a copy of the tokens dict as dropping the locks will
  75. # cause the dictionary to change.
  76. locks = dict(self._live_tokens)
  77. for lock in locks.values():
  78. await lock.release()
  79. logger.info("Dropped locks due to shutdown")
  80. async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
  81. """Try to acquire a lock for the given name/key. Will return an async
  82. context manager if the lock is successfully acquired, which *must* be
  83. used (otherwise the lock will leak).
  84. """
  85. if (lock_name, lock_key) in self._acquiring_locks:
  86. return None
  87. try:
  88. self._acquiring_locks.add((lock_name, lock_key))
  89. return await self._try_acquire_lock(lock_name, lock_key)
  90. finally:
  91. self._acquiring_locks.discard((lock_name, lock_key))
  92. async def _try_acquire_lock(
  93. self, lock_name: str, lock_key: str
  94. ) -> Optional["Lock"]:
  95. """Try to acquire a lock for the given name/key. Will return an async
  96. context manager if the lock is successfully acquired, which *must* be
  97. used (otherwise the lock will leak).
  98. """
  99. # Check if this process has taken out a lock and if it's still valid.
  100. lock = self._live_tokens.get((lock_name, lock_key))
  101. if lock and await lock.is_still_valid():
  102. return None
  103. now = self._clock.time_msec()
  104. token = random_string(6)
  105. def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
  106. # We take out the lock if either a) there is no row for the lock
  107. # already, b) the existing row has timed out, or c) the row is
  108. # for this instance (which means the process got killed and
  109. # restarted)
  110. sql = """
  111. INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
  112. VALUES (?, ?, ?, ?, ?)
  113. ON CONFLICT (lock_name, lock_key)
  114. DO UPDATE
  115. SET
  116. token = EXCLUDED.token,
  117. instance_name = EXCLUDED.instance_name,
  118. last_renewed_ts = EXCLUDED.last_renewed_ts
  119. WHERE
  120. worker_locks.last_renewed_ts < ?
  121. OR worker_locks.instance_name = EXCLUDED.instance_name
  122. """
  123. txn.execute(
  124. sql,
  125. (
  126. lock_name,
  127. lock_key,
  128. self._instance_name,
  129. token,
  130. now,
  131. now - _LOCK_TIMEOUT_MS,
  132. ),
  133. )
  134. # We only acquired the lock if we inserted or updated the table.
  135. return bool(txn.rowcount)
  136. did_lock = await self.db_pool.runInteraction(
  137. "try_acquire_lock",
  138. _try_acquire_lock_txn,
  139. # We can autocommit here as we're executing a single query, this
  140. # will avoid serialization errors.
  141. db_autocommit=True,
  142. )
  143. if not did_lock:
  144. return None
  145. lock = Lock(
  146. self._reactor,
  147. self._clock,
  148. self,
  149. lock_name=lock_name,
  150. lock_key=lock_key,
  151. token=token,
  152. )
  153. self._live_tokens[(lock_name, lock_key)] = lock
  154. return lock
  155. async def _is_lock_still_valid(
  156. self, lock_name: str, lock_key: str, token: str
  157. ) -> bool:
  158. """Checks whether this instance still holds the lock."""
  159. last_renewed_ts = await self.db_pool.simple_select_one_onecol(
  160. table="worker_locks",
  161. keyvalues={
  162. "lock_name": lock_name,
  163. "lock_key": lock_key,
  164. "token": token,
  165. },
  166. retcol="last_renewed_ts",
  167. allow_none=True,
  168. desc="is_lock_still_valid",
  169. )
  170. return (
  171. last_renewed_ts is not None
  172. and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
  173. )
  174. async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
  175. """Attempt to renew the lock if we still hold it."""
  176. await self.db_pool.simple_update(
  177. table="worker_locks",
  178. keyvalues={
  179. "lock_name": lock_name,
  180. "lock_key": lock_key,
  181. "token": token,
  182. },
  183. updatevalues={"last_renewed_ts": self._clock.time_msec()},
  184. desc="renew_lock",
  185. )
  186. async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
  187. """Attempt to drop the lock, if we still hold it"""
  188. await self.db_pool.simple_delete(
  189. table="worker_locks",
  190. keyvalues={
  191. "lock_name": lock_name,
  192. "lock_key": lock_key,
  193. "token": token,
  194. },
  195. desc="drop_lock",
  196. )
  197. self._live_tokens.pop((lock_name, lock_key), None)
  198. class Lock:
  199. """An async context manager that manages an acquired lock, ensuring it is
  200. regularly renewed and dropping it when the context manager exits.
  201. The lock object has an `is_still_valid` method which can be used to
  202. double-check the lock is still valid, if e.g. processing work in a loop.
  203. For example:
  204. lock = await self.store.try_acquire_lock(...)
  205. if not lock:
  206. return
  207. async with lock:
  208. for item in work:
  209. await process(item)
  210. if not await lock.is_still_valid():
  211. break
  212. """
  213. def __init__(
  214. self,
  215. reactor: IReactorCore,
  216. clock: Clock,
  217. store: LockStore,
  218. lock_name: str,
  219. lock_key: str,
  220. token: str,
  221. ) -> None:
  222. self._reactor = reactor
  223. self._clock = clock
  224. self._store = store
  225. self._lock_name = lock_name
  226. self._lock_key = lock_key
  227. self._token = token
  228. self._looping_call = clock.looping_call(
  229. self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
  230. )
  231. self._dropped = False
  232. @staticmethod
  233. @wrap_as_background_process("Lock._renew")
  234. async def _renew(
  235. store: LockStore,
  236. lock_name: str,
  237. lock_key: str,
  238. token: str,
  239. ) -> None:
  240. """Renew the lock.
  241. Note: this is a static method, rather than using self.*, so that we
  242. don't end up with a reference to `self` in the reactor, which would stop
  243. this from being cleaned up if we dropped the context manager.
  244. """
  245. await store._renew_lock(lock_name, lock_key, token)
  246. async def is_still_valid(self) -> bool:
  247. """Check if the lock is still held by us"""
  248. return await self._store._is_lock_still_valid(
  249. self._lock_name, self._lock_key, self._token
  250. )
  251. async def __aenter__(self) -> None:
  252. if self._dropped:
  253. raise Exception("Cannot reuse a Lock object")
  254. async def __aexit__(
  255. self,
  256. _exctype: Optional[Type[BaseException]],
  257. _excinst: Optional[BaseException],
  258. _exctb: Optional[TracebackType],
  259. ) -> bool:
  260. await self.release()
  261. return False
  262. async def release(self) -> None:
  263. """Release the lock.
  264. This is automatically called when using the lock as a context manager.
  265. """
  266. if self._dropped:
  267. return
  268. if self._looping_call.running:
  269. self._looping_call.stop()
  270. await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
  271. self._dropped = True
  272. def __del__(self) -> None:
  273. if not self._dropped:
  274. # We should not be dropped without the lock being released (unless
  275. # we're shutting down), but if we are then let's at least stop
  276. # renewing the lock.
  277. if self._looping_call.running:
  278. self._looping_call.stop()
  279. if self._reactor.running:
  280. logger.error(
  281. "Lock for (%s, %s) dropped without being released",
  282. self._lock_name,
  283. self._lock_key,
  284. )