123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344 |
- # Copyright 2021 Matrix.org Foundation C.I.C.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- from types import TracebackType
- from typing import TYPE_CHECKING, Optional, Set, Tuple, Type
- from weakref import WeakValueDictionary
- from twisted.internet.interfaces import IReactorCore
- from synapse.metrics.background_process_metrics import wrap_as_background_process
- from synapse.storage._base import SQLBaseStore
- from synapse.storage.database import (
- DatabasePool,
- LoggingDatabaseConnection,
- LoggingTransaction,
- )
- from synapse.util import Clock
- from synapse.util.stringutils import random_string
- if TYPE_CHECKING:
- from synapse.server import HomeServer
- logger = logging.getLogger(__name__)
- # How often to renew an acquired lock by updating the `last_renewed_ts` time in
- # the lock table.
- _RENEWAL_INTERVAL_MS = 30 * 1000
- # How long before an acquired lock times out.
- _LOCK_TIMEOUT_MS = 2 * 60 * 1000
- class LockStore(SQLBaseStore):
- """Provides a best effort distributed lock between worker instances.
- Locks are identified by a name and key. A lock is acquired by inserting into
- the `worker_locks` table if a) there is no existing row for the name/key or
- b) the existing row has a `last_renewed_ts` older than `_LOCK_TIMEOUT_MS`.
- When a lock is taken out the instance inserts a random `token`, the instance
- that holds that token holds the lock until it drops (or times out).
- The instance that holds the lock should regularly update the
- `last_renewed_ts` column with the current time.
- """
- def __init__(
- self,
- database: DatabasePool,
- db_conn: LoggingDatabaseConnection,
- hs: "HomeServer",
- ):
- super().__init__(database, db_conn, hs)
- self._reactor = hs.get_reactor()
- self._instance_name = hs.get_instance_id()
- # A map from `(lock_name, lock_key)` to the token of any locks that we
- # think we currently hold.
- self._live_tokens: WeakValueDictionary[
- Tuple[str, str], Lock
- ] = WeakValueDictionary()
- # When we shut down we want to remove the locks. Technically this can
- # lead to a race, as we may drop the lock while we are still processing.
- # However, a) it should be a small window, b) the lock is best effort
- # anyway and c) we want to really avoid leaking locks when we restart.
- hs.get_reactor().addSystemEventTrigger(
- "before",
- "shutdown",
- self._on_shutdown,
- )
- self._acquiring_locks: Set[Tuple[str, str]] = set()
- @wrap_as_background_process("LockStore._on_shutdown")
- async def _on_shutdown(self) -> None:
- """Called when the server is shutting down"""
- logger.info("Dropping held locks due to shutdown")
- # We need to take a copy of the tokens dict as dropping the locks will
- # cause the dictionary to change.
- locks = dict(self._live_tokens)
- for lock in locks.values():
- await lock.release()
- logger.info("Dropped locks due to shutdown")
- async def try_acquire_lock(self, lock_name: str, lock_key: str) -> Optional["Lock"]:
- """Try to acquire a lock for the given name/key. Will return an async
- context manager if the lock is successfully acquired, which *must* be
- used (otherwise the lock will leak).
- """
- if (lock_name, lock_key) in self._acquiring_locks:
- return None
- try:
- self._acquiring_locks.add((lock_name, lock_key))
- return await self._try_acquire_lock(lock_name, lock_key)
- finally:
- self._acquiring_locks.discard((lock_name, lock_key))
- async def _try_acquire_lock(
- self, lock_name: str, lock_key: str
- ) -> Optional["Lock"]:
- """Try to acquire a lock for the given name/key. Will return an async
- context manager if the lock is successfully acquired, which *must* be
- used (otherwise the lock will leak).
- """
- # Check if this process has taken out a lock and if it's still valid.
- lock = self._live_tokens.get((lock_name, lock_key))
- if lock and await lock.is_still_valid():
- return None
- now = self._clock.time_msec()
- token = random_string(6)
- def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool:
- # We take out the lock if either a) there is no row for the lock
- # already, b) the existing row has timed out, or c) the row is
- # for this instance (which means the process got killed and
- # restarted)
- sql = """
- INSERT INTO worker_locks (lock_name, lock_key, instance_name, token, last_renewed_ts)
- VALUES (?, ?, ?, ?, ?)
- ON CONFLICT (lock_name, lock_key)
- DO UPDATE
- SET
- token = EXCLUDED.token,
- instance_name = EXCLUDED.instance_name,
- last_renewed_ts = EXCLUDED.last_renewed_ts
- WHERE
- worker_locks.last_renewed_ts < ?
- OR worker_locks.instance_name = EXCLUDED.instance_name
- """
- txn.execute(
- sql,
- (
- lock_name,
- lock_key,
- self._instance_name,
- token,
- now,
- now - _LOCK_TIMEOUT_MS,
- ),
- )
- # We only acquired the lock if we inserted or updated the table.
- return bool(txn.rowcount)
- did_lock = await self.db_pool.runInteraction(
- "try_acquire_lock",
- _try_acquire_lock_txn,
- # We can autocommit here as we're executing a single query, this
- # will avoid serialization errors.
- db_autocommit=True,
- )
- if not did_lock:
- return None
- lock = Lock(
- self._reactor,
- self._clock,
- self,
- lock_name=lock_name,
- lock_key=lock_key,
- token=token,
- )
- self._live_tokens[(lock_name, lock_key)] = lock
- return lock
- async def _is_lock_still_valid(
- self, lock_name: str, lock_key: str, token: str
- ) -> bool:
- """Checks whether this instance still holds the lock."""
- last_renewed_ts = await self.db_pool.simple_select_one_onecol(
- table="worker_locks",
- keyvalues={
- "lock_name": lock_name,
- "lock_key": lock_key,
- "token": token,
- },
- retcol="last_renewed_ts",
- allow_none=True,
- desc="is_lock_still_valid",
- )
- return (
- last_renewed_ts is not None
- and self._clock.time_msec() - _LOCK_TIMEOUT_MS < last_renewed_ts
- )
- async def _renew_lock(self, lock_name: str, lock_key: str, token: str) -> None:
- """Attempt to renew the lock if we still hold it."""
- await self.db_pool.simple_update(
- table="worker_locks",
- keyvalues={
- "lock_name": lock_name,
- "lock_key": lock_key,
- "token": token,
- },
- updatevalues={"last_renewed_ts": self._clock.time_msec()},
- desc="renew_lock",
- )
- async def _drop_lock(self, lock_name: str, lock_key: str, token: str) -> None:
- """Attempt to drop the lock, if we still hold it"""
- await self.db_pool.simple_delete(
- table="worker_locks",
- keyvalues={
- "lock_name": lock_name,
- "lock_key": lock_key,
- "token": token,
- },
- desc="drop_lock",
- )
- self._live_tokens.pop((lock_name, lock_key), None)
- class Lock:
- """An async context manager that manages an acquired lock, ensuring it is
- regularly renewed and dropping it when the context manager exits.
- The lock object has an `is_still_valid` method which can be used to
- double-check the lock is still valid, if e.g. processing work in a loop.
- For example:
- lock = await self.store.try_acquire_lock(...)
- if not lock:
- return
- async with lock:
- for item in work:
- await process(item)
- if not await lock.is_still_valid():
- break
- """
- def __init__(
- self,
- reactor: IReactorCore,
- clock: Clock,
- store: LockStore,
- lock_name: str,
- lock_key: str,
- token: str,
- ) -> None:
- self._reactor = reactor
- self._clock = clock
- self._store = store
- self._lock_name = lock_name
- self._lock_key = lock_key
- self._token = token
- self._looping_call = clock.looping_call(
- self._renew, _RENEWAL_INTERVAL_MS, store, lock_name, lock_key, token
- )
- self._dropped = False
- @staticmethod
- @wrap_as_background_process("Lock._renew")
- async def _renew(
- store: LockStore,
- lock_name: str,
- lock_key: str,
- token: str,
- ) -> None:
- """Renew the lock.
- Note: this is a static method, rather than using self.*, so that we
- don't end up with a reference to `self` in the reactor, which would stop
- this from being cleaned up if we dropped the context manager.
- """
- await store._renew_lock(lock_name, lock_key, token)
- async def is_still_valid(self) -> bool:
- """Check if the lock is still held by us"""
- return await self._store._is_lock_still_valid(
- self._lock_name, self._lock_key, self._token
- )
- async def __aenter__(self) -> None:
- if self._dropped:
- raise Exception("Cannot reuse a Lock object")
- async def __aexit__(
- self,
- _exctype: Optional[Type[BaseException]],
- _excinst: Optional[BaseException],
- _exctb: Optional[TracebackType],
- ) -> bool:
- await self.release()
- return False
- async def release(self) -> None:
- """Release the lock.
- This is automatically called when using the lock as a context manager.
- """
- if self._dropped:
- return
- if self._looping_call.running:
- self._looping_call.stop()
- await self._store._drop_lock(self._lock_name, self._lock_key, self._token)
- self._dropped = True
- def __del__(self) -> None:
- if not self._dropped:
- # We should not be dropped without the lock being released (unless
- # we're shutting down), but if we are then let's at least stop
- # renewing the lock.
- if self._looping_call.running:
- self._looping_call.stop()
- if self._reactor.running:
- logger.error(
- "Lock for (%s, %s) dropped without being released",
- self._lock_name,
- self._lock_key,
- )
|