typing.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from collections import namedtuple
  17. from typing import List
  18. from twisted.internet import defer
  19. from synapse.api.errors import AuthError, SynapseError
  20. from synapse.logging.context import run_in_background
  21. from synapse.types import UserID, get_domain_from_id
  22. from synapse.util.caches.stream_change_cache import StreamChangeCache
  23. from synapse.util.metrics import Measure
  24. from synapse.util.wheel_timer import WheelTimer
  25. logger = logging.getLogger(__name__)
  26. # A tiny object useful for storing a user's membership in a room, as a mapping
  27. # key
  28. RoomMember = namedtuple("RoomMember", ("room_id", "user_id"))
  29. # How often we expect remote servers to resend us presence.
  30. FEDERATION_TIMEOUT = 60 * 1000
  31. # How often to resend typing across federation.
  32. FEDERATION_PING_INTERVAL = 40 * 1000
  33. class TypingHandler(object):
  34. def __init__(self, hs):
  35. self.store = hs.get_datastore()
  36. self.server_name = hs.config.server_name
  37. self.auth = hs.get_auth()
  38. self.is_mine_id = hs.is_mine_id
  39. self.notifier = hs.get_notifier()
  40. self.state = hs.get_state_handler()
  41. self.hs = hs
  42. self.clock = hs.get_clock()
  43. self.wheel_timer = WheelTimer(bucket_size=5000)
  44. self.federation = hs.get_federation_sender()
  45. hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
  46. hs.get_distributor().observe("user_left_room", self.user_left_room)
  47. self._member_typing_until = {} # clock time we expect to stop
  48. self._member_last_federation_poke = {}
  49. self._latest_room_serial = 0
  50. self._reset()
  51. # caches which room_ids changed at which serials
  52. self._typing_stream_change_cache = StreamChangeCache(
  53. "TypingStreamChangeCache", self._latest_room_serial
  54. )
  55. self.clock.looping_call(self._handle_timeouts, 5000)
  56. def _reset(self):
  57. """
  58. Reset the typing handler's data caches.
  59. """
  60. # map room IDs to serial numbers
  61. self._room_serials = {}
  62. # map room IDs to sets of users currently typing
  63. self._room_typing = {}
  64. def _handle_timeouts(self):
  65. logger.debug("Checking for typing timeouts")
  66. now = self.clock.time_msec()
  67. members = set(self.wheel_timer.fetch(now))
  68. for member in members:
  69. if not self.is_typing(member):
  70. # Nothing to do if they're no longer typing
  71. continue
  72. until = self._member_typing_until.get(member, None)
  73. if not until or until <= now:
  74. logger.info("Timing out typing for: %s", member.user_id)
  75. self._stopped_typing(member)
  76. continue
  77. # Check if we need to resend a keep alive over federation for this
  78. # user.
  79. if self.hs.is_mine_id(member.user_id):
  80. last_fed_poke = self._member_last_federation_poke.get(member, None)
  81. if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
  82. run_in_background(self._push_remote, member=member, typing=True)
  83. # Add a paranoia timer to ensure that we always have a timer for
  84. # each person typing.
  85. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
  86. def is_typing(self, member):
  87. return member.user_id in self._room_typing.get(member.room_id, [])
  88. @defer.inlineCallbacks
  89. def started_typing(self, target_user, auth_user, room_id, timeout):
  90. target_user_id = target_user.to_string()
  91. auth_user_id = auth_user.to_string()
  92. if not self.is_mine_id(target_user_id):
  93. raise SynapseError(400, "User is not hosted on this homeserver")
  94. if target_user_id != auth_user_id:
  95. raise AuthError(400, "Cannot set another user's typing state")
  96. yield self.auth.check_user_in_room(room_id, target_user_id)
  97. logger.debug("%s has started typing in %s", target_user_id, room_id)
  98. member = RoomMember(room_id=room_id, user_id=target_user_id)
  99. was_present = member.user_id in self._room_typing.get(room_id, set())
  100. now = self.clock.time_msec()
  101. self._member_typing_until[member] = now + timeout
  102. self.wheel_timer.insert(now=now, obj=member, then=now + timeout)
  103. if was_present:
  104. # No point sending another notification
  105. return None
  106. self._push_update(member=member, typing=True)
  107. @defer.inlineCallbacks
  108. def stopped_typing(self, target_user, auth_user, room_id):
  109. target_user_id = target_user.to_string()
  110. auth_user_id = auth_user.to_string()
  111. if not self.is_mine_id(target_user_id):
  112. raise SynapseError(400, "User is not hosted on this homeserver")
  113. if target_user_id != auth_user_id:
  114. raise AuthError(400, "Cannot set another user's typing state")
  115. yield self.auth.check_user_in_room(room_id, target_user_id)
  116. logger.debug("%s has stopped typing in %s", target_user_id, room_id)
  117. member = RoomMember(room_id=room_id, user_id=target_user_id)
  118. self._stopped_typing(member)
  119. @defer.inlineCallbacks
  120. def user_left_room(self, user, room_id):
  121. user_id = user.to_string()
  122. if self.is_mine_id(user_id):
  123. member = RoomMember(room_id=room_id, user_id=user_id)
  124. yield self._stopped_typing(member)
  125. def _stopped_typing(self, member):
  126. if member.user_id not in self._room_typing.get(member.room_id, set()):
  127. # No point
  128. return None
  129. self._member_typing_until.pop(member, None)
  130. self._member_last_federation_poke.pop(member, None)
  131. self._push_update(member=member, typing=False)
  132. def _push_update(self, member, typing):
  133. if self.hs.is_mine_id(member.user_id):
  134. # Only send updates for changes to our own users.
  135. run_in_background(self._push_remote, member, typing)
  136. self._push_update_local(member=member, typing=typing)
  137. @defer.inlineCallbacks
  138. def _push_remote(self, member, typing):
  139. try:
  140. users = yield self.state.get_current_users_in_room(member.room_id)
  141. self._member_last_federation_poke[member] = self.clock.time_msec()
  142. now = self.clock.time_msec()
  143. self.wheel_timer.insert(
  144. now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
  145. )
  146. for domain in {get_domain_from_id(u) for u in users}:
  147. if domain != self.server_name:
  148. logger.debug("sending typing update to %s", domain)
  149. self.federation.build_and_send_edu(
  150. destination=domain,
  151. edu_type="m.typing",
  152. content={
  153. "room_id": member.room_id,
  154. "user_id": member.user_id,
  155. "typing": typing,
  156. },
  157. key=member,
  158. )
  159. except Exception:
  160. logger.exception("Error pushing typing notif to remotes")
  161. @defer.inlineCallbacks
  162. def _recv_edu(self, origin, content):
  163. room_id = content["room_id"]
  164. user_id = content["user_id"]
  165. member = RoomMember(user_id=user_id, room_id=room_id)
  166. # Check that the string is a valid user id
  167. user = UserID.from_string(user_id)
  168. if user.domain != origin:
  169. logger.info(
  170. "Got typing update from %r with bad 'user_id': %r", origin, user_id
  171. )
  172. return
  173. users = yield self.state.get_current_users_in_room(room_id)
  174. domains = {get_domain_from_id(u) for u in users}
  175. if self.server_name in domains:
  176. logger.info("Got typing update from %s: %r", user_id, content)
  177. now = self.clock.time_msec()
  178. self._member_typing_until[member] = now + FEDERATION_TIMEOUT
  179. self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
  180. self._push_update_local(member=member, typing=content["typing"])
  181. def _push_update_local(self, member, typing):
  182. room_set = self._room_typing.setdefault(member.room_id, set())
  183. if typing:
  184. room_set.add(member.user_id)
  185. else:
  186. room_set.discard(member.user_id)
  187. self._latest_room_serial += 1
  188. self._room_serials[member.room_id] = self._latest_room_serial
  189. self._typing_stream_change_cache.entity_has_changed(
  190. member.room_id, self._latest_room_serial
  191. )
  192. self.notifier.on_new_event(
  193. "typing_key", self._latest_room_serial, rooms=[member.room_id]
  194. )
  195. async def get_all_typing_updates(
  196. self, last_id: int, current_id: int, limit: int
  197. ) -> List[dict]:
  198. """Get up to `limit` typing updates between the given tokens, earliest
  199. updates first.
  200. """
  201. if last_id == current_id:
  202. return []
  203. changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
  204. last_id
  205. )
  206. if changed_rooms is None:
  207. changed_rooms = self._room_serials
  208. rows = []
  209. for room_id in changed_rooms:
  210. serial = self._room_serials[room_id]
  211. if last_id < serial <= current_id:
  212. typing = self._room_typing[room_id]
  213. rows.append((serial, room_id, list(typing)))
  214. rows.sort()
  215. return rows[:limit]
  216. def get_current_token(self):
  217. return self._latest_room_serial
  218. class TypingNotificationEventSource(object):
  219. def __init__(self, hs):
  220. self.hs = hs
  221. self.clock = hs.get_clock()
  222. # We can't call get_typing_handler here because there's a cycle:
  223. #
  224. # Typing -> Notifier -> TypingNotificationEventSource -> Typing
  225. #
  226. self.get_typing_handler = hs.get_typing_handler
  227. def _make_event_for(self, room_id):
  228. typing = self.get_typing_handler()._room_typing[room_id]
  229. return {
  230. "type": "m.typing",
  231. "room_id": room_id,
  232. "content": {"user_ids": list(typing)},
  233. }
  234. def get_new_events(self, from_key, room_ids, **kwargs):
  235. with Measure(self.clock, "typing.get_new_events"):
  236. from_key = int(from_key)
  237. handler = self.get_typing_handler()
  238. events = []
  239. for room_id in room_ids:
  240. if room_id not in handler._room_serials:
  241. continue
  242. if handler._room_serials[room_id] <= from_key:
  243. continue
  244. events.append(self._make_event_for(room_id))
  245. return defer.succeed((events, handler._latest_room_serial))
  246. def get_current_key(self):
  247. return self.get_typing_handler()._latest_room_serial