|
@@ -17,7 +17,17 @@
|
|
|
import itertools
|
|
|
import logging
|
|
|
from collections import OrderedDict, namedtuple
|
|
|
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
|
|
|
+from typing import (
|
|
|
+ TYPE_CHECKING,
|
|
|
+ Any,
|
|
|
+ Dict,
|
|
|
+ Generator,
|
|
|
+ Iterable,
|
|
|
+ List,
|
|
|
+ Optional,
|
|
|
+ Set,
|
|
|
+ Tuple,
|
|
|
+)
|
|
|
|
|
|
import attr
|
|
|
from prometheus_client import Counter
|
|
@@ -33,9 +43,10 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
|
|
from synapse.storage.database import DatabasePool, LoggingTransaction
|
|
|
from synapse.storage.databases.main.search import SearchEntry
|
|
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
|
|
+from synapse.storage.util.sequence import build_sequence_generator
|
|
|
from synapse.types import StateMap, get_domain_from_id
|
|
|
from synapse.util import json_encoder
|
|
|
-from synapse.util.iterutils import batch_iter
|
|
|
+from synapse.util.iterutils import batch_iter, sorted_topologically
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from synapse.server import HomeServer
|
|
@@ -89,6 +100,14 @@ class PersistEventsStore:
|
|
|
self._clock = hs.get_clock()
|
|
|
self._instance_name = hs.get_instance_name()
|
|
|
|
|
|
+ def get_chain_id_txn(txn):
|
|
|
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
|
|
|
+ return txn.fetchone()[0]
|
|
|
+
|
|
|
+ self._event_chain_id_gen = build_sequence_generator(
|
|
|
+ db.engine, get_chain_id_txn, "event_auth_chain_id"
|
|
|
+ )
|
|
|
+
|
|
|
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
|
|
self.is_mine_id = hs.is_mine_id
|
|
|
|
|
@@ -366,6 +385,36 @@ class PersistEventsStore:
|
|
|
# Insert into event_to_state_groups.
|
|
|
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
|
|
|
|
|
+ self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts])
|
|
|
+
|
|
|
+ # _store_rejected_events_txn filters out any events which were
|
|
|
+ # rejected, and returns the filtered list.
|
|
|
+ events_and_contexts = self._store_rejected_events_txn(
|
|
|
+ txn, events_and_contexts=events_and_contexts
|
|
|
+ )
|
|
|
+
|
|
|
+ # From this point onwards the events are only ones that weren't
|
|
|
+ # rejected.
|
|
|
+
|
|
|
+ self._update_metadata_tables_txn(
|
|
|
+ txn,
|
|
|
+ events_and_contexts=events_and_contexts,
|
|
|
+ all_events_and_contexts=all_events_and_contexts,
|
|
|
+ backfilled=backfilled,
|
|
|
+ )
|
|
|
+
|
|
|
+ # We call this last as it assumes we've inserted the events into
|
|
|
+ # room_memberships, where applicable.
|
|
|
+ self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
|
|
|
+
|
|
|
+ def _persist_event_auth_chain_txn(
|
|
|
+ self, txn: LoggingTransaction, events: List[EventBase],
|
|
|
+ ) -> None:
|
|
|
+
|
|
|
+ # We only care about state events, so this if there are no state events.
|
|
|
+ if not any(e.is_state() for e in events):
|
|
|
+ return
|
|
|
+
|
|
|
# We want to store event_auth mappings for rejected events, as they're
|
|
|
# used in state res v2.
|
|
|
# This is only necessary if the rejected event appears in an accepted
|
|
@@ -381,31 +430,357 @@ class PersistEventsStore:
|
|
|
"room_id": event.room_id,
|
|
|
"auth_id": auth_id,
|
|
|
}
|
|
|
- for event, _ in events_and_contexts
|
|
|
+ for event in events
|
|
|
for auth_id in event.auth_event_ids()
|
|
|
if event.is_state()
|
|
|
],
|
|
|
)
|
|
|
|
|
|
- # _store_rejected_events_txn filters out any events which were
|
|
|
- # rejected, and returns the filtered list.
|
|
|
- events_and_contexts = self._store_rejected_events_txn(
|
|
|
- txn, events_and_contexts=events_and_contexts
|
|
|
+ # We now calculate chain ID/sequence numbers for any state events we're
|
|
|
+ # persisting. We ignore out of band memberships as we're not in the room
|
|
|
+ # and won't have their auth chain (we'll fix it up later if we join the
|
|
|
+ # room).
|
|
|
+ #
|
|
|
+ # See: docs/auth_chain_difference_algorithm.md
|
|
|
+
|
|
|
+ # We ignore legacy rooms that we aren't filling the chain cover index
|
|
|
+ # for.
|
|
|
+ rows = self.db_pool.simple_select_many_txn(
|
|
|
+ txn,
|
|
|
+ table="rooms",
|
|
|
+ column="room_id",
|
|
|
+ iterable={event.room_id for event in events if event.is_state()},
|
|
|
+ keyvalues={},
|
|
|
+ retcols=("room_id", "has_auth_chain_index"),
|
|
|
)
|
|
|
+ rooms_using_chain_index = {
|
|
|
+ row["room_id"] for row in rows if row["has_auth_chain_index"]
|
|
|
+ }
|
|
|
|
|
|
- # From this point onwards the events are only ones that weren't
|
|
|
- # rejected.
|
|
|
+ state_events = {
|
|
|
+ event.event_id: event
|
|
|
+ for event in events
|
|
|
+ if event.is_state() and event.room_id in rooms_using_chain_index
|
|
|
+ }
|
|
|
|
|
|
- self._update_metadata_tables_txn(
|
|
|
+ if not state_events:
|
|
|
+ return
|
|
|
+
|
|
|
+ # Map from event ID to chain ID/sequence number.
|
|
|
+ chain_map = {} # type: Dict[str, Tuple[int, int]]
|
|
|
+
|
|
|
+ # We need to know the type/state_key and auth events of the events we're
|
|
|
+ # calculating chain IDs for. We don't rely on having the full Event
|
|
|
+ # instances as we'll potentially be pulling more events from the DB and
|
|
|
+ # we don't need the overhead of fetching/parsing the full event JSON.
|
|
|
+ event_to_types = {
|
|
|
+ e.event_id: (e.type, e.state_key) for e in state_events.values()
|
|
|
+ }
|
|
|
+ event_to_auth_chain = {
|
|
|
+ e.event_id: e.auth_event_ids() for e in state_events.values()
|
|
|
+ }
|
|
|
+
|
|
|
+ # Set of event IDs to calculate chain ID/seq numbers for.
|
|
|
+ events_to_calc_chain_id_for = set(state_events)
|
|
|
+
|
|
|
+ # We check if there are any events that need to be handled in the rooms
|
|
|
+ # we're looking at. These should just be out of band memberships, where
|
|
|
+ # we didn't have the auth chain when we first persisted.
|
|
|
+ rows = self.db_pool.simple_select_many_txn(
|
|
|
txn,
|
|
|
- events_and_contexts=events_and_contexts,
|
|
|
- all_events_and_contexts=all_events_and_contexts,
|
|
|
- backfilled=backfilled,
|
|
|
+ table="event_auth_chain_to_calculate",
|
|
|
+ keyvalues={},
|
|
|
+ column="room_id",
|
|
|
+ iterable={e.room_id for e in state_events.values()},
|
|
|
+ retcols=("event_id", "type", "state_key"),
|
|
|
)
|
|
|
+ for row in rows:
|
|
|
+ event_id = row["event_id"]
|
|
|
+ event_type = row["type"]
|
|
|
+ state_key = row["state_key"]
|
|
|
+
|
|
|
+ # (We could pull out the auth events for all rows at once using
|
|
|
+ # simple_select_many, but this case happens rarely and almost always
|
|
|
+ # with a single row.)
|
|
|
+ auth_events = self.db_pool.simple_select_onecol_txn(
|
|
|
+ txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
|
|
|
+ )
|
|
|
|
|
|
- # We call this last as it assumes we've inserted the events into
|
|
|
- # room_memberships, where applicable.
|
|
|
- self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
|
|
|
+ events_to_calc_chain_id_for.add(event_id)
|
|
|
+ event_to_types[event_id] = (event_type, state_key)
|
|
|
+ event_to_auth_chain[event_id] = auth_events
|
|
|
+
|
|
|
+ # First we get the chain ID and sequence numbers for the events'
|
|
|
+ # auth events (that aren't also currently being persisted).
|
|
|
+ #
|
|
|
+ # Note that there there is an edge case here where we might not have
|
|
|
+ # calculated chains and sequence numbers for events that were "out
|
|
|
+ # of band". We handle this case by fetching the necessary info and
|
|
|
+ # adding it to the set of events to calculate chain IDs for.
|
|
|
+
|
|
|
+ missing_auth_chains = {
|
|
|
+ a_id
|
|
|
+ for auth_events in event_to_auth_chain.values()
|
|
|
+ for a_id in auth_events
|
|
|
+ if a_id not in events_to_calc_chain_id_for
|
|
|
+ }
|
|
|
+
|
|
|
+ # We loop here in case we find an out of band membership and need to
|
|
|
+ # fetch their auth event info.
|
|
|
+ while missing_auth_chains:
|
|
|
+ sql = """
|
|
|
+ SELECT event_id, events.type, state_key, chain_id, sequence_number
|
|
|
+ FROM events
|
|
|
+ INNER JOIN state_events USING (event_id)
|
|
|
+ LEFT JOIN event_auth_chains USING (event_id)
|
|
|
+ WHERE
|
|
|
+ """
|
|
|
+ clause, args = make_in_list_sql_clause(
|
|
|
+ txn.database_engine, "event_id", missing_auth_chains,
|
|
|
+ )
|
|
|
+ txn.execute(sql + clause, args)
|
|
|
+
|
|
|
+ missing_auth_chains.clear()
|
|
|
+
|
|
|
+ for auth_id, event_type, state_key, chain_id, sequence_number in txn:
|
|
|
+ event_to_types[auth_id] = (event_type, state_key)
|
|
|
+
|
|
|
+ if chain_id is None:
|
|
|
+ # No chain ID, so the event was persisted out of band.
|
|
|
+ # We add to list of events to calculate auth chains for.
|
|
|
+
|
|
|
+ events_to_calc_chain_id_for.add(auth_id)
|
|
|
+
|
|
|
+ event_to_auth_chain[
|
|
|
+ auth_id
|
|
|
+ ] = self.db_pool.simple_select_onecol_txn(
|
|
|
+ txn,
|
|
|
+ "event_auth",
|
|
|
+ keyvalues={"event_id": auth_id},
|
|
|
+ retcol="auth_id",
|
|
|
+ )
|
|
|
+
|
|
|
+ missing_auth_chains.update(
|
|
|
+ e
|
|
|
+ for e in event_to_auth_chain[auth_id]
|
|
|
+ if e not in event_to_types
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ chain_map[auth_id] = (chain_id, sequence_number)
|
|
|
+
|
|
|
+ # Now we check if we have any events where we don't have auth chain,
|
|
|
+ # this should only be out of band memberships.
|
|
|
+ for event_id in sorted_topologically(event_to_auth_chain, event_to_auth_chain):
|
|
|
+ for auth_id in event_to_auth_chain[event_id]:
|
|
|
+ if (
|
|
|
+ auth_id not in chain_map
|
|
|
+ and auth_id not in events_to_calc_chain_id_for
|
|
|
+ ):
|
|
|
+ events_to_calc_chain_id_for.discard(event_id)
|
|
|
+
|
|
|
+ # If this is an event we're trying to persist we add it to
|
|
|
+ # the list of events to calculate chain IDs for next time
|
|
|
+ # around. (Otherwise we will have already added it to the
|
|
|
+ # table).
|
|
|
+ event = state_events.get(event_id)
|
|
|
+ if event:
|
|
|
+ self.db_pool.simple_insert_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chain_to_calculate",
|
|
|
+ values={
|
|
|
+ "event_id": event.event_id,
|
|
|
+ "room_id": event.room_id,
|
|
|
+ "type": event.type,
|
|
|
+ "state_key": event.state_key,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ # We stop checking the event's auth events since we've
|
|
|
+ # discarded it.
|
|
|
+ break
|
|
|
+
|
|
|
+ if not events_to_calc_chain_id_for:
|
|
|
+ return
|
|
|
+
|
|
|
+ # We now calculate the chain IDs/sequence numbers for the events. We
|
|
|
+ # do this by looking at the chain ID and sequence number of any auth
|
|
|
+ # event with the same type/state_key and incrementing the sequence
|
|
|
+ # number by one. If there was no match or the chain ID/sequence
|
|
|
+ # number is already taken we generate a new chain.
|
|
|
+ #
|
|
|
+ # We need to do this in a topologically sorted order as we want to
|
|
|
+ # generate chain IDs/sequence numbers of an event's auth events
|
|
|
+ # before the event itself.
|
|
|
+ chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
|
|
|
+ new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
|
|
|
+ for event_id in sorted_topologically(
|
|
|
+ events_to_calc_chain_id_for, event_to_auth_chain
|
|
|
+ ):
|
|
|
+ existing_chain_id = None
|
|
|
+ for auth_id in event_to_auth_chain[event_id]:
|
|
|
+ if event_to_types.get(event_id) == event_to_types.get(auth_id):
|
|
|
+ existing_chain_id = chain_map[auth_id]
|
|
|
+ break
|
|
|
+
|
|
|
+ new_chain_tuple = None
|
|
|
+ if existing_chain_id:
|
|
|
+ # We found a chain ID/sequence number candidate, check its
|
|
|
+ # not already taken.
|
|
|
+ proposed_new_id = existing_chain_id[0]
|
|
|
+ proposed_new_seq = existing_chain_id[1] + 1
|
|
|
+ if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
|
|
|
+ already_allocated = self.db_pool.simple_select_one_onecol_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chains",
|
|
|
+ keyvalues={
|
|
|
+ "chain_id": proposed_new_id,
|
|
|
+ "sequence_number": proposed_new_seq,
|
|
|
+ },
|
|
|
+ retcol="event_id",
|
|
|
+ allow_none=True,
|
|
|
+ )
|
|
|
+ if already_allocated:
|
|
|
+ # Mark it as already allocated so we don't need to hit
|
|
|
+ # the DB again.
|
|
|
+ chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
|
|
|
+ else:
|
|
|
+ new_chain_tuple = (
|
|
|
+ proposed_new_id,
|
|
|
+ proposed_new_seq,
|
|
|
+ )
|
|
|
+
|
|
|
+ if not new_chain_tuple:
|
|
|
+ new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
|
|
|
+
|
|
|
+ chains_tuples_allocated.add(new_chain_tuple)
|
|
|
+
|
|
|
+ chain_map[event_id] = new_chain_tuple
|
|
|
+ new_chain_tuples[event_id] = new_chain_tuple
|
|
|
+
|
|
|
+ self.db_pool.simple_insert_many_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chains",
|
|
|
+ values=[
|
|
|
+ {"event_id": event_id, "chain_id": c_id, "sequence_number": seq}
|
|
|
+ for event_id, (c_id, seq) in new_chain_tuples.items()
|
|
|
+ ],
|
|
|
+ )
|
|
|
+
|
|
|
+ self.db_pool.simple_delete_many_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chain_to_calculate",
|
|
|
+ keyvalues={},
|
|
|
+ column="event_id",
|
|
|
+ iterable=new_chain_tuples,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Now we need to calculate any new links between chains caused by
|
|
|
+ # the new events.
|
|
|
+ #
|
|
|
+ # Links are pairs of chain ID/sequence numbers such that for any
|
|
|
+ # event A (CA, SA) and any event B (CB, SB), B is in A's auth chain
|
|
|
+ # if and only if there is at least one link (CA, S1) -> (CB, S2)
|
|
|
+ # where SA >= S1 and S2 >= SB.
|
|
|
+ #
|
|
|
+ # We try and avoid adding redundant links to the table, e.g. if we
|
|
|
+ # have two links between two chains which both start/end at the
|
|
|
+ # sequence number event (or cross) then one can be safely dropped.
|
|
|
+ #
|
|
|
+ # To calculate new links we look at every new event and:
|
|
|
+ # 1. Fetch the chain ID/sequence numbers of its auth events,
|
|
|
+ # discarding any that are reachable by other auth events, or
|
|
|
+ # that have the same chain ID as the event.
|
|
|
+ # 2. For each retained auth event we:
|
|
|
+ # a. Add a link from the event's to the auth event's chain
|
|
|
+ # ID/sequence number; and
|
|
|
+ # b. Add a link from the event to every chain reachable by the
|
|
|
+ # auth event.
|
|
|
+
|
|
|
+ # Step 1, fetch all existing links from all the chains we've seen
|
|
|
+ # referenced.
|
|
|
+ chain_links = _LinkMap()
|
|
|
+ rows = self.db_pool.simple_select_many_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chain_links",
|
|
|
+ column="origin_chain_id",
|
|
|
+ iterable={chain_id for chain_id, _ in chain_map.values()},
|
|
|
+ keyvalues={},
|
|
|
+ retcols=(
|
|
|
+ "origin_chain_id",
|
|
|
+ "origin_sequence_number",
|
|
|
+ "target_chain_id",
|
|
|
+ "target_sequence_number",
|
|
|
+ ),
|
|
|
+ )
|
|
|
+ for row in rows:
|
|
|
+ chain_links.add_link(
|
|
|
+ (row["origin_chain_id"], row["origin_sequence_number"]),
|
|
|
+ (row["target_chain_id"], row["target_sequence_number"]),
|
|
|
+ new=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ # We do this in toplogical order to avoid adding redundant links.
|
|
|
+ for event_id in sorted_topologically(
|
|
|
+ events_to_calc_chain_id_for, event_to_auth_chain
|
|
|
+ ):
|
|
|
+ chain_id, sequence_number = chain_map[event_id]
|
|
|
+
|
|
|
+ # Filter out auth events that are reachable by other auth
|
|
|
+ # events. We do this by looking at every permutation of pairs of
|
|
|
+ # auth events (A, B) to check if B is reachable from A.
|
|
|
+ reduction = {
|
|
|
+ a_id
|
|
|
+ for a_id in event_to_auth_chain[event_id]
|
|
|
+ if chain_map[a_id][0] != chain_id
|
|
|
+ }
|
|
|
+ for start_auth_id, end_auth_id in itertools.permutations(
|
|
|
+ event_to_auth_chain[event_id], r=2,
|
|
|
+ ):
|
|
|
+ if chain_links.exists_path_from(
|
|
|
+ chain_map[start_auth_id], chain_map[end_auth_id]
|
|
|
+ ):
|
|
|
+ reduction.discard(end_auth_id)
|
|
|
+
|
|
|
+ # Step 2, figure out what the new links are from the reduced
|
|
|
+ # list of auth events.
|
|
|
+ for auth_id in reduction:
|
|
|
+ auth_chain_id, auth_sequence_number = chain_map[auth_id]
|
|
|
+
|
|
|
+ # Step 2a, add link between the event and auth event
|
|
|
+ chain_links.add_link(
|
|
|
+ (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
|
|
|
+ )
|
|
|
+
|
|
|
+ # Step 2b, add a link to chains reachable from the auth
|
|
|
+ # event.
|
|
|
+ for target_id, target_seq in chain_links.get_links_from(
|
|
|
+ (auth_chain_id, auth_sequence_number)
|
|
|
+ ):
|
|
|
+ if target_id == chain_id:
|
|
|
+ continue
|
|
|
+
|
|
|
+ chain_links.add_link(
|
|
|
+ (chain_id, sequence_number), (target_id, target_seq)
|
|
|
+ )
|
|
|
+
|
|
|
+ self.db_pool.simple_insert_many_txn(
|
|
|
+ txn,
|
|
|
+ table="event_auth_chain_links",
|
|
|
+ values=[
|
|
|
+ {
|
|
|
+ "origin_chain_id": source_id,
|
|
|
+ "origin_sequence_number": source_seq,
|
|
|
+ "target_chain_id": target_id,
|
|
|
+ "target_sequence_number": target_seq,
|
|
|
+ }
|
|
|
+ for (
|
|
|
+ source_id,
|
|
|
+ source_seq,
|
|
|
+ target_id,
|
|
|
+ target_seq,
|
|
|
+ ) in chain_links.get_additions()
|
|
|
+ ],
|
|
|
+ )
|
|
|
|
|
|
def _persist_transaction_ids_txn(
|
|
|
self,
|
|
@@ -1521,3 +1896,131 @@ class PersistEventsStore:
|
|
|
if not ev.internal_metadata.is_outlier()
|
|
|
],
|
|
|
)
|
|
|
+
|
|
|
+
|
|
|
+@attr.s(slots=True)
|
|
|
+class _LinkMap:
|
|
|
+ """A helper type for tracking links between chains.
|
|
|
+ """
|
|
|
+
|
|
|
+ # Stores the set of links as nested maps: source chain ID -> target chain ID
|
|
|
+ # -> source sequence number -> target sequence number.
|
|
|
+ maps = attr.ib(type=Dict[int, Dict[int, Dict[int, int]]], factory=dict)
|
|
|
+
|
|
|
+ # Stores the links that have been added (with new set to true), as tuples of
|
|
|
+ # `(source chain ID, source sequence no, target chain ID, target sequence no.)`
|
|
|
+ additions = attr.ib(type=Set[Tuple[int, int, int, int]], factory=set)
|
|
|
+
|
|
|
+ def add_link(
|
|
|
+ self,
|
|
|
+ src_tuple: Tuple[int, int],
|
|
|
+ target_tuple: Tuple[int, int],
|
|
|
+ new: bool = True,
|
|
|
+ ) -> bool:
|
|
|
+ """Add a new link between two chains, ensuring no redundant links are added.
|
|
|
+
|
|
|
+ New links should be added in topological order.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ src_tuple: The chain ID/sequence number of the source of the link.
|
|
|
+ target_tuple: The chain ID/sequence number of the target of the link.
|
|
|
+ new: Whether this is a "new" link, i.e. should it be returned
|
|
|
+ by `get_additions`.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ True if a link was added, false if the given link was dropped as redundant
|
|
|
+ """
|
|
|
+ src_chain, src_seq = src_tuple
|
|
|
+ target_chain, target_seq = target_tuple
|
|
|
+
|
|
|
+ current_links = self.maps.setdefault(src_chain, {}).setdefault(target_chain, {})
|
|
|
+
|
|
|
+ assert src_chain != target_chain
|
|
|
+
|
|
|
+ if new:
|
|
|
+ # Check if the new link is redundant
|
|
|
+ for current_seq_src, current_seq_target in current_links.items():
|
|
|
+ # If a link "crosses" another link then its redundant. For example
|
|
|
+ # in the following link 1 (L1) is redundant, as any event reachable
|
|
|
+ # via L1 is *also* reachable via L2.
|
|
|
+ #
|
|
|
+ # Chain A Chain B
|
|
|
+ # | |
|
|
|
+ # L1 |------ |
|
|
|
+ # | | |
|
|
|
+ # L2 |---- | -->|
|
|
|
+ # | | |
|
|
|
+ # | |--->|
|
|
|
+ # | |
|
|
|
+ # | |
|
|
|
+ #
|
|
|
+ # So we only need to keep links which *do not* cross, i.e. links
|
|
|
+ # that both start and end above or below an existing link.
|
|
|
+ #
|
|
|
+ # Note, since we add links in topological ordering we should never
|
|
|
+ # see `src_seq` less than `current_seq_src`.
|
|
|
+
|
|
|
+ if current_seq_src <= src_seq and target_seq <= current_seq_target:
|
|
|
+ # This new link is redundant, nothing to do.
|
|
|
+ return False
|
|
|
+
|
|
|
+ self.additions.add((src_chain, src_seq, target_chain, target_seq))
|
|
|
+
|
|
|
+ current_links[src_seq] = target_seq
|
|
|
+ return True
|
|
|
+
|
|
|
+ def get_links_from(
|
|
|
+ self, src_tuple: Tuple[int, int]
|
|
|
+ ) -> Generator[Tuple[int, int], None, None]:
|
|
|
+ """Gets the chains reachable from the given chain/sequence number.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The chain ID and sequence number the link points to.
|
|
|
+ """
|
|
|
+ src_chain, src_seq = src_tuple
|
|
|
+ for target_id, sequence_numbers in self.maps.get(src_chain, {}).items():
|
|
|
+ for link_src_seq, target_seq in sequence_numbers.items():
|
|
|
+ if link_src_seq <= src_seq:
|
|
|
+ yield target_id, target_seq
|
|
|
+
|
|
|
+ def get_links_between(
|
|
|
+ self, source_chain: int, target_chain: int
|
|
|
+ ) -> Generator[Tuple[int, int], None, None]:
|
|
|
+ """Gets the links between two chains.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The source and target sequence numbers.
|
|
|
+ """
|
|
|
+
|
|
|
+ yield from self.maps.get(source_chain, {}).get(target_chain, {}).items()
|
|
|
+
|
|
|
+ def get_additions(self) -> Generator[Tuple[int, int, int, int], None, None]:
|
|
|
+ """Gets any newly added links.
|
|
|
+
|
|
|
+ Yields:
|
|
|
+ The source chain ID/sequence number and target chain ID/sequence number
|
|
|
+ """
|
|
|
+
|
|
|
+ for src_chain, src_seq, target_chain, _ in self.additions:
|
|
|
+ target_seq = self.maps.get(src_chain, {}).get(target_chain, {}).get(src_seq)
|
|
|
+ if target_seq is not None:
|
|
|
+ yield (src_chain, src_seq, target_chain, target_seq)
|
|
|
+
|
|
|
+ def exists_path_from(
|
|
|
+ self, src_tuple: Tuple[int, int], target_tuple: Tuple[int, int],
|
|
|
+ ) -> bool:
|
|
|
+ """Checks if there is a path between the source chain ID/sequence and
|
|
|
+ target chain ID/sequence.
|
|
|
+ """
|
|
|
+ src_chain, src_seq = src_tuple
|
|
|
+ target_chain, target_seq = target_tuple
|
|
|
+
|
|
|
+ if src_chain == target_chain:
|
|
|
+ return target_seq <= src_seq
|
|
|
+
|
|
|
+ links = self.get_links_between(src_chain, target_chain)
|
|
|
+ for link_start_seq, link_end_seq in links:
|
|
|
+ if link_start_seq <= src_seq and target_seq <= link_end_seq:
|
|
|
+ return True
|
|
|
+
|
|
|
+ return False
|