receipts.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright 2015, 2016 OpenMarket Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
  16. from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
  17. from synapse.appservice import ApplicationService
  18. from synapse.streams import EventSource
  19. from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id
  20. if TYPE_CHECKING:
  21. from synapse.server import HomeServer
  22. logger = logging.getLogger(__name__)
  23. class ReceiptsHandler:
  24. def __init__(self, hs: "HomeServer"):
  25. self.notifier = hs.get_notifier()
  26. self.server_name = hs.config.server.server_name
  27. self.store = hs.get_datastores().main
  28. self.event_auth_handler = hs.get_event_auth_handler()
  29. self.hs = hs
  30. # We only need to poke the federation sender explicitly if its on the
  31. # same instance. Other federation sender instances will get notified by
  32. # `synapse.app.generic_worker.FederationSenderHandler` when it sees it
  33. # in the receipts stream.
  34. self.federation_sender = None
  35. if hs.should_send_federation():
  36. self.federation_sender = hs.get_federation_sender()
  37. # If we can handle the receipt EDUs we do so, otherwise we route them
  38. # to the appropriate worker.
  39. if hs.get_instance_name() in hs.config.worker.writers.receipts:
  40. hs.get_federation_registry().register_edu_handler(
  41. "m.receipt", self._received_remote_receipt
  42. )
  43. else:
  44. hs.get_federation_registry().register_instances_for_edu(
  45. "m.receipt",
  46. hs.config.worker.writers.receipts,
  47. )
  48. self.clock = self.hs.get_clock()
  49. self.state = hs.get_state_handler()
  50. async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
  51. """Called when we receive an EDU of type m.receipt from a remote HS."""
  52. receipts = []
  53. for room_id, room_values in content.items():
  54. # If we're not in the room just ditch the event entirely. This is
  55. # probably an old server that has come back and thinks we're still in
  56. # the room (or we've been rejoined to the room by a state reset).
  57. is_in_room = await self.event_auth_handler.check_host_in_room(
  58. room_id, self.server_name
  59. )
  60. if not is_in_room:
  61. logger.info(
  62. "Ignoring receipt for room %r from server %s as we're not in the room",
  63. room_id,
  64. origin,
  65. )
  66. continue
  67. for receipt_type, users in room_values.items():
  68. for user_id, user_values in users.items():
  69. if get_domain_from_id(user_id) != origin:
  70. logger.info(
  71. "Received receipt for user %r from server %s, ignoring",
  72. user_id,
  73. origin,
  74. )
  75. continue
  76. receipts.append(
  77. ReadReceipt(
  78. room_id=room_id,
  79. receipt_type=receipt_type,
  80. user_id=user_id,
  81. event_ids=user_values["event_ids"],
  82. data=user_values.get("data", {}),
  83. )
  84. )
  85. await self._handle_new_receipts(receipts)
  86. async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
  87. """Takes a list of receipts, stores them and informs the notifier."""
  88. min_batch_id: Optional[int] = None
  89. max_batch_id: Optional[int] = None
  90. for receipt in receipts:
  91. res = await self.store.insert_receipt(
  92. receipt.room_id,
  93. receipt.receipt_type,
  94. receipt.user_id,
  95. receipt.event_ids,
  96. receipt.data,
  97. )
  98. if not res:
  99. # res will be None if this read receipt is 'old'
  100. continue
  101. stream_id, max_persisted_id = res
  102. if min_batch_id is None or stream_id < min_batch_id:
  103. min_batch_id = stream_id
  104. if max_batch_id is None or max_persisted_id > max_batch_id:
  105. max_batch_id = max_persisted_id
  106. # Either both of these should be None or neither.
  107. if min_batch_id is None or max_batch_id is None:
  108. # no new receipts
  109. return False
  110. affected_room_ids = list({r.room_id for r in receipts})
  111. self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
  112. # Note that the min here shouldn't be relied upon to be accurate.
  113. await self.hs.get_pusherpool().on_new_receipts(
  114. min_batch_id, max_batch_id, affected_room_ids
  115. )
  116. return True
  117. async def received_client_receipt(
  118. self, room_id: str, receipt_type: str, user_id: str, event_id: str, hidden: bool
  119. ) -> None:
  120. """Called when a client tells us a local user has read up to the given
  121. event_id in the room.
  122. """
  123. receipt = ReadReceipt(
  124. room_id=room_id,
  125. receipt_type=receipt_type,
  126. user_id=user_id,
  127. event_ids=[event_id],
  128. data={"ts": int(self.clock.time_msec()), "hidden": hidden},
  129. )
  130. is_new = await self._handle_new_receipts([receipt])
  131. if not is_new:
  132. return
  133. if self.federation_sender and not (
  134. self.hs.config.experimental.msc2285_enabled and hidden
  135. ):
  136. await self.federation_sender.send_read_receipt(receipt)
  137. class ReceiptEventSource(EventSource[int, JsonDict]):
  138. def __init__(self, hs: "HomeServer"):
  139. self.store = hs.get_datastores().main
  140. self.config = hs.config
  141. @staticmethod
  142. def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]:
  143. visible_events = []
  144. # filter out hidden receipts the user shouldn't see
  145. for event in events:
  146. content = event.get("content", {})
  147. new_event = event.copy()
  148. new_event["content"] = {}
  149. for event_id in content.keys():
  150. event_content = content.get(event_id, {})
  151. m_read = event_content.get(ReceiptTypes.READ, {})
  152. # If m_read is missing copy over the original event_content as there is nothing to process here
  153. if not m_read:
  154. new_event["content"][event_id] = event_content.copy()
  155. continue
  156. new_users = {}
  157. for rr_user_id, user_rr in m_read.items():
  158. try:
  159. hidden = user_rr.get("hidden")
  160. except AttributeError:
  161. # Due to https://github.com/matrix-org/synapse/issues/10376
  162. # there are cases where user_rr is a string, in those cases
  163. # we just ignore the read receipt
  164. continue
  165. if hidden is not True or rr_user_id == user_id:
  166. new_users[rr_user_id] = user_rr.copy()
  167. # If hidden has a value replace hidden with the correct prefixed key
  168. if hidden is not None:
  169. new_users[rr_user_id].pop("hidden")
  170. new_users[rr_user_id][
  171. ReadReceiptEventFields.MSC2285_HIDDEN
  172. ] = hidden
  173. # Set new users unless empty
  174. if len(new_users.keys()) > 0:
  175. new_event["content"][event_id] = {ReceiptTypes.READ: new_users}
  176. # Append new_event to visible_events unless empty
  177. if len(new_event["content"].keys()) > 0:
  178. visible_events.append(new_event)
  179. return visible_events
  180. async def get_new_events(
  181. self,
  182. user: UserID,
  183. from_key: int,
  184. limit: Optional[int],
  185. room_ids: Iterable[str],
  186. is_guest: bool,
  187. explicit_room_id: Optional[str] = None,
  188. ) -> Tuple[List[JsonDict], int]:
  189. from_key = int(from_key)
  190. to_key = self.get_current_key()
  191. if from_key == to_key:
  192. return [], to_key
  193. events = await self.store.get_linearized_receipts_for_rooms(
  194. room_ids, from_key=from_key, to_key=to_key
  195. )
  196. if self.config.experimental.msc2285_enabled:
  197. events = ReceiptEventSource.filter_out_hidden(events, user.to_string())
  198. return events, to_key
  199. async def get_new_events_as(
  200. self, from_key: int, service: ApplicationService
  201. ) -> Tuple[List[JsonDict], int]:
  202. """Returns a set of new read receipt events that an appservice
  203. may be interested in.
  204. Args:
  205. from_key: the stream position at which events should be fetched from
  206. service: The appservice which may be interested
  207. Returns:
  208. A two-tuple containing the following:
  209. * A list of json dictionaries derived from read receipts that the
  210. appservice may be interested in.
  211. * The current read receipt stream token.
  212. """
  213. from_key = int(from_key)
  214. to_key = self.get_current_key()
  215. if from_key == to_key:
  216. return [], to_key
  217. # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
  218. # by most recent.
  219. rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
  220. from_key=from_key, to_key=to_key
  221. )
  222. # Then filter down to rooms that the AS can read
  223. events = []
  224. for room_id, event in rooms_to_events.items():
  225. if not await service.matches_user_in_member_list(room_id, self.store):
  226. continue
  227. events.append(event)
  228. return events, to_key
  229. def get_current_key(self, direction: str = "f") -> int:
  230. return self.store.get_max_receipt_stream_id()