events.py 82 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2018-2019 New Vector Ltd
  4. # Copyright 2019 The Matrix.org Foundation C.I.C.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import itertools
  18. import logging
  19. from collections import OrderedDict, namedtuple
  20. from typing import (
  21. TYPE_CHECKING,
  22. Any,
  23. Dict,
  24. Generator,
  25. Iterable,
  26. List,
  27. Optional,
  28. Set,
  29. Tuple,
  30. )
  31. import attr
  32. from prometheus_client import Counter
  33. import synapse.metrics
  34. from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
  35. from synapse.api.room_versions import RoomVersions
  36. from synapse.crypto.event_signing import compute_event_reference_hash
  37. from synapse.events import EventBase # noqa: F401
  38. from synapse.events.snapshot import EventContext # noqa: F401
  39. from synapse.logging.utils import log_function
  40. from synapse.storage._base import db_to_json, make_in_list_sql_clause
  41. from synapse.storage.database import DatabasePool, LoggingTransaction
  42. from synapse.storage.databases.main.search import SearchEntry
  43. from synapse.storage.types import Connection
  44. from synapse.storage.util.id_generators import MultiWriterIdGenerator
  45. from synapse.storage.util.sequence import SequenceGenerator
  46. from synapse.types import StateMap, get_domain_from_id
  47. from synapse.util import json_encoder
  48. from synapse.util.iterutils import batch_iter, sorted_topologically
  49. if TYPE_CHECKING:
  50. from synapse.server import HomeServer
  51. from synapse.storage.databases.main import DataStore
  52. logger = logging.getLogger(__name__)
  53. persist_event_counter = Counter("synapse_storage_events_persisted_events", "")
  54. event_counter = Counter(
  55. "synapse_storage_events_persisted_events_sep",
  56. "",
  57. ["type", "origin_type", "origin_entity"],
  58. )
  59. _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
  60. @attr.s(slots=True)
  61. class DeltaState:
  62. """Deltas to use to update the `current_state_events` table.
  63. Attributes:
  64. to_delete: List of type/state_keys to delete from current state
  65. to_insert: Map of state to upsert into current state
  66. no_longer_in_room: The server is not longer in the room, so the room
  67. should e.g. be removed from `current_state_events` table.
  68. """
  69. to_delete = attr.ib(type=List[Tuple[str, str]])
  70. to_insert = attr.ib(type=StateMap[str])
  71. no_longer_in_room = attr.ib(type=bool, default=False)
  72. class PersistEventsStore:
  73. """Contains all the functions for writing events to the database.
  74. Should only be instantiated on one process (when using a worker mode setup).
  75. Note: This is not part of the `DataStore` mixin.
  76. """
  77. def __init__(
  78. self,
  79. hs: "HomeServer",
  80. db: DatabasePool,
  81. main_data_store: "DataStore",
  82. db_conn: Connection,
  83. ):
  84. self.hs = hs
  85. self.db_pool = db
  86. self.store = main_data_store
  87. self.database_engine = db.engine
  88. self._clock = hs.get_clock()
  89. self._instance_name = hs.get_instance_name()
  90. self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
  91. self.is_mine_id = hs.is_mine_id
  92. # Ideally we'd move these ID gens here, unfortunately some other ID
  93. # generators are chained off them so doing so is a bit of a PITA.
  94. self._backfill_id_gen = (
  95. self.store._backfill_id_gen
  96. ) # type: MultiWriterIdGenerator
  97. self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
  98. # This should only exist on instances that are configured to write
  99. assert (
  100. hs.get_instance_name() in hs.config.worker.writers.events
  101. ), "Can only instantiate EventsStore on master"
  102. async def _persist_events_and_state_updates(
  103. self,
  104. events_and_contexts: List[Tuple[EventBase, EventContext]],
  105. current_state_for_room: Dict[str, StateMap[str]],
  106. state_delta_for_room: Dict[str, DeltaState],
  107. new_forward_extremeties: Dict[str, List[str]],
  108. backfilled: bool = False,
  109. ) -> None:
  110. """Persist a set of events alongside updates to the current state and
  111. forward extremities tables.
  112. Args:
  113. events_and_contexts:
  114. current_state_for_room: Map from room_id to the current state of
  115. the room based on forward extremities
  116. state_delta_for_room: Map from room_id to the delta to apply to
  117. room state
  118. new_forward_extremities: Map from room_id to list of event IDs
  119. that are the new forward extremities of the room.
  120. backfilled
  121. Returns:
  122. Resolves when the events have been persisted
  123. """
  124. # We want to calculate the stream orderings as late as possible, as
  125. # we only notify after all events with a lesser stream ordering have
  126. # been persisted. I.e. if we spend 10s inside the with block then
  127. # that will delay all subsequent events from being notified about.
  128. # Hence why we do it down here rather than wrapping the entire
  129. # function.
  130. #
  131. # Its safe to do this after calculating the state deltas etc as we
  132. # only need to protect the *persistence* of the events. This is to
  133. # ensure that queries of the form "fetch events since X" don't
  134. # return events and stream positions after events that are still in
  135. # flight, as otherwise subsequent requests "fetch event since Y"
  136. # will not return those events.
  137. #
  138. # Note: Multiple instances of this function cannot be in flight at
  139. # the same time for the same room.
  140. if backfilled:
  141. stream_ordering_manager = self._backfill_id_gen.get_next_mult(
  142. len(events_and_contexts)
  143. )
  144. else:
  145. stream_ordering_manager = self._stream_id_gen.get_next_mult(
  146. len(events_and_contexts)
  147. )
  148. async with stream_ordering_manager as stream_orderings:
  149. for (event, context), stream in zip(events_and_contexts, stream_orderings):
  150. event.internal_metadata.stream_ordering = stream
  151. await self.db_pool.runInteraction(
  152. "persist_events",
  153. self._persist_events_txn,
  154. events_and_contexts=events_and_contexts,
  155. backfilled=backfilled,
  156. state_delta_for_room=state_delta_for_room,
  157. new_forward_extremeties=new_forward_extremeties,
  158. )
  159. persist_event_counter.inc(len(events_and_contexts))
  160. if not backfilled:
  161. # backfilled events have negative stream orderings, so we don't
  162. # want to set the event_persisted_position to that.
  163. synapse.metrics.event_persisted_position.set(
  164. events_and_contexts[-1][0].internal_metadata.stream_ordering
  165. )
  166. for event, context in events_and_contexts:
  167. if context.app_service:
  168. origin_type = "local"
  169. origin_entity = context.app_service.id
  170. elif self.hs.is_mine_id(event.sender):
  171. origin_type = "local"
  172. origin_entity = "*client*"
  173. else:
  174. origin_type = "remote"
  175. origin_entity = get_domain_from_id(event.sender)
  176. event_counter.labels(event.type, origin_type, origin_entity).inc()
  177. for room_id, new_state in current_state_for_room.items():
  178. self.store.get_current_state_ids.prefill((room_id,), new_state)
  179. for room_id, latest_event_ids in new_forward_extremeties.items():
  180. self.store.get_latest_event_ids_in_room.prefill(
  181. (room_id,), list(latest_event_ids)
  182. )
  183. async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
  184. """Filter the supplied list of event_ids to get those which are prev_events of
  185. existing (non-outlier/rejected) events.
  186. Args:
  187. event_ids: event ids to filter
  188. Returns:
  189. Filtered event ids
  190. """
  191. results = [] # type: List[str]
  192. def _get_events_which_are_prevs_txn(txn, batch):
  193. sql = """
  194. SELECT prev_event_id, internal_metadata
  195. FROM event_edges
  196. INNER JOIN events USING (event_id)
  197. LEFT JOIN rejections USING (event_id)
  198. LEFT JOIN event_json USING (event_id)
  199. WHERE
  200. NOT events.outlier
  201. AND rejections.event_id IS NULL
  202. AND
  203. """
  204. clause, args = make_in_list_sql_clause(
  205. self.database_engine, "prev_event_id", batch
  206. )
  207. txn.execute(sql + clause, args)
  208. results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
  209. for chunk in batch_iter(event_ids, 100):
  210. await self.db_pool.runInteraction(
  211. "_get_events_which_are_prevs", _get_events_which_are_prevs_txn, chunk
  212. )
  213. return results
  214. async def _get_prevs_before_rejected(self, event_ids: Iterable[str]) -> Set[str]:
  215. """Get soft-failed ancestors to remove from the extremities.
  216. Given a set of events, find all those that have been soft-failed or
  217. rejected. Returns those soft failed/rejected events and their prev
  218. events (whether soft-failed/rejected or not), and recurses up the
  219. prev-event graph until it finds no more soft-failed/rejected events.
  220. This is used to find extremities that are ancestors of new events, but
  221. are separated by soft failed events.
  222. Args:
  223. event_ids: Events to find prev events for. Note that these must have
  224. already been persisted.
  225. Returns:
  226. The previous events.
  227. """
  228. # The set of event_ids to return. This includes all soft-failed events
  229. # and their prev events.
  230. existing_prevs = set()
  231. def _get_prevs_before_rejected_txn(txn, batch):
  232. to_recursively_check = batch
  233. while to_recursively_check:
  234. sql = """
  235. SELECT
  236. event_id, prev_event_id, internal_metadata,
  237. rejections.event_id IS NOT NULL
  238. FROM event_edges
  239. INNER JOIN events USING (event_id)
  240. LEFT JOIN rejections USING (event_id)
  241. LEFT JOIN event_json USING (event_id)
  242. WHERE
  243. NOT events.outlier
  244. AND
  245. """
  246. clause, args = make_in_list_sql_clause(
  247. self.database_engine, "event_id", to_recursively_check
  248. )
  249. txn.execute(sql + clause, args)
  250. to_recursively_check = []
  251. for event_id, prev_event_id, metadata, rejected in txn:
  252. if prev_event_id in existing_prevs:
  253. continue
  254. soft_failed = db_to_json(metadata).get("soft_failed")
  255. if soft_failed or rejected:
  256. to_recursively_check.append(prev_event_id)
  257. existing_prevs.add(prev_event_id)
  258. for chunk in batch_iter(event_ids, 100):
  259. await self.db_pool.runInteraction(
  260. "_get_prevs_before_rejected", _get_prevs_before_rejected_txn, chunk
  261. )
  262. return existing_prevs
  263. @log_function
  264. def _persist_events_txn(
  265. self,
  266. txn: LoggingTransaction,
  267. events_and_contexts: List[Tuple[EventBase, EventContext]],
  268. backfilled: bool,
  269. state_delta_for_room: Dict[str, DeltaState] = {},
  270. new_forward_extremeties: Dict[str, List[str]] = {},
  271. ):
  272. """Insert some number of room events into the necessary database tables.
  273. Rejected events are only inserted into the events table, the events_json table,
  274. and the rejections table. Things reading from those table will need to check
  275. whether the event was rejected.
  276. Args:
  277. txn
  278. events_and_contexts: events to persist
  279. backfilled: True if the events were backfilled
  280. delete_existing True to purge existing table rows for the events
  281. from the database. This is useful when retrying due to
  282. IntegrityError.
  283. state_delta_for_room: The current-state delta for each room.
  284. new_forward_extremetie: The new forward extremities for each room.
  285. For each room, a list of the event ids which are the forward
  286. extremities.
  287. """
  288. all_events_and_contexts = events_and_contexts
  289. min_stream_order = events_and_contexts[0][0].internal_metadata.stream_ordering
  290. max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
  291. # stream orderings should have been assigned by now
  292. assert min_stream_order
  293. assert max_stream_order
  294. self._update_forward_extremities_txn(
  295. txn,
  296. new_forward_extremities=new_forward_extremeties,
  297. max_stream_order=max_stream_order,
  298. )
  299. # Ensure that we don't have the same event twice.
  300. events_and_contexts = self._filter_events_and_contexts_for_duplicates(
  301. events_and_contexts
  302. )
  303. self._update_room_depths_txn(
  304. txn, events_and_contexts=events_and_contexts, backfilled=backfilled
  305. )
  306. # _update_outliers_txn filters out any events which have already been
  307. # persisted, and returns the filtered list.
  308. events_and_contexts = self._update_outliers_txn(
  309. txn, events_and_contexts=events_and_contexts
  310. )
  311. # From this point onwards the events are only events that we haven't
  312. # seen before.
  313. self._store_event_txn(txn, events_and_contexts=events_and_contexts)
  314. self._persist_transaction_ids_txn(txn, events_and_contexts)
  315. # Insert into event_to_state_groups.
  316. self._store_event_state_mappings_txn(txn, events_and_contexts)
  317. self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
  318. # _store_rejected_events_txn filters out any events which were
  319. # rejected, and returns the filtered list.
  320. events_and_contexts = self._store_rejected_events_txn(
  321. txn, events_and_contexts=events_and_contexts
  322. )
  323. # From this point onwards the events are only ones that weren't
  324. # rejected.
  325. self._update_metadata_tables_txn(
  326. txn,
  327. events_and_contexts=events_and_contexts,
  328. all_events_and_contexts=all_events_and_contexts,
  329. backfilled=backfilled,
  330. )
  331. # We call this last as it assumes we've inserted the events into
  332. # room_memberships, where applicable.
  333. self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
  334. def _persist_event_auth_chain_txn(
  335. self,
  336. txn: LoggingTransaction,
  337. events: List[EventBase],
  338. ) -> None:
  339. # We only care about state events, so this if there are no state events.
  340. if not any(e.is_state() for e in events):
  341. return
  342. # We want to store event_auth mappings for rejected events, as they're
  343. # used in state res v2.
  344. # This is only necessary if the rejected event appears in an accepted
  345. # event's auth chain, but its easier for now just to store them (and
  346. # it doesn't take much storage compared to storing the entire event
  347. # anyway).
  348. self.db_pool.simple_insert_many_txn(
  349. txn,
  350. table="event_auth",
  351. values=[
  352. {
  353. "event_id": event.event_id,
  354. "room_id": event.room_id,
  355. "auth_id": auth_id,
  356. }
  357. for event in events
  358. for auth_id in event.auth_event_ids()
  359. if event.is_state()
  360. ],
  361. )
  362. # We now calculate chain ID/sequence numbers for any state events we're
  363. # persisting. We ignore out of band memberships as we're not in the room
  364. # and won't have their auth chain (we'll fix it up later if we join the
  365. # room).
  366. #
  367. # See: docs/auth_chain_difference_algorithm.md
  368. # We ignore legacy rooms that we aren't filling the chain cover index
  369. # for.
  370. rows = self.db_pool.simple_select_many_txn(
  371. txn,
  372. table="rooms",
  373. column="room_id",
  374. iterable={event.room_id for event in events if event.is_state()},
  375. keyvalues={},
  376. retcols=("room_id", "has_auth_chain_index"),
  377. )
  378. rooms_using_chain_index = {
  379. row["room_id"] for row in rows if row["has_auth_chain_index"]
  380. }
  381. state_events = {
  382. event.event_id: event
  383. for event in events
  384. if event.is_state() and event.room_id in rooms_using_chain_index
  385. }
  386. if not state_events:
  387. return
  388. # We need to know the type/state_key and auth events of the events we're
  389. # calculating chain IDs for. We don't rely on having the full Event
  390. # instances as we'll potentially be pulling more events from the DB and
  391. # we don't need the overhead of fetching/parsing the full event JSON.
  392. event_to_types = {
  393. e.event_id: (e.type, e.state_key) for e in state_events.values()
  394. }
  395. event_to_auth_chain = {
  396. e.event_id: e.auth_event_ids() for e in state_events.values()
  397. }
  398. event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
  399. self._add_chain_cover_index(
  400. txn,
  401. self.db_pool,
  402. self.store.event_chain_id_gen,
  403. event_to_room_id,
  404. event_to_types,
  405. event_to_auth_chain,
  406. )
  407. @classmethod
  408. def _add_chain_cover_index(
  409. cls,
  410. txn,
  411. db_pool: DatabasePool,
  412. event_chain_id_gen: SequenceGenerator,
  413. event_to_room_id: Dict[str, str],
  414. event_to_types: Dict[str, Tuple[str, str]],
  415. event_to_auth_chain: Dict[str, List[str]],
  416. ) -> None:
  417. """Calculate the chain cover index for the given events.
  418. Args:
  419. event_to_room_id: Event ID to the room ID of the event
  420. event_to_types: Event ID to type and state_key of the event
  421. event_to_auth_chain: Event ID to list of auth event IDs of the
  422. event (events with no auth events can be excluded).
  423. """
  424. # Map from event ID to chain ID/sequence number.
  425. chain_map = {} # type: Dict[str, Tuple[int, int]]
  426. # Set of event IDs to calculate chain ID/seq numbers for.
  427. events_to_calc_chain_id_for = set(event_to_room_id)
  428. # We check if there are any events that need to be handled in the rooms
  429. # we're looking at. These should just be out of band memberships, where
  430. # we didn't have the auth chain when we first persisted.
  431. rows = db_pool.simple_select_many_txn(
  432. txn,
  433. table="event_auth_chain_to_calculate",
  434. keyvalues={},
  435. column="room_id",
  436. iterable=set(event_to_room_id.values()),
  437. retcols=("event_id", "type", "state_key"),
  438. )
  439. for row in rows:
  440. event_id = row["event_id"]
  441. event_type = row["type"]
  442. state_key = row["state_key"]
  443. # (We could pull out the auth events for all rows at once using
  444. # simple_select_many, but this case happens rarely and almost always
  445. # with a single row.)
  446. auth_events = db_pool.simple_select_onecol_txn(
  447. txn,
  448. "event_auth",
  449. keyvalues={"event_id": event_id},
  450. retcol="auth_id",
  451. )
  452. events_to_calc_chain_id_for.add(event_id)
  453. event_to_types[event_id] = (event_type, state_key)
  454. event_to_auth_chain[event_id] = auth_events
  455. # First we get the chain ID and sequence numbers for the events'
  456. # auth events (that aren't also currently being persisted).
  457. #
  458. # Note that there there is an edge case here where we might not have
  459. # calculated chains and sequence numbers for events that were "out
  460. # of band". We handle this case by fetching the necessary info and
  461. # adding it to the set of events to calculate chain IDs for.
  462. missing_auth_chains = {
  463. a_id
  464. for auth_events in event_to_auth_chain.values()
  465. for a_id in auth_events
  466. if a_id not in events_to_calc_chain_id_for
  467. }
  468. # We loop here in case we find an out of band membership and need to
  469. # fetch their auth event info.
  470. while missing_auth_chains:
  471. sql = """
  472. SELECT event_id, events.type, state_key, chain_id, sequence_number
  473. FROM events
  474. INNER JOIN state_events USING (event_id)
  475. LEFT JOIN event_auth_chains USING (event_id)
  476. WHERE
  477. """
  478. clause, args = make_in_list_sql_clause(
  479. txn.database_engine,
  480. "event_id",
  481. missing_auth_chains,
  482. )
  483. txn.execute(sql + clause, args)
  484. missing_auth_chains.clear()
  485. for auth_id, event_type, state_key, chain_id, sequence_number in txn:
  486. event_to_types[auth_id] = (event_type, state_key)
  487. if chain_id is None:
  488. # No chain ID, so the event was persisted out of band.
  489. # We add to list of events to calculate auth chains for.
  490. events_to_calc_chain_id_for.add(auth_id)
  491. event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
  492. txn,
  493. "event_auth",
  494. keyvalues={"event_id": auth_id},
  495. retcol="auth_id",
  496. )
  497. missing_auth_chains.update(
  498. e
  499. for e in event_to_auth_chain[auth_id]
  500. if e not in event_to_types
  501. )
  502. else:
  503. chain_map[auth_id] = (chain_id, sequence_number)
  504. # Now we check if we have any events where we don't have auth chain,
  505. # this should only be out of band memberships.
  506. for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
  507. for auth_id in event_to_auth_chain[event_id]:
  508. if (
  509. auth_id not in chain_map
  510. and auth_id not in events_to_calc_chain_id_for
  511. ):
  512. events_to_calc_chain_id_for.discard(event_id)
  513. # If this is an event we're trying to persist we add it to
  514. # the list of events to calculate chain IDs for next time
  515. # around. (Otherwise we will have already added it to the
  516. # table).
  517. room_id = event_to_room_id.get(event_id)
  518. if room_id:
  519. e_type, state_key = event_to_types[event_id]
  520. db_pool.simple_insert_txn(
  521. txn,
  522. table="event_auth_chain_to_calculate",
  523. values={
  524. "event_id": event_id,
  525. "room_id": room_id,
  526. "type": e_type,
  527. "state_key": state_key,
  528. },
  529. )
  530. # We stop checking the event's auth events since we've
  531. # discarded it.
  532. break
  533. if not events_to_calc_chain_id_for:
  534. return
  535. # Allocate chain ID/sequence numbers to each new event.
  536. new_chain_tuples = cls._allocate_chain_ids(
  537. txn,
  538. db_pool,
  539. event_chain_id_gen,
  540. event_to_room_id,
  541. event_to_types,
  542. event_to_auth_chain,
  543. events_to_calc_chain_id_for,
  544. chain_map,
  545. )
  546. chain_map.update(new_chain_tuples)
  547. db_pool.simple_insert_many_txn(
  548. txn,
  549. table="event_auth_chains",
  550. values=[
  551. {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
  552. for event_id, (c_id, seq) in new_chain_tuples.items()
  553. ],
  554. )
  555. db_pool.simple_delete_many_txn(
  556. txn,
  557. table="event_auth_chain_to_calculate",
  558. keyvalues={},
  559. column="event_id",
  560. iterable=new_chain_tuples,
  561. )
  562. # Now we need to calculate any new links between chains caused by
  563. # the new events.
  564. #
  565. # Links are pairs of chain ID/sequence numbers such that for any
  566. # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
  567. # if and only if there is at least one link (CA, S1) -> (CB, S2)
  568. # where SA >= S1 and S2 >= SB.
  569. #
  570. # We try and avoid adding redundant links to the table, e.g. if we
  571. # have two links between two chains which both start/end at the
  572. # sequence number event (or cross) then one can be safely dropped.
  573. #
  574. # To calculate new links we look at every new event and:
  575. # 1. Fetch the chain ID/sequence numbers of its auth events,
  576. # discarding any that are reachable by other auth events, or
  577. # that have the same chain ID as the event.
  578. # 2. For each retained auth event we:
  579. # a. Add a link from the event's to the auth event's chain
  580. # ID/sequence number; and
  581. # b. Add a link from the event to every chain reachable by the
  582. # auth event.
  583. # Step 1, fetch all existing links from all the chains we've seen
  584. # referenced.
  585. chain_links = _LinkMap()
  586. rows = db_pool.simple_select_many_txn(
  587. txn,
  588. table="event_auth_chain_links",
  589. column="origin_chain_id",
  590. iterable={chain_id for chain_id, _ in chain_map.values()},
  591. keyvalues={},
  592. retcols=(
  593. "origin_chain_id",
  594. "origin_sequence_number",
  595. "target_chain_id",
  596. "target_sequence_number",
  597. ),
  598. )
  599. for row in rows:
  600. chain_links.add_link(
  601. (row["origin_chain_id"], row["origin_sequence_number"]),
  602. (row["target_chain_id"], row["target_sequence_number"]),
  603. new=False,
  604. )
  605. # We do this in toplogical order to avoid adding redundant links.
  606. for event_id in sorted_topologically(
  607. events_to_calc_chain_id_for, event_to_auth_chain
  608. ):
  609. chain_id, sequence_number = chain_map[event_id]
  610. # Filter out auth events that are reachable by other auth
  611. # events. We do this by looking at every permutation of pairs of
  612. # auth events (A, B) to check if B is reachable from A.
  613. reduction = {
  614. a_id
  615. for a_id in event_to_auth_chain.get(event_id, [])
  616. if chain_map[a_id][0] != chain_id
  617. }
  618. for start_auth_id, end_auth_id in itertools.permutations(
  619. event_to_auth_chain.get(event_id, []),
  620. r=2,
  621. ):
  622. if chain_links.exists_path_from(
  623. chain_map[start_auth_id], chain_map[end_auth_id]
  624. ):
  625. reduction.discard(end_auth_id)
  626. # Step 2, figure out what the new links are from the reduced
  627. # list of auth events.
  628. for auth_id in reduction:
  629. auth_chain_id, auth_sequence_number = chain_map[auth_id]
  630. # Step 2a, add link between the event and auth event
  631. chain_links.add_link(
  632. (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
  633. )
  634. # Step 2b, add a link to chains reachable from the auth
  635. # event.
  636. for target_id, target_seq in chain_links.get_links_from(
  637. (auth_chain_id, auth_sequence_number)
  638. ):
  639. if target_id == chain_id:
  640. continue
  641. chain_links.add_link(
  642. (chain_id, sequence_number), (target_id, target_seq)
  643. )
  644. db_pool.simple_insert_many_txn(
  645. txn,
  646. table="event_auth_chain_links",
  647. values=[
  648. {
  649. "origin_chain_id": source_id,
  650. "origin_sequence_number": source_seq,
  651. "target_chain_id": target_id,
  652. "target_sequence_number": target_seq,
  653. }
  654. for (
  655. source_id,
  656. source_seq,
  657. target_id,
  658. target_seq,
  659. ) in chain_links.get_additions()
  660. ],
  661. )
  662. @staticmethod
  663. def _allocate_chain_ids(
  664. txn,
  665. db_pool: DatabasePool,
  666. event_chain_id_gen: SequenceGenerator,
  667. event_to_room_id: Dict[str, str],
  668. event_to_types: Dict[str, Tuple[str, str]],
  669. event_to_auth_chain: Dict[str, List[str]],
  670. events_to_calc_chain_id_for: Set[str],
  671. chain_map: Dict[str, Tuple[int, int]],
  672. ) -> Dict[str, Tuple[int, int]]:
  673. """Allocates, but does not persist, chain ID/sequence numbers for the
  674. events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
  675. for info on args)
  676. """
  677. # We now calculate the chain IDs/sequence numbers for the events. We do
  678. # this by looking at the chain ID and sequence number of any auth event
  679. # with the same type/state_key and incrementing the sequence number by
  680. # one. If there was no match or the chain ID/sequence number is already
  681. # taken we generate a new chain.
  682. #
  683. # We try to reduce the number of times that we hit the database by
  684. # batching up calls, to make this more efficient when persisting large
  685. # numbers of state events (e.g. during joins).
  686. #
  687. # We do this by:
  688. # 1. Calculating for each event which auth event will be used to
  689. # inherit the chain ID, i.e. converting the auth chain graph to a
  690. # tree that we can allocate chains on. We also keep track of which
  691. # existing chain IDs have been referenced.
  692. # 2. Fetching the max allocated sequence number for each referenced
  693. # existing chain ID, generating a map from chain ID to the max
  694. # allocated sequence number.
  695. # 3. Iterating over the tree and allocating a chain ID/seq no. to the
  696. # new event, by incrementing the sequence number from the
  697. # referenced event's chain ID/seq no. and checking that the
  698. # incremented sequence number hasn't already been allocated (by
  699. # looking in the map generated in the previous step). We generate a
  700. # new chain if the sequence number has already been allocated.
  701. #
  702. existing_chains = set() # type: Set[int]
  703. tree = [] # type: List[Tuple[str, Optional[str]]]
  704. # We need to do this in a topologically sorted order as we want to
  705. # generate chain IDs/sequence numbers of an event's auth events before
  706. # the event itself.
  707. for event_id in sorted_topologically(
  708. events_to_calc_chain_id_for, event_to_auth_chain
  709. ):
  710. for auth_id in event_to_auth_chain.get(event_id, []):
  711. if event_to_types.get(event_id) == event_to_types.get(auth_id):
  712. existing_chain_id = chain_map.get(auth_id)
  713. if existing_chain_id:
  714. existing_chains.add(existing_chain_id[0])
  715. tree.append((event_id, auth_id))
  716. break
  717. else:
  718. tree.append((event_id, None))
  719. # Fetch the current max sequence number for each existing referenced chain.
  720. sql = """
  721. SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
  722. WHERE %s
  723. GROUP BY chain_id
  724. """
  725. clause, args = make_in_list_sql_clause(
  726. db_pool.engine, "chain_id", existing_chains
  727. )
  728. txn.execute(sql % (clause,), args)
  729. chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
  730. # Allocate the new events chain ID/sequence numbers.
  731. #
  732. # To reduce the number of calls to the database we don't allocate a
  733. # chain ID number in the loop, instead we use a temporary `object()` for
  734. # each new chain ID. Once we've done the loop we generate the necessary
  735. # number of new chain IDs in one call, replacing all temporary
  736. # objects with real allocated chain IDs.
  737. unallocated_chain_ids = set() # type: Set[object]
  738. new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
  739. for event_id, auth_event_id in tree:
  740. # If we reference an auth_event_id we fetch the allocated chain ID,
  741. # either from the existing `chain_map` or the newly generated
  742. # `new_chain_tuples` map.
  743. existing_chain_id = None
  744. if auth_event_id:
  745. existing_chain_id = new_chain_tuples.get(auth_event_id)
  746. if not existing_chain_id:
  747. existing_chain_id = chain_map[auth_event_id]
  748. new_chain_tuple = None # type: Optional[Tuple[Any, int]]
  749. if existing_chain_id:
  750. # We found a chain ID/sequence number candidate, check its
  751. # not already taken.
  752. proposed_new_id = existing_chain_id[0]
  753. proposed_new_seq = existing_chain_id[1] + 1
  754. if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
  755. new_chain_tuple = (
  756. proposed_new_id,
  757. proposed_new_seq,
  758. )
  759. # If we need to start a new chain we allocate a temporary chain ID.
  760. if not new_chain_tuple:
  761. new_chain_tuple = (object(), 1)
  762. unallocated_chain_ids.add(new_chain_tuple[0])
  763. new_chain_tuples[event_id] = new_chain_tuple
  764. chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
  765. # Generate new chain IDs for all unallocated chain IDs.
  766. newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
  767. txn, len(unallocated_chain_ids)
  768. )
  769. # Map from potentially temporary chain ID to real chain ID
  770. chain_id_to_allocated_map = dict(
  771. zip(unallocated_chain_ids, newly_allocated_chain_ids)
  772. ) # type: Dict[Any, int]
  773. chain_id_to_allocated_map.update((c, c) for c in existing_chains)
  774. return {
  775. event_id: (chain_id_to_allocated_map[chain_id], seq)
  776. for event_id, (chain_id, seq) in new_chain_tuples.items()
  777. }
  778. def _persist_transaction_ids_txn(
  779. self,
  780. txn: LoggingTransaction,
  781. events_and_contexts: List[Tuple[EventBase, EventContext]],
  782. ):
  783. """Persist the mapping from transaction IDs to event IDs (if defined)."""
  784. to_insert = []
  785. for event, _ in events_and_contexts:
  786. token_id = getattr(event.internal_metadata, "token_id", None)
  787. txn_id = getattr(event.internal_metadata, "txn_id", None)
  788. if token_id and txn_id:
  789. to_insert.append(
  790. {
  791. "event_id": event.event_id,
  792. "room_id": event.room_id,
  793. "user_id": event.sender,
  794. "token_id": token_id,
  795. "txn_id": txn_id,
  796. "inserted_ts": self._clock.time_msec(),
  797. }
  798. )
  799. if to_insert:
  800. self.db_pool.simple_insert_many_txn(
  801. txn,
  802. table="event_txn_id",
  803. values=to_insert,
  804. )
  805. def _update_current_state_txn(
  806. self,
  807. txn: LoggingTransaction,
  808. state_delta_by_room: Dict[str, DeltaState],
  809. stream_id: int,
  810. ):
  811. for room_id, delta_state in state_delta_by_room.items():
  812. to_delete = delta_state.to_delete
  813. to_insert = delta_state.to_insert
  814. if delta_state.no_longer_in_room:
  815. # Server is no longer in the room so we delete the room from
  816. # current_state_events, being careful we've already updated the
  817. # rooms.room_version column (which gets populated in a
  818. # background task).
  819. self._upsert_room_version_txn(txn, room_id)
  820. # Before deleting we populate the current_state_delta_stream
  821. # so that async background tasks get told what happened.
  822. sql = """
  823. INSERT INTO current_state_delta_stream
  824. (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
  825. SELECT ?, ?, room_id, type, state_key, null, event_id
  826. FROM current_state_events
  827. WHERE room_id = ?
  828. """
  829. txn.execute(sql, (stream_id, self._instance_name, room_id))
  830. self.db_pool.simple_delete_txn(
  831. txn,
  832. table="current_state_events",
  833. keyvalues={"room_id": room_id},
  834. )
  835. else:
  836. # We're still in the room, so we update the current state as normal.
  837. # First we add entries to the current_state_delta_stream. We
  838. # do this before updating the current_state_events table so
  839. # that we can use it to calculate the `prev_event_id`. (This
  840. # allows us to not have to pull out the existing state
  841. # unnecessarily).
  842. #
  843. # The stream_id for the update is chosen to be the minimum of the stream_ids
  844. # for the batch of the events that we are persisting; that means we do not
  845. # end up in a situation where workers see events before the
  846. # current_state_delta updates.
  847. #
  848. sql = """
  849. INSERT INTO current_state_delta_stream
  850. (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
  851. SELECT ?, ?, ?, ?, ?, ?, (
  852. SELECT event_id FROM current_state_events
  853. WHERE room_id = ? AND type = ? AND state_key = ?
  854. )
  855. """
  856. txn.execute_batch(
  857. sql,
  858. (
  859. (
  860. stream_id,
  861. self._instance_name,
  862. room_id,
  863. etype,
  864. state_key,
  865. to_insert.get((etype, state_key)),
  866. room_id,
  867. etype,
  868. state_key,
  869. )
  870. for etype, state_key in itertools.chain(to_delete, to_insert)
  871. ),
  872. )
  873. # Now we actually update the current_state_events table
  874. txn.execute_batch(
  875. "DELETE FROM current_state_events"
  876. " WHERE room_id = ? AND type = ? AND state_key = ?",
  877. (
  878. (room_id, etype, state_key)
  879. for etype, state_key in itertools.chain(to_delete, to_insert)
  880. ),
  881. )
  882. # We include the membership in the current state table, hence we do
  883. # a lookup when we insert. This assumes that all events have already
  884. # been inserted into room_memberships.
  885. txn.execute_batch(
  886. """INSERT INTO current_state_events
  887. (room_id, type, state_key, event_id, membership)
  888. VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
  889. """,
  890. [
  891. (room_id, key[0], key[1], ev_id, ev_id)
  892. for key, ev_id in to_insert.items()
  893. ],
  894. )
  895. # We now update `local_current_membership`. We do this regardless
  896. # of whether we're still in the room or not to handle the case where
  897. # e.g. we just got banned (where we need to record that fact here).
  898. # Note: Do we really want to delete rows here (that we do not
  899. # subsequently reinsert below)? While technically correct it means
  900. # we have no record of the fact the user *was* a member of the
  901. # room but got, say, state reset out of it.
  902. if to_delete or to_insert:
  903. txn.execute_batch(
  904. "DELETE FROM local_current_membership"
  905. " WHERE room_id = ? AND user_id = ?",
  906. (
  907. (room_id, state_key)
  908. for etype, state_key in itertools.chain(to_delete, to_insert)
  909. if etype == EventTypes.Member and self.is_mine_id(state_key)
  910. ),
  911. )
  912. if to_insert:
  913. txn.execute_batch(
  914. """INSERT INTO local_current_membership
  915. (room_id, user_id, event_id, membership)
  916. VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
  917. """,
  918. [
  919. (room_id, key[1], ev_id, ev_id)
  920. for key, ev_id in to_insert.items()
  921. if key[0] == EventTypes.Member and self.is_mine_id(key[1])
  922. ],
  923. )
  924. txn.call_after(
  925. self.store._curr_state_delta_stream_cache.entity_has_changed,
  926. room_id,
  927. stream_id,
  928. )
  929. # Invalidate the various caches
  930. # Figure out the changes of membership to invalidate the
  931. # `get_rooms_for_user` cache.
  932. # We find out which membership events we may have deleted
  933. # and which we have added, then we invalidate the caches for all
  934. # those users.
  935. members_changed = {
  936. state_key
  937. for ev_type, state_key in itertools.chain(to_delete, to_insert)
  938. if ev_type == EventTypes.Member
  939. }
  940. for member in members_changed:
  941. txn.call_after(
  942. self.store.get_rooms_for_user_with_stream_ordering.invalidate,
  943. (member,),
  944. )
  945. self.store._invalidate_state_caches_and_stream(
  946. txn, room_id, members_changed
  947. )
  948. def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
  949. """Update the room version in the database based off current state
  950. events.
  951. This is used when we're about to delete current state and we want to
  952. ensure that the `rooms.room_version` column is up to date.
  953. """
  954. sql = """
  955. SELECT json FROM event_json
  956. INNER JOIN current_state_events USING (room_id, event_id)
  957. WHERE room_id = ? AND type = ? AND state_key = ?
  958. """
  959. txn.execute(sql, (room_id, EventTypes.Create, ""))
  960. row = txn.fetchone()
  961. if row:
  962. event_json = db_to_json(row[0])
  963. content = event_json.get("content", {})
  964. creator = content.get("creator")
  965. room_version_id = content.get("room_version", RoomVersions.V1.identifier)
  966. self.db_pool.simple_upsert_txn(
  967. txn,
  968. table="rooms",
  969. keyvalues={"room_id": room_id},
  970. values={"room_version": room_version_id},
  971. insertion_values={"is_public": False, "creator": creator},
  972. )
  973. def _update_forward_extremities_txn(
  974. self, txn, new_forward_extremities, max_stream_order
  975. ):
  976. for room_id, new_extrem in new_forward_extremities.items():
  977. self.db_pool.simple_delete_txn(
  978. txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
  979. )
  980. txn.call_after(
  981. self.store.get_latest_event_ids_in_room.invalidate, (room_id,)
  982. )
  983. self.db_pool.simple_insert_many_txn(
  984. txn,
  985. table="event_forward_extremities",
  986. values=[
  987. {"event_id": ev_id, "room_id": room_id}
  988. for room_id, new_extrem in new_forward_extremities.items()
  989. for ev_id in new_extrem
  990. ],
  991. )
  992. # We now insert into stream_ordering_to_exterm a mapping from room_id,
  993. # new stream_ordering to new forward extremeties in the room.
  994. # This allows us to later efficiently look up the forward extremeties
  995. # for a room before a given stream_ordering
  996. self.db_pool.simple_insert_many_txn(
  997. txn,
  998. table="stream_ordering_to_exterm",
  999. values=[
  1000. {
  1001. "room_id": room_id,
  1002. "event_id": event_id,
  1003. "stream_ordering": max_stream_order,
  1004. }
  1005. for room_id, new_extrem in new_forward_extremities.items()
  1006. for event_id in new_extrem
  1007. ],
  1008. )
  1009. @classmethod
  1010. def _filter_events_and_contexts_for_duplicates(
  1011. cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
  1012. ) -> List[Tuple[EventBase, EventContext]]:
  1013. """Ensure that we don't have the same event twice.
  1014. Pick the earliest non-outlier if there is one, else the earliest one.
  1015. Args:
  1016. events_and_contexts (list[(EventBase, EventContext)]):
  1017. Returns:
  1018. list[(EventBase, EventContext)]: filtered list
  1019. """
  1020. new_events_and_contexts = (
  1021. OrderedDict()
  1022. ) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
  1023. for event, context in events_and_contexts:
  1024. prev_event_context = new_events_and_contexts.get(event.event_id)
  1025. if prev_event_context:
  1026. if not event.internal_metadata.is_outlier():
  1027. if prev_event_context[0].internal_metadata.is_outlier():
  1028. # To ensure correct ordering we pop, as OrderedDict is
  1029. # ordered by first insertion.
  1030. new_events_and_contexts.pop(event.event_id, None)
  1031. new_events_and_contexts[event.event_id] = (event, context)
  1032. else:
  1033. new_events_and_contexts[event.event_id] = (event, context)
  1034. return list(new_events_and_contexts.values())
  1035. def _update_room_depths_txn(
  1036. self,
  1037. txn,
  1038. events_and_contexts: List[Tuple[EventBase, EventContext]],
  1039. backfilled: bool,
  1040. ):
  1041. """Update min_depth for each room
  1042. Args:
  1043. txn (twisted.enterprise.adbapi.Connection): db connection
  1044. events_and_contexts (list[(EventBase, EventContext)]): events
  1045. we are persisting
  1046. backfilled (bool): True if the events were backfilled
  1047. """
  1048. depth_updates = {} # type: Dict[str, int]
  1049. for event, context in events_and_contexts:
  1050. # Remove the any existing cache entries for the event_ids
  1051. txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
  1052. if not backfilled:
  1053. txn.call_after(
  1054. self.store._events_stream_cache.entity_has_changed,
  1055. event.room_id,
  1056. event.internal_metadata.stream_ordering,
  1057. )
  1058. if not event.internal_metadata.is_outlier() and not context.rejected:
  1059. depth_updates[event.room_id] = max(
  1060. event.depth, depth_updates.get(event.room_id, event.depth)
  1061. )
  1062. for room_id, depth in depth_updates.items():
  1063. self._update_min_depth_for_room_txn(txn, room_id, depth)
  1064. def _update_outliers_txn(self, txn, events_and_contexts):
  1065. """Update any outliers with new event info.
  1066. This turns outliers into ex-outliers (unless the new event was
  1067. rejected).
  1068. Args:
  1069. txn (twisted.enterprise.adbapi.Connection): db connection
  1070. events_and_contexts (list[(EventBase, EventContext)]): events
  1071. we are persisting
  1072. Returns:
  1073. list[(EventBase, EventContext)] new list, without events which
  1074. are already in the events table.
  1075. """
  1076. txn.execute(
  1077. "SELECT event_id, outlier FROM events WHERE event_id in (%s)"
  1078. % (",".join(["?"] * len(events_and_contexts)),),
  1079. [event.event_id for event, _ in events_and_contexts],
  1080. )
  1081. have_persisted = {event_id: outlier for event_id, outlier in txn}
  1082. to_remove = set()
  1083. for event, context in events_and_contexts:
  1084. if event.event_id not in have_persisted:
  1085. continue
  1086. to_remove.add(event)
  1087. if context.rejected:
  1088. # If the event is rejected then we don't care if the event
  1089. # was an outlier or not.
  1090. continue
  1091. outlier_persisted = have_persisted[event.event_id]
  1092. if not event.internal_metadata.is_outlier() and outlier_persisted:
  1093. # We received a copy of an event that we had already stored as
  1094. # an outlier in the database. We now have some state at that
  1095. # so we need to update the state_groups table with that state.
  1096. # insert into event_to_state_groups.
  1097. try:
  1098. self._store_event_state_mappings_txn(txn, ((event, context),))
  1099. except Exception:
  1100. logger.exception("")
  1101. raise
  1102. metadata_json = json_encoder.encode(event.internal_metadata.get_dict())
  1103. sql = "UPDATE event_json SET internal_metadata = ? WHERE event_id = ?"
  1104. txn.execute(sql, (metadata_json, event.event_id))
  1105. # Add an entry to the ex_outlier_stream table to replicate the
  1106. # change in outlier status to our workers.
  1107. stream_order = event.internal_metadata.stream_ordering
  1108. state_group_id = context.state_group
  1109. self.db_pool.simple_insert_txn(
  1110. txn,
  1111. table="ex_outlier_stream",
  1112. values={
  1113. "event_stream_ordering": stream_order,
  1114. "event_id": event.event_id,
  1115. "state_group": state_group_id,
  1116. "instance_name": self._instance_name,
  1117. },
  1118. )
  1119. sql = "UPDATE events SET outlier = ? WHERE event_id = ?"
  1120. txn.execute(sql, (False, event.event_id))
  1121. # Update the event_backward_extremities table now that this
  1122. # event isn't an outlier any more.
  1123. self._update_backward_extremeties(txn, [event])
  1124. return [ec for ec in events_and_contexts if ec[0] not in to_remove]
  1125. def _store_event_txn(self, txn, events_and_contexts):
  1126. """Insert new events into the event, event_json, redaction and
  1127. state_events tables.
  1128. Args:
  1129. txn (twisted.enterprise.adbapi.Connection): db connection
  1130. events_and_contexts (list[(EventBase, EventContext)]): events
  1131. we are persisting
  1132. """
  1133. if not events_and_contexts:
  1134. # nothing to do here
  1135. return
  1136. def event_dict(event):
  1137. d = event.get_dict()
  1138. d.pop("redacted", None)
  1139. d.pop("redacted_because", None)
  1140. return d
  1141. self.db_pool.simple_insert_many_txn(
  1142. txn,
  1143. table="event_json",
  1144. values=[
  1145. {
  1146. "event_id": event.event_id,
  1147. "room_id": event.room_id,
  1148. "internal_metadata": json_encoder.encode(
  1149. event.internal_metadata.get_dict()
  1150. ),
  1151. "json": json_encoder.encode(event_dict(event)),
  1152. "format_version": event.format_version,
  1153. }
  1154. for event, _ in events_and_contexts
  1155. ],
  1156. )
  1157. self.db_pool.simple_insert_many_txn(
  1158. txn,
  1159. table="events",
  1160. values=[
  1161. {
  1162. "instance_name": self._instance_name,
  1163. "stream_ordering": event.internal_metadata.stream_ordering,
  1164. "topological_ordering": event.depth,
  1165. "depth": event.depth,
  1166. "event_id": event.event_id,
  1167. "room_id": event.room_id,
  1168. "type": event.type,
  1169. "processed": True,
  1170. "outlier": event.internal_metadata.is_outlier(),
  1171. "origin_server_ts": int(event.origin_server_ts),
  1172. "received_ts": self._clock.time_msec(),
  1173. "sender": event.sender,
  1174. "contains_url": (
  1175. "url" in event.content and isinstance(event.content["url"], str)
  1176. ),
  1177. }
  1178. for event, _ in events_and_contexts
  1179. ],
  1180. )
  1181. for event, _ in events_and_contexts:
  1182. if not event.internal_metadata.is_redacted():
  1183. # If we're persisting an unredacted event we go and ensure
  1184. # that we mark any redactions that reference this event as
  1185. # requiring censoring.
  1186. self.db_pool.simple_update_txn(
  1187. txn,
  1188. table="redactions",
  1189. keyvalues={"redacts": event.event_id},
  1190. updatevalues={"have_censored": False},
  1191. )
  1192. state_events_and_contexts = [
  1193. ec for ec in events_and_contexts if ec[0].is_state()
  1194. ]
  1195. state_values = []
  1196. for event, context in state_events_and_contexts:
  1197. vals = {
  1198. "event_id": event.event_id,
  1199. "room_id": event.room_id,
  1200. "type": event.type,
  1201. "state_key": event.state_key,
  1202. }
  1203. # TODO: How does this work with backfilling?
  1204. if hasattr(event, "replaces_state"):
  1205. vals["prev_state"] = event.replaces_state
  1206. state_values.append(vals)
  1207. self.db_pool.simple_insert_many_txn(
  1208. txn, table="state_events", values=state_values
  1209. )
  1210. def _store_rejected_events_txn(self, txn, events_and_contexts):
  1211. """Add rows to the 'rejections' table for received events which were
  1212. rejected
  1213. Args:
  1214. txn (twisted.enterprise.adbapi.Connection): db connection
  1215. events_and_contexts (list[(EventBase, EventContext)]): events
  1216. we are persisting
  1217. Returns:
  1218. list[(EventBase, EventContext)] new list, without the rejected
  1219. events.
  1220. """
  1221. # Remove the rejected events from the list now that we've added them
  1222. # to the events table and the events_json table.
  1223. to_remove = set()
  1224. for event, context in events_and_contexts:
  1225. if context.rejected:
  1226. # Insert the event_id into the rejections table
  1227. self._store_rejections_txn(txn, event.event_id, context.rejected)
  1228. to_remove.add(event)
  1229. return [ec for ec in events_and_contexts if ec[0] not in to_remove]
  1230. def _update_metadata_tables_txn(
  1231. self, txn, events_and_contexts, all_events_and_contexts, backfilled
  1232. ):
  1233. """Update all the miscellaneous tables for new events
  1234. Args:
  1235. txn (twisted.enterprise.adbapi.Connection): db connection
  1236. events_and_contexts (list[(EventBase, EventContext)]): events
  1237. we are persisting
  1238. all_events_and_contexts (list[(EventBase, EventContext)]): all
  1239. events that we were going to persist. This includes events
  1240. we've already persisted, etc, that wouldn't appear in
  1241. events_and_context.
  1242. backfilled (bool): True if the events were backfilled
  1243. """
  1244. # Insert all the push actions into the event_push_actions table.
  1245. self._set_push_actions_for_event_and_users_txn(
  1246. txn,
  1247. events_and_contexts=events_and_contexts,
  1248. all_events_and_contexts=all_events_and_contexts,
  1249. )
  1250. if not events_and_contexts:
  1251. # nothing to do here
  1252. return
  1253. for event, context in events_and_contexts:
  1254. if event.type == EventTypes.Redaction and event.redacts is not None:
  1255. # Remove the entries in the event_push_actions table for the
  1256. # redacted event.
  1257. self._remove_push_actions_for_event_id_txn(
  1258. txn, event.room_id, event.redacts
  1259. )
  1260. # Remove from relations table.
  1261. self._handle_redaction(txn, event.redacts)
  1262. # Update the event_forward_extremities, event_backward_extremities and
  1263. # event_edges tables.
  1264. self._handle_mult_prev_events(
  1265. txn, events=[event for event, _ in events_and_contexts]
  1266. )
  1267. for event, _ in events_and_contexts:
  1268. if event.type == EventTypes.Name:
  1269. # Insert into the event_search table.
  1270. self._store_room_name_txn(txn, event)
  1271. elif event.type == EventTypes.Topic:
  1272. # Insert into the event_search table.
  1273. self._store_room_topic_txn(txn, event)
  1274. elif event.type == EventTypes.Message:
  1275. # Insert into the event_search table.
  1276. self._store_room_message_txn(txn, event)
  1277. elif event.type == EventTypes.Redaction and event.redacts is not None:
  1278. # Insert into the redactions table.
  1279. self._store_redaction(txn, event)
  1280. elif event.type == EventTypes.Retention:
  1281. # Update the room_retention table.
  1282. self._store_retention_policy_for_room_txn(txn, event)
  1283. self._handle_event_relations(txn, event)
  1284. # Store the labels for this event.
  1285. labels = event.content.get(EventContentFields.LABELS)
  1286. if labels:
  1287. self.insert_labels_for_event_txn(
  1288. txn, event.event_id, labels, event.room_id, event.depth
  1289. )
  1290. if self._ephemeral_messages_enabled:
  1291. # If there's an expiry timestamp on the event, store it.
  1292. expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
  1293. if isinstance(expiry_ts, int) and not event.is_state():
  1294. self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
  1295. # Insert into the room_memberships table.
  1296. self._store_room_members_txn(
  1297. txn,
  1298. [
  1299. event
  1300. for event, _ in events_and_contexts
  1301. if event.type == EventTypes.Member
  1302. ],
  1303. backfilled=backfilled,
  1304. )
  1305. # Insert event_reference_hashes table.
  1306. self._store_event_reference_hashes_txn(
  1307. txn, [event for event, _ in events_and_contexts]
  1308. )
  1309. # Prefill the event cache
  1310. self._add_to_cache(txn, events_and_contexts)
  1311. def _add_to_cache(self, txn, events_and_contexts):
  1312. to_prefill = []
  1313. rows = []
  1314. N = 200
  1315. for i in range(0, len(events_and_contexts), N):
  1316. ev_map = {e[0].event_id: e[0] for e in events_and_contexts[i : i + N]}
  1317. if not ev_map:
  1318. break
  1319. sql = (
  1320. "SELECT "
  1321. " e.event_id as event_id, "
  1322. " r.redacts as redacts,"
  1323. " rej.event_id as rejects "
  1324. " FROM events as e"
  1325. " LEFT JOIN rejections as rej USING (event_id)"
  1326. " LEFT JOIN redactions as r ON e.event_id = r.redacts"
  1327. " WHERE "
  1328. )
  1329. clause, args = make_in_list_sql_clause(
  1330. self.database_engine, "e.event_id", list(ev_map)
  1331. )
  1332. txn.execute(sql + clause, args)
  1333. rows = self.db_pool.cursor_to_dict(txn)
  1334. for row in rows:
  1335. event = ev_map[row["event_id"]]
  1336. if not row["rejects"] and not row["redacts"]:
  1337. to_prefill.append(
  1338. _EventCacheEntry(event=event, redacted_event=None)
  1339. )
  1340. def prefill():
  1341. for cache_entry in to_prefill:
  1342. self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
  1343. txn.call_after(prefill)
  1344. def _store_redaction(self, txn, event):
  1345. # invalidate the cache for the redacted event
  1346. txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
  1347. self.db_pool.simple_insert_txn(
  1348. txn,
  1349. table="redactions",
  1350. values={
  1351. "event_id": event.event_id,
  1352. "redacts": event.redacts,
  1353. "received_ts": self._clock.time_msec(),
  1354. },
  1355. )
  1356. def insert_labels_for_event_txn(
  1357. self, txn, event_id, labels, room_id, topological_ordering
  1358. ):
  1359. """Store the mapping between an event's ID and its labels, with one row per
  1360. (event_id, label) tuple.
  1361. Args:
  1362. txn (LoggingTransaction): The transaction to execute.
  1363. event_id (str): The event's ID.
  1364. labels (list[str]): A list of text labels.
  1365. room_id (str): The ID of the room the event was sent to.
  1366. topological_ordering (int): The position of the event in the room's topology.
  1367. """
  1368. return self.db_pool.simple_insert_many_txn(
  1369. txn=txn,
  1370. table="event_labels",
  1371. values=[
  1372. {
  1373. "event_id": event_id,
  1374. "label": label,
  1375. "room_id": room_id,
  1376. "topological_ordering": topological_ordering,
  1377. }
  1378. for label in labels
  1379. ],
  1380. )
  1381. def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
  1382. """Save the expiry timestamp associated with a given event ID.
  1383. Args:
  1384. txn (LoggingTransaction): The database transaction to use.
  1385. event_id (str): The event ID the expiry timestamp is associated with.
  1386. expiry_ts (int): The timestamp at which to expire (delete) the event.
  1387. """
  1388. return self.db_pool.simple_insert_txn(
  1389. txn=txn,
  1390. table="event_expiry",
  1391. values={"event_id": event_id, "expiry_ts": expiry_ts},
  1392. )
  1393. def _store_event_reference_hashes_txn(self, txn, events):
  1394. """Store a hash for a PDU
  1395. Args:
  1396. txn (cursor):
  1397. events (list): list of Events.
  1398. """
  1399. vals = []
  1400. for event in events:
  1401. ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
  1402. vals.append(
  1403. {
  1404. "event_id": event.event_id,
  1405. "algorithm": ref_alg,
  1406. "hash": memoryview(ref_hash_bytes),
  1407. }
  1408. )
  1409. self.db_pool.simple_insert_many_txn(
  1410. txn, table="event_reference_hashes", values=vals
  1411. )
  1412. def _store_room_members_txn(self, txn, events, backfilled):
  1413. """Store a room member in the database."""
  1414. def str_or_none(val: Any) -> Optional[str]:
  1415. return val if isinstance(val, str) else None
  1416. self.db_pool.simple_insert_many_txn(
  1417. txn,
  1418. table="room_memberships",
  1419. values=[
  1420. {
  1421. "event_id": event.event_id,
  1422. "user_id": event.state_key,
  1423. "sender": event.user_id,
  1424. "room_id": event.room_id,
  1425. "membership": event.membership,
  1426. "display_name": str_or_none(event.content.get("displayname")),
  1427. "avatar_url": str_or_none(event.content.get("avatar_url")),
  1428. }
  1429. for event in events
  1430. ],
  1431. )
  1432. for event in events:
  1433. txn.call_after(
  1434. self.store._membership_stream_cache.entity_has_changed,
  1435. event.state_key,
  1436. event.internal_metadata.stream_ordering,
  1437. )
  1438. txn.call_after(
  1439. self.store.get_invited_rooms_for_local_user.invalidate,
  1440. (event.state_key,),
  1441. )
  1442. # We update the local_current_membership table only if the event is
  1443. # "current", i.e., its something that has just happened.
  1444. #
  1445. # This will usually get updated by the `current_state_events` handling,
  1446. # unless its an outlier, and an outlier is only "current" if it's an "out of
  1447. # band membership", like a remote invite or a rejection of a remote invite.
  1448. if (
  1449. self.is_mine_id(event.state_key)
  1450. and not backfilled
  1451. and event.internal_metadata.is_outlier()
  1452. and event.internal_metadata.is_out_of_band_membership()
  1453. ):
  1454. self.db_pool.simple_upsert_txn(
  1455. txn,
  1456. table="local_current_membership",
  1457. keyvalues={"room_id": event.room_id, "user_id": event.state_key},
  1458. values={
  1459. "event_id": event.event_id,
  1460. "membership": event.membership,
  1461. },
  1462. )
  1463. def _handle_event_relations(self, txn, event):
  1464. """Handles inserting relation data during peristence of events
  1465. Args:
  1466. txn
  1467. event (EventBase)
  1468. """
  1469. relation = event.content.get("m.relates_to")
  1470. if not relation:
  1471. # No relations
  1472. return
  1473. rel_type = relation.get("rel_type")
  1474. if rel_type not in (
  1475. RelationTypes.ANNOTATION,
  1476. RelationTypes.REFERENCE,
  1477. RelationTypes.REPLACE,
  1478. ):
  1479. # Unknown relation type
  1480. return
  1481. parent_id = relation.get("event_id")
  1482. if not parent_id:
  1483. # Invalid relation
  1484. return
  1485. aggregation_key = relation.get("key")
  1486. self.db_pool.simple_insert_txn(
  1487. txn,
  1488. table="event_relations",
  1489. values={
  1490. "event_id": event.event_id,
  1491. "relates_to_id": parent_id,
  1492. "relation_type": rel_type,
  1493. "aggregation_key": aggregation_key,
  1494. },
  1495. )
  1496. txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,))
  1497. txn.call_after(
  1498. self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
  1499. )
  1500. if rel_type == RelationTypes.REPLACE:
  1501. txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
  1502. def _handle_redaction(self, txn, redacted_event_id):
  1503. """Handles receiving a redaction and checking whether we need to remove
  1504. any redacted relations from the database.
  1505. Args:
  1506. txn
  1507. redacted_event_id (str): The event that was redacted.
  1508. """
  1509. self.db_pool.simple_delete_txn(
  1510. txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
  1511. )
  1512. def _store_room_topic_txn(self, txn, event):
  1513. if hasattr(event, "content") and "topic" in event.content:
  1514. self.store_event_search_txn(
  1515. txn, event, "content.topic", event.content["topic"]
  1516. )
  1517. def _store_room_name_txn(self, txn, event):
  1518. if hasattr(event, "content") and "name" in event.content:
  1519. self.store_event_search_txn(
  1520. txn, event, "content.name", event.content["name"]
  1521. )
  1522. def _store_room_message_txn(self, txn, event):
  1523. if hasattr(event, "content") and "body" in event.content:
  1524. self.store_event_search_txn(
  1525. txn, event, "content.body", event.content["body"]
  1526. )
  1527. def _store_retention_policy_for_room_txn(self, txn, event):
  1528. if not event.is_state():
  1529. logger.debug("Ignoring non-state m.room.retention event")
  1530. return
  1531. if hasattr(event, "content") and (
  1532. "min_lifetime" in event.content or "max_lifetime" in event.content
  1533. ):
  1534. if (
  1535. "min_lifetime" in event.content
  1536. and not isinstance(event.content.get("min_lifetime"), int)
  1537. ) or (
  1538. "max_lifetime" in event.content
  1539. and not isinstance(event.content.get("max_lifetime"), int)
  1540. ):
  1541. # Ignore the event if one of the value isn't an integer.
  1542. return
  1543. self.db_pool.simple_insert_txn(
  1544. txn=txn,
  1545. table="room_retention",
  1546. values={
  1547. "room_id": event.room_id,
  1548. "event_id": event.event_id,
  1549. "min_lifetime": event.content.get("min_lifetime"),
  1550. "max_lifetime": event.content.get("max_lifetime"),
  1551. },
  1552. )
  1553. self.store._invalidate_cache_and_stream(
  1554. txn, self.store.get_retention_policy_for_room, (event.room_id,)
  1555. )
  1556. def store_event_search_txn(self, txn, event, key, value):
  1557. """Add event to the search table
  1558. Args:
  1559. txn (cursor):
  1560. event (EventBase):
  1561. key (str):
  1562. value (str):
  1563. """
  1564. self.store.store_search_entries_txn(
  1565. txn,
  1566. (
  1567. SearchEntry(
  1568. key=key,
  1569. value=value,
  1570. event_id=event.event_id,
  1571. room_id=event.room_id,
  1572. stream_ordering=event.internal_metadata.stream_ordering,
  1573. origin_server_ts=event.origin_server_ts,
  1574. ),
  1575. ),
  1576. )
  1577. def _set_push_actions_for_event_and_users_txn(
  1578. self, txn, events_and_contexts, all_events_and_contexts
  1579. ):
  1580. """Handles moving push actions from staging table to main
  1581. event_push_actions table for all events in `events_and_contexts`.
  1582. Also ensures that all events in `all_events_and_contexts` are removed
  1583. from the push action staging area.
  1584. Args:
  1585. events_and_contexts (list[(EventBase, EventContext)]): events
  1586. we are persisting
  1587. all_events_and_contexts (list[(EventBase, EventContext)]): all
  1588. events that we were going to persist. This includes events
  1589. we've already persisted, etc, that wouldn't appear in
  1590. events_and_context.
  1591. """
  1592. sql = """
  1593. INSERT INTO event_push_actions (
  1594. room_id, event_id, user_id, actions, stream_ordering,
  1595. topological_ordering, notif, highlight, unread
  1596. )
  1597. SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
  1598. FROM event_push_actions_staging
  1599. WHERE event_id = ?
  1600. """
  1601. if events_and_contexts:
  1602. txn.execute_batch(
  1603. sql,
  1604. (
  1605. (
  1606. event.room_id,
  1607. event.internal_metadata.stream_ordering,
  1608. event.depth,
  1609. event.event_id,
  1610. )
  1611. for event, _ in events_and_contexts
  1612. ),
  1613. )
  1614. for event, _ in events_and_contexts:
  1615. user_ids = self.db_pool.simple_select_onecol_txn(
  1616. txn,
  1617. table="event_push_actions_staging",
  1618. keyvalues={"event_id": event.event_id},
  1619. retcol="user_id",
  1620. )
  1621. for uid in user_ids:
  1622. txn.call_after(
  1623. self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
  1624. (event.room_id, uid),
  1625. )
  1626. # Now we delete the staging area for *all* events that were being
  1627. # persisted.
  1628. txn.execute_batch(
  1629. "DELETE FROM event_push_actions_staging WHERE event_id = ?",
  1630. ((event.event_id,) for event, _ in all_events_and_contexts),
  1631. )
  1632. def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
  1633. # Sad that we have to blow away the cache for the whole room here
  1634. txn.call_after(
  1635. self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many,
  1636. (room_id,),
  1637. )
  1638. txn.execute(
  1639. "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?",
  1640. (room_id, event_id),
  1641. )
  1642. def _store_rejections_txn(self, txn, event_id, reason):
  1643. self.db_pool.simple_insert_txn(
  1644. txn,
  1645. table="rejections",
  1646. values={
  1647. "event_id": event_id,
  1648. "reason": reason,
  1649. "last_check": self._clock.time_msec(),
  1650. },
  1651. )
  1652. def _store_event_state_mappings_txn(
  1653. self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
  1654. ):
  1655. state_groups = {}
  1656. for event, context in events_and_contexts:
  1657. if event.internal_metadata.is_outlier():
  1658. continue
  1659. # if the event was rejected, just give it the same state as its
  1660. # predecessor.
  1661. if context.rejected:
  1662. state_groups[event.event_id] = context.state_group_before_event
  1663. continue
  1664. state_groups[event.event_id] = context.state_group
  1665. self.db_pool.simple_insert_many_txn(
  1666. txn,
  1667. table="event_to_state_groups",
  1668. values=[
  1669. {"state_group": state_group_id, "event_id": event_id}
  1670. for event_id, state_group_id in state_groups.items()
  1671. ],
  1672. )
  1673. for event_id, state_group_id in state_groups.items():
  1674. txn.call_after(
  1675. self.store._get_state_group_for_event.prefill,
  1676. (event_id,),
  1677. state_group_id,
  1678. )
  1679. def _update_min_depth_for_room_txn(self, txn, room_id, depth):
  1680. min_depth = self.store._get_min_depth_interaction(txn, room_id)
  1681. if min_depth is not None and depth >= min_depth:
  1682. return
  1683. self.db_pool.simple_upsert_txn(
  1684. txn,
  1685. table="room_depth",
  1686. keyvalues={"room_id": room_id},
  1687. values={"min_depth": depth},
  1688. )
  1689. def _handle_mult_prev_events(self, txn, events):
  1690. """
  1691. For the given event, update the event edges table and forward and
  1692. backward extremities tables.
  1693. """
  1694. self.db_pool.simple_insert_many_txn(
  1695. txn,
  1696. table="event_edges",
  1697. values=[
  1698. {
  1699. "event_id": ev.event_id,
  1700. "prev_event_id": e_id,
  1701. "room_id": ev.room_id,
  1702. "is_state": False,
  1703. }
  1704. for ev in events
  1705. for e_id in ev.prev_event_ids()
  1706. ],
  1707. )
  1708. self._update_backward_extremeties(txn, events)
  1709. def _update_backward_extremeties(self, txn, events):
  1710. """Updates the event_backward_extremities tables based on the new/updated
  1711. events being persisted.
  1712. This is called for new events *and* for events that were outliers, but
  1713. are now being persisted as non-outliers.
  1714. Forward extremities are handled when we first start persisting the events.
  1715. """
  1716. events_by_room = {} # type: Dict[str, List[EventBase]]
  1717. for ev in events:
  1718. events_by_room.setdefault(ev.room_id, []).append(ev)
  1719. query = (
  1720. "INSERT INTO event_backward_extremities (event_id, room_id)"
  1721. " SELECT ?, ? WHERE NOT EXISTS ("
  1722. " SELECT 1 FROM event_backward_extremities"
  1723. " WHERE event_id = ? AND room_id = ?"
  1724. " )"
  1725. " AND NOT EXISTS ("
  1726. " SELECT 1 FROM events WHERE event_id = ? AND room_id = ? "
  1727. " AND outlier = ?"
  1728. " )"
  1729. )
  1730. txn.execute_batch(
  1731. query,
  1732. [
  1733. (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
  1734. for ev in events
  1735. for e_id in ev.prev_event_ids()
  1736. if not ev.internal_metadata.is_outlier()
  1737. ],
  1738. )
  1739. query = (
  1740. "DELETE FROM event_backward_extremities"
  1741. " WHERE event_id = ? AND room_id = ?"
  1742. )
  1743. txn.execute_batch(
  1744. query,
  1745. [
  1746. (ev.event_id, ev.room_id)
  1747. for ev in events
  1748. if not ev.internal_metadata.is_outlier()
  1749. ],
  1750. )
  1751. @attr.s(slots=True)
  1752. class _LinkMap:
  1753. """A helper type for tracking links between chains."""
  1754. # Stores the set of links as nested maps: source chain ID -> target chain ID
  1755. # -> source sequence number -> target sequence number.
  1756. maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
  1757. # Stores the links that have been added (with new set to true), as tuples of
  1758. # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
  1759. additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
  1760. def add_link(
  1761. self,
  1762. src_tuple: Tuple[int, int],
  1763. target_tuple: Tuple[int, int],
  1764. new: bool = True,
  1765. ) -> bool:
  1766. """Add a new link between two chains, ensuring no redundant links are added.
  1767. New links should be added in topological order.
  1768. Args:
  1769. src_tuple: The chain ID/sequence number of the source of the link.
  1770. target_tuple: The chain ID/sequence number of the target of the link.
  1771. new: Whether this is a "new" link, i.e. should it be returned
  1772. by `get_additions`.
  1773. Returns:
  1774. True if a link was added, false if the given link was dropped as redundant
  1775. """
  1776. src_chain, src_seq = src_tuple
  1777. target_chain, target_seq = target_tuple
  1778. current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
  1779. assert src_chain != target_chain
  1780. if new:
  1781. # Check if the new link is redundant
  1782. for current_seq_src, current_seq_target in current_links.items():
  1783. # If a link "crosses" another link then its redundant. For example
  1784. # in the following link 1 (L1) is redundant, as any event reachable
  1785. # via L1 is *also* reachable via L2.
  1786. #
  1787. # Chain A Chain B
  1788. # | |
  1789. # L1 |------ |
  1790. # | | |
  1791. # L2 |---- | -->|
  1792. # | | |
  1793. # | |--->|
  1794. # | |
  1795. # | |
  1796. #
  1797. # So we only need to keep links which *do not* cross, i.e. links
  1798. # that both start and end above or below an existing link.
  1799. #
  1800. # Note, since we add links in topological ordering we should never
  1801. # see `src_seq` less than `current_seq_src`.
  1802. if current_seq_src <= src_seq and target_seq <= current_seq_target:
  1803. # This new link is redundant, nothing to do.
  1804. return False
  1805. self.additions.add((src_chain, src_seq, target_chain, target_seq))
  1806. current_links[src_seq] = target_seq
  1807. return True
  1808. def get_links_from(
  1809. self, src_tuple: Tuple[int, int]
  1810. ) -> Generator[Tuple[int, int], None, None]:
  1811. """Gets the chains reachable from the given chain/sequence number.
  1812. Yields:
  1813. The chain ID and sequence number the link points to.
  1814. """
  1815. src_chain, src_seq = src_tuple
  1816. for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
  1817. for link_src_seq, target_seq in sequence_numbers.items():
  1818. if link_src_seq <= src_seq:
  1819. yield target_id, target_seq
  1820. def get_links_between(
  1821. self, source_chain: int, target_chain: int
  1822. ) -> Generator[Tuple[int, int], None, None]:
  1823. """Gets the links between two chains.
  1824. Yields:
  1825. The source and target sequence numbers.
  1826. """
  1827. yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
  1828. def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
  1829. """Gets any newly added links.
  1830. Yields:
  1831. The source chain ID/sequence number and target chain ID/sequence number
  1832. """
  1833. for src_chain, src_seq, target_chain, _ in self.additions:
  1834. target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
  1835. if target_seq is not None:
  1836. yield (src_chain, src_seq, target_chain, target_seq)
  1837. def exists_path_from(
  1838. self,
  1839. src_tuple: Tuple[int, int],
  1840. target_tuple: Tuple[int, int],
  1841. ) -> bool:
  1842. """Checks if there is a path between the source chain ID/sequence and
  1843. target chain ID/sequence.
  1844. """
  1845. src_chain, src_seq = src_tuple
  1846. target_chain, target_seq = target_tuple
  1847. if src_chain == target_chain:
  1848. return target_seq <= src_seq
  1849. links = self.get_links_between(src_chain, target_chain)
  1850. for link_start_seq, link_end_seq in links:
  1851. if link_start_seq <= src_seq and target_seq <= link_end_seq:
  1852. return True
  1853. return False