123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476 |
- # -*- coding: utf-8 -*-
- # Copyright 2019 New Vector Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- import attr
- from twisted.internet import defer
- from synapse.api.constants import RelationTypes
- from synapse.api.errors import SynapseError
- from synapse.storage._base import SQLBaseStore
- from synapse.storage.stream import generate_pagination_where_clause
- from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
- logger = logging.getLogger(__name__)
- @attr.s
- class PaginationChunk(object):
- """Returned by relation pagination APIs.
- Attributes:
- chunk (list): The rows returned by pagination
- next_batch (Any|None): Token to fetch next set of results with, if
- None then there are no more results.
- prev_batch (Any|None): Token to fetch previous set of results with, if
- None then there are no previous results.
- """
- chunk = attr.ib()
- next_batch = attr.ib(default=None)
- prev_batch = attr.ib(default=None)
- def to_dict(self):
- d = {"chunk": self.chunk}
- if self.next_batch:
- d["next_batch"] = self.next_batch.to_string()
- if self.prev_batch:
- d["prev_batch"] = self.prev_batch.to_string()
- return d
- @attr.s(frozen=True, slots=True)
- class RelationPaginationToken(object):
- """Pagination token for relation pagination API.
- As the results are order by topological ordering, we can use the
- `topological_ordering` and `stream_ordering` fields of the events at the
- boundaries of the chunk as pagination tokens.
- Attributes:
- topological (int): The topological ordering of the boundary event
- stream (int): The stream ordering of the boundary event.
- """
- topological = attr.ib()
- stream = attr.ib()
- @staticmethod
- def from_string(string):
- try:
- t, s = string.split("-")
- return RelationPaginationToken(int(t), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid token")
- def to_string(self):
- return "%d-%d" % (self.topological, self.stream)
- def as_tuple(self):
- return attr.astuple(self)
- @attr.s(frozen=True, slots=True)
- class AggregationPaginationToken(object):
- """Pagination token for relation aggregation pagination API.
- As the results are order by count and then MAX(stream_ordering) of the
- aggregation groups, we can just use them as our pagination token.
- Attributes:
- count (int): The count of relations in the boundar group.
- stream (int): The MAX stream ordering in the boundary group.
- """
- count = attr.ib()
- stream = attr.ib()
- @staticmethod
- def from_string(string):
- try:
- c, s = string.split("-")
- return AggregationPaginationToken(int(c), int(s))
- except ValueError:
- raise SynapseError(400, "Invalid token")
- def to_string(self):
- return "%d-%d" % (self.count, self.stream)
- def as_tuple(self):
- return attr.astuple(self)
- class RelationsWorkerStore(SQLBaseStore):
- @cached(tree=True)
- def get_relations_for_event(
- self,
- event_id,
- relation_type=None,
- event_type=None,
- aggregation_key=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of relations for an event, ordered by topological ordering.
- Args:
- event_id (str): Fetch events that relate to this event ID.
- relation_type (str|None): Only fetch events with this relation
- type, if given.
- event_type (str|None): Only fetch events with this event type, if
- given.
- aggregation_key (str|None): Only fetch events with this aggregation
- key, if given.
- limit (int): Only fetch the most recent `limit` events.
- direction (str): Whether to fetch the most recent first (`"b"`) or
- the oldest first (`"f"`).
- from_token (RelationPaginationToken|None): Fetch rows from the given
- token, or from the start if None.
- to_token (RelationPaginationToken|None): Fetch rows up to the given
- token, or up to the end if None.
- Returns:
- Deferred[PaginationChunk]: List of event IDs that match relations
- requested. The rows are of the form `{"event_id": "..."}`.
- """
- where_clause = ["relates_to_id = ?"]
- where_args = [event_id]
- if relation_type is not None:
- where_clause.append("relation_type = ?")
- where_args.append(relation_type)
- if event_type is not None:
- where_clause.append("type = ?")
- where_args.append(event_type)
- if aggregation_key:
- where_clause.append("aggregation_key = ?")
- where_args.append(aggregation_key)
- pagination_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
- if pagination_clause:
- where_clause.append(pagination_clause)
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
- sql = """
- SELECT event_id, topological_ordering, stream_ordering
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE %s
- ORDER BY topological_ordering %s, stream_ordering %s
- LIMIT ?
- """ % (
- " AND ".join(where_clause),
- order,
- order,
- )
- def _get_recent_references_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
- last_topo_id = None
- last_stream_id = None
- events = []
- for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
- next_batch = None
- if len(events) > limit and last_topo_id and last_stream_id:
- next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
- return self.runInteraction(
- "get_recent_references_for_event", _get_recent_references_for_event_txn
- )
- @cached(tree=True)
- def get_aggregation_groups_for_event(
- self,
- event_id,
- event_type=None,
- limit=5,
- direction="b",
- from_token=None,
- to_token=None,
- ):
- """Get a list of annotations on the event, grouped by event type and
- aggregation key, sorted by count.
- This is used e.g. to get the what and how many reactions have happend
- on an event.
- Args:
- event_id (str): Fetch events that relate to this event ID.
- event_type (str|None): Only fetch events with this event type, if
- given.
- limit (int): Only fetch the `limit` groups.
- direction (str): Whether to fetch the highest count first (`"b"`) or
- the lowest count first (`"f"`).
- from_token (AggregationPaginationToken|None): Fetch rows from the
- given token, or from the start if None.
- to_token (AggregationPaginationToken|None): Fetch rows up to the
- given token, or up to the end if None.
- Returns:
- Deferred[PaginationChunk]: List of groups of annotations that
- match. Each row is a dict with `type`, `key` and `count` fields.
- """
- where_clause = ["relates_to_id = ?", "relation_type = ?"]
- where_args = [event_id, RelationTypes.ANNOTATION]
- if event_type:
- where_clause.append("type = ?")
- where_args.append(event_type)
- having_clause = generate_pagination_where_clause(
- direction=direction,
- column_names=("COUNT(*)", "MAX(stream_ordering)"),
- from_token=attr.astuple(from_token) if from_token else None,
- to_token=attr.astuple(to_token) if to_token else None,
- engine=self.database_engine,
- )
- if direction == "b":
- order = "DESC"
- else:
- order = "ASC"
- if having_clause:
- having_clause = "HAVING " + having_clause
- else:
- having_clause = ""
- sql = """
- SELECT type, aggregation_key, COUNT(*), MAX(stream_ordering)
- FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE {where_clause}
- GROUP BY relation_type, type, aggregation_key
- {having_clause}
- ORDER BY COUNT(*) {order}, MAX(stream_ordering) {order}
- LIMIT ?
- """.format(
- where_clause=" AND ".join(where_clause),
- order=order,
- having_clause=having_clause,
- )
- def _get_aggregation_groups_for_event_txn(txn):
- txn.execute(sql, where_args + [limit + 1])
- next_batch = None
- events = []
- for row in txn:
- events.append({"type": row[0], "key": row[1], "count": row[2]})
- next_batch = AggregationPaginationToken(row[2], row[3])
- if len(events) <= limit:
- next_batch = None
- return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
- )
- return self.runInteraction(
- "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
- )
- @cachedInlineCallbacks()
- def get_applicable_edit(self, event_id):
- """Get the most recent edit (if any) that has happened for the given
- event.
- Correctly handles checking whether edits were allowed to happen.
- Args:
- event_id (str): The original event ID
- Returns:
- Deferred[EventBase|None]: Returns the most recent edit, if any.
- """
- # We only allow edits for `m.room.message` events that have the same sender
- # and event type. We can't assert these things during regular event auth so
- # we have to do the checks post hoc.
- # Fetches latest edit that has the same type and sender as the
- # original, and is an `m.room.message`.
- sql = """
- SELECT edit.event_id FROM events AS edit
- INNER JOIN event_relations USING (event_id)
- INNER JOIN events AS original ON
- original.event_id = relates_to_id
- AND edit.type = original.type
- AND edit.sender = original.sender
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND edit.type = 'm.room.message'
- ORDER by edit.origin_server_ts DESC, edit.event_id DESC
- LIMIT 1
- """
- def _get_applicable_edit_txn(txn):
- txn.execute(sql, (event_id, RelationTypes.REPLACE))
- row = txn.fetchone()
- if row:
- return row[0]
- edit_id = yield self.runInteraction(
- "get_applicable_edit", _get_applicable_edit_txn
- )
- if not edit_id:
- return
- edit_event = yield self.get_event(edit_id, allow_none=True)
- defer.returnValue(edit_event)
- def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
- """Check if a user has already annotated an event with the same key
- (e.g. already liked an event).
- Args:
- parent_id (str): The event being annotated
- event_type (str): The event type of the annotation
- aggregation_key (str): The aggregation key of the annotation
- sender (str): The sender of the annotation
- Returns:
- Deferred[bool]
- """
- sql = """
- SELECT 1 FROM event_relations
- INNER JOIN events USING (event_id)
- WHERE
- relates_to_id = ?
- AND relation_type = ?
- AND type = ?
- AND sender = ?
- AND aggregation_key = ?
- LIMIT 1;
- """
- def _get_if_user_has_annotated_event(txn):
- txn.execute(
- sql,
- (
- parent_id,
- RelationTypes.ANNOTATION,
- event_type,
- sender,
- aggregation_key,
- ),
- )
- return bool(txn.fetchone())
- return self.runInteraction(
- "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
- )
- class RelationsStore(RelationsWorkerStore):
- def _handle_event_relations(self, txn, event):
- """Handles inserting relation data during peristence of events
- Args:
- txn
- event (EventBase)
- """
- relation = event.content.get("m.relates_to")
- if not relation:
- # No relations
- return
- rel_type = relation.get("rel_type")
- if rel_type not in (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.REPLACE,
- ):
- # Unknown relation type
- return
- parent_id = relation.get("event_id")
- if not parent_id:
- # Invalid relation
- return
- aggregation_key = relation.get("key")
- self._simple_insert_txn(
- txn,
- table="event_relations",
- values={
- "event_id": event.event_id,
- "relates_to_id": parent_id,
- "relation_type": rel_type,
- "aggregation_key": aggregation_key,
- },
- )
- txn.call_after(self.get_relations_for_event.invalidate_many, (parent_id,))
- txn.call_after(
- self.get_aggregation_groups_for_event.invalidate_many, (parent_id,)
- )
- if rel_type == RelationTypes.REPLACE:
- txn.call_after(self.get_applicable_edit.invalidate, (parent_id,))
- def _handle_redaction(self, txn, redacted_event_id):
- """Handles receiving a redaction and checking whether we need to remove
- any redacted relations from the database.
- Args:
- txn
- redacted_event_id (str): The event that was redacted.
- """
- self._simple_delete_txn(
- txn,
- table="event_relations",
- keyvalues={
- "event_id": redacted_event_id,
- }
- )
|