search.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. import logging
  17. from unpaddedbase64 import decode_base64, encode_base64
  18. from twisted.internet import defer
  19. from synapse.api.constants import EventTypes, Membership
  20. from synapse.api.errors import SynapseError
  21. from synapse.api.filtering import Filter
  22. from synapse.storage.state import StateFilter
  23. from synapse.visibility import filter_events_for_client
  24. from ._base import BaseHandler
  25. logger = logging.getLogger(__name__)
  26. class SearchHandler(BaseHandler):
  27. def __init__(self, hs):
  28. super(SearchHandler, self).__init__(hs)
  29. self._event_serializer = hs.get_event_client_serializer()
  30. @defer.inlineCallbacks
  31. def get_old_rooms_from_upgraded_room(self, room_id):
  32. """Retrieves room IDs of old rooms in the history of an upgraded room.
  33. We do so by checking the m.room.create event of the room for a
  34. `predecessor` key. If it exists, we add the room ID to our return
  35. list and then check that room for a m.room.create event and so on
  36. until we can no longer find any more previous rooms.
  37. The full list of all found rooms in then returned.
  38. Args:
  39. room_id (str): id of the room to search through.
  40. Returns:
  41. Deferred[iterable[unicode]]: predecessor room ids
  42. """
  43. historical_room_ids = []
  44. while True:
  45. predecessor = yield self.store.get_room_predecessor(room_id)
  46. # If no predecessor, assume we've hit a dead end
  47. if not predecessor:
  48. break
  49. # Add predecessor's room ID
  50. historical_room_ids.append(predecessor["room_id"])
  51. # Scan through the old room for further predecessors
  52. room_id = predecessor["room_id"]
  53. return historical_room_ids
  54. @defer.inlineCallbacks
  55. def search(self, user, content, batch=None):
  56. """Performs a full text search for a user.
  57. Args:
  58. user (UserID)
  59. content (dict): Search parameters
  60. batch (str): The next_batch parameter. Used for pagination.
  61. Returns:
  62. dict to be returned to the client with results of search
  63. """
  64. if not self.hs.config.enable_search:
  65. raise SynapseError(400, "Search is disabled on this homeserver")
  66. batch_group = None
  67. batch_group_key = None
  68. batch_token = None
  69. if batch:
  70. try:
  71. b = decode_base64(batch).decode("ascii")
  72. batch_group, batch_group_key, batch_token = b.split("\n")
  73. assert batch_group is not None
  74. assert batch_group_key is not None
  75. assert batch_token is not None
  76. except Exception:
  77. raise SynapseError(400, "Invalid batch")
  78. logger.info(
  79. "Search batch properties: %r, %r, %r",
  80. batch_group,
  81. batch_group_key,
  82. batch_token,
  83. )
  84. logger.info("Search content: %s", content)
  85. try:
  86. room_cat = content["search_categories"]["room_events"]
  87. # The actual thing to query in FTS
  88. search_term = room_cat["search_term"]
  89. # Which "keys" to search over in FTS query
  90. keys = room_cat.get(
  91. "keys", ["content.body", "content.name", "content.topic"]
  92. )
  93. # Filter to apply to results
  94. filter_dict = room_cat.get("filter", {})
  95. # What to order results by (impacts whether pagination can be doen)
  96. order_by = room_cat.get("order_by", "rank")
  97. # Return the current state of the rooms?
  98. include_state = room_cat.get("include_state", False)
  99. # Include context around each event?
  100. event_context = room_cat.get("event_context", None)
  101. # Group results together? May allow clients to paginate within a
  102. # group
  103. group_by = room_cat.get("groupings", {}).get("group_by", {})
  104. group_keys = [g["key"] for g in group_by]
  105. if event_context is not None:
  106. before_limit = int(event_context.get("before_limit", 5))
  107. after_limit = int(event_context.get("after_limit", 5))
  108. # Return the historic display name and avatar for the senders
  109. # of the events?
  110. include_profile = bool(event_context.get("include_profile", False))
  111. except KeyError:
  112. raise SynapseError(400, "Invalid search query")
  113. if order_by not in ("rank", "recent"):
  114. raise SynapseError(400, "Invalid order by: %r" % (order_by,))
  115. if set(group_keys) - {"room_id", "sender"}:
  116. raise SynapseError(
  117. 400,
  118. "Invalid group by keys: %r"
  119. % (set(group_keys) - {"room_id", "sender"},),
  120. )
  121. search_filter = Filter(filter_dict)
  122. # TODO: Search through left rooms too
  123. rooms = yield self.store.get_rooms_for_user_where_membership_is(
  124. user.to_string(),
  125. membership_list=[Membership.JOIN],
  126. # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
  127. )
  128. room_ids = set(r.room_id for r in rooms)
  129. # If doing a subset of all rooms seearch, check if any of the rooms
  130. # are from an upgraded room, and search their contents as well
  131. if search_filter.rooms:
  132. historical_room_ids = []
  133. for room_id in search_filter.rooms:
  134. # Add any previous rooms to the search if they exist
  135. ids = yield self.get_old_rooms_from_upgraded_room(room_id)
  136. historical_room_ids += ids
  137. # Prevent any historical events from being filtered
  138. search_filter = search_filter.with_room_ids(historical_room_ids)
  139. room_ids = search_filter.filter_rooms(room_ids)
  140. if batch_group == "room_id":
  141. room_ids.intersection_update({batch_group_key})
  142. if not room_ids:
  143. return {
  144. "search_categories": {
  145. "room_events": {"results": [], "count": 0, "highlights": []}
  146. }
  147. }
  148. rank_map = {} # event_id -> rank of event
  149. allowed_events = []
  150. room_groups = {} # Holds result of grouping by room, if applicable
  151. sender_group = {} # Holds result of grouping by sender, if applicable
  152. # Holds the next_batch for the entire result set if one of those exists
  153. global_next_batch = None
  154. highlights = set()
  155. count = None
  156. if order_by == "rank":
  157. search_result = yield self.store.search_msgs(room_ids, search_term, keys)
  158. count = search_result["count"]
  159. if search_result["highlights"]:
  160. highlights.update(search_result["highlights"])
  161. results = search_result["results"]
  162. results_map = {r["event"].event_id: r for r in results}
  163. rank_map.update({r["event"].event_id: r["rank"] for r in results})
  164. filtered_events = search_filter.filter([r["event"] for r in results])
  165. events = yield filter_events_for_client(
  166. self.store, user.to_string(), filtered_events
  167. )
  168. events.sort(key=lambda e: -rank_map[e.event_id])
  169. allowed_events = events[: search_filter.limit()]
  170. for e in allowed_events:
  171. rm = room_groups.setdefault(
  172. e.room_id, {"results": [], "order": rank_map[e.event_id]}
  173. )
  174. rm["results"].append(e.event_id)
  175. s = sender_group.setdefault(
  176. e.sender, {"results": [], "order": rank_map[e.event_id]}
  177. )
  178. s["results"].append(e.event_id)
  179. elif order_by == "recent":
  180. room_events = []
  181. i = 0
  182. pagination_token = batch_token
  183. # We keep looping and we keep filtering until we reach the limit
  184. # or we run out of things.
  185. # But only go around 5 times since otherwise synapse will be sad.
  186. while len(room_events) < search_filter.limit() and i < 5:
  187. i += 1
  188. search_result = yield self.store.search_rooms(
  189. room_ids,
  190. search_term,
  191. keys,
  192. search_filter.limit() * 2,
  193. pagination_token=pagination_token,
  194. )
  195. if search_result["highlights"]:
  196. highlights.update(search_result["highlights"])
  197. count = search_result["count"]
  198. results = search_result["results"]
  199. results_map = {r["event"].event_id: r for r in results}
  200. rank_map.update({r["event"].event_id: r["rank"] for r in results})
  201. filtered_events = search_filter.filter([r["event"] for r in results])
  202. events = yield filter_events_for_client(
  203. self.store, user.to_string(), filtered_events
  204. )
  205. room_events.extend(events)
  206. room_events = room_events[: search_filter.limit()]
  207. if len(results) < search_filter.limit() * 2:
  208. pagination_token = None
  209. break
  210. else:
  211. pagination_token = results[-1]["pagination_token"]
  212. for event in room_events:
  213. group = room_groups.setdefault(event.room_id, {"results": []})
  214. group["results"].append(event.event_id)
  215. if room_events and len(room_events) >= search_filter.limit():
  216. last_event_id = room_events[-1].event_id
  217. pagination_token = results_map[last_event_id]["pagination_token"]
  218. # We want to respect the given batch group and group keys so
  219. # that if people blindly use the top level `next_batch` token
  220. # it returns more from the same group (if applicable) rather
  221. # than reverting to searching all results again.
  222. if batch_group and batch_group_key:
  223. global_next_batch = encode_base64(
  224. (
  225. "%s\n%s\n%s"
  226. % (batch_group, batch_group_key, pagination_token)
  227. ).encode("ascii")
  228. )
  229. else:
  230. global_next_batch = encode_base64(
  231. ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
  232. )
  233. for room_id, group in room_groups.items():
  234. group["next_batch"] = encode_base64(
  235. ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
  236. "ascii"
  237. )
  238. )
  239. allowed_events.extend(room_events)
  240. else:
  241. # We should never get here due to the guard earlier.
  242. raise NotImplementedError()
  243. logger.info("Found %d events to return", len(allowed_events))
  244. # If client has asked for "context" for each event (i.e. some surrounding
  245. # events and state), fetch that
  246. if event_context is not None:
  247. now_token = yield self.hs.get_event_sources().get_current_token()
  248. contexts = {}
  249. for event in allowed_events:
  250. res = yield self.store.get_events_around(
  251. event.room_id, event.event_id, before_limit, after_limit
  252. )
  253. logger.info(
  254. "Context for search returned %d and %d events",
  255. len(res["events_before"]),
  256. len(res["events_after"]),
  257. )
  258. res["events_before"] = yield filter_events_for_client(
  259. self.store, user.to_string(), res["events_before"]
  260. )
  261. res["events_after"] = yield filter_events_for_client(
  262. self.store, user.to_string(), res["events_after"]
  263. )
  264. res["start"] = now_token.copy_and_replace(
  265. "room_key", res["start"]
  266. ).to_string()
  267. res["end"] = now_token.copy_and_replace(
  268. "room_key", res["end"]
  269. ).to_string()
  270. if include_profile:
  271. senders = set(
  272. ev.sender
  273. for ev in itertools.chain(
  274. res["events_before"], [event], res["events_after"]
  275. )
  276. )
  277. if res["events_after"]:
  278. last_event_id = res["events_after"][-1].event_id
  279. else:
  280. last_event_id = event.event_id
  281. state_filter = StateFilter.from_types(
  282. [(EventTypes.Member, sender) for sender in senders]
  283. )
  284. state = yield self.store.get_state_for_event(
  285. last_event_id, state_filter
  286. )
  287. res["profile_info"] = {
  288. s.state_key: {
  289. "displayname": s.content.get("displayname", None),
  290. "avatar_url": s.content.get("avatar_url", None),
  291. }
  292. for s in state.values()
  293. if s.type == EventTypes.Member and s.state_key in senders
  294. }
  295. contexts[event.event_id] = res
  296. else:
  297. contexts = {}
  298. # TODO: Add a limit
  299. time_now = self.clock.time_msec()
  300. for context in contexts.values():
  301. context["events_before"] = (
  302. yield self._event_serializer.serialize_events(
  303. context["events_before"], time_now
  304. )
  305. )
  306. context["events_after"] = (
  307. yield self._event_serializer.serialize_events(
  308. context["events_after"], time_now
  309. )
  310. )
  311. state_results = {}
  312. if include_state:
  313. rooms = set(e.room_id for e in allowed_events)
  314. for room_id in rooms:
  315. state = yield self.state_handler.get_current_state(room_id)
  316. state_results[room_id] = list(state.values())
  317. state_results.values()
  318. # We're now about to serialize the events. We should not make any
  319. # blocking calls after this. Otherwise the 'age' will be wrong
  320. results = []
  321. for e in allowed_events:
  322. results.append(
  323. {
  324. "rank": rank_map[e.event_id],
  325. "result": (
  326. yield self._event_serializer.serialize_event(e, time_now)
  327. ),
  328. "context": contexts.get(e.event_id, {}),
  329. }
  330. )
  331. rooms_cat_res = {
  332. "results": results,
  333. "count": count,
  334. "highlights": list(highlights),
  335. }
  336. if state_results:
  337. s = {}
  338. for room_id, state in state_results.items():
  339. s[room_id] = yield self._event_serializer.serialize_events(
  340. state, time_now
  341. )
  342. rooms_cat_res["state"] = s
  343. if room_groups and "room_id" in group_keys:
  344. rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
  345. if sender_group and "sender" in group_keys:
  346. rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
  347. if global_next_batch:
  348. rooms_cat_res["next_batch"] = global_next_batch
  349. return {"search_categories": {"room_events": rooms_cat_res}}