search.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import itertools
  16. import logging
  17. from unpaddedbase64 import decode_base64, encode_base64
  18. from twisted.internet import defer
  19. from synapse.api.constants import EventTypes, Membership
  20. from synapse.api.errors import SynapseError
  21. from synapse.api.filtering import Filter
  22. from synapse.events.utils import serialize_event
  23. from synapse.visibility import filter_events_for_client
  24. from ._base import BaseHandler
  25. logger = logging.getLogger(__name__)
  26. class SearchHandler(BaseHandler):
  27. def __init__(self, hs):
  28. super(SearchHandler, self).__init__(hs)
  29. @defer.inlineCallbacks
  30. def search(self, user, content, batch=None):
  31. """Performs a full text search for a user.
  32. Args:
  33. user (UserID)
  34. content (dict): Search parameters
  35. batch (str): The next_batch parameter. Used for pagination.
  36. Returns:
  37. dict to be returned to the client with results of search
  38. """
  39. batch_group = None
  40. batch_group_key = None
  41. batch_token = None
  42. if batch:
  43. try:
  44. b = decode_base64(batch)
  45. batch_group, batch_group_key, batch_token = b.split("\n")
  46. assert batch_group is not None
  47. assert batch_group_key is not None
  48. assert batch_token is not None
  49. except Exception:
  50. raise SynapseError(400, "Invalid batch")
  51. logger.info(
  52. "Search batch properties: %r, %r, %r",
  53. batch_group, batch_group_key, batch_token,
  54. )
  55. logger.info("Search content: %s", content)
  56. try:
  57. room_cat = content["search_categories"]["room_events"]
  58. # The actual thing to query in FTS
  59. search_term = room_cat["search_term"]
  60. # Which "keys" to search over in FTS query
  61. keys = room_cat.get("keys", [
  62. "content.body", "content.name", "content.topic",
  63. ])
  64. # Filter to apply to results
  65. filter_dict = room_cat.get("filter", {})
  66. # What to order results by (impacts whether pagination can be doen)
  67. order_by = room_cat.get("order_by", "rank")
  68. # Return the current state of the rooms?
  69. include_state = room_cat.get("include_state", False)
  70. # Include context around each event?
  71. event_context = room_cat.get(
  72. "event_context", None
  73. )
  74. # Group results together? May allow clients to paginate within a
  75. # group
  76. group_by = room_cat.get("groupings", {}).get("group_by", {})
  77. group_keys = [g["key"] for g in group_by]
  78. if event_context is not None:
  79. before_limit = int(event_context.get(
  80. "before_limit", 5
  81. ))
  82. after_limit = int(event_context.get(
  83. "after_limit", 5
  84. ))
  85. # Return the historic display name and avatar for the senders
  86. # of the events?
  87. include_profile = bool(event_context.get("include_profile", False))
  88. except KeyError:
  89. raise SynapseError(400, "Invalid search query")
  90. if order_by not in ("rank", "recent"):
  91. raise SynapseError(400, "Invalid order by: %r" % (order_by,))
  92. if set(group_keys) - {"room_id", "sender"}:
  93. raise SynapseError(
  94. 400,
  95. "Invalid group by keys: %r" % (set(group_keys) - {"room_id", "sender"},)
  96. )
  97. search_filter = Filter(filter_dict)
  98. # TODO: Search through left rooms too
  99. rooms = yield self.store.get_rooms_for_user_where_membership_is(
  100. user.to_string(),
  101. membership_list=[Membership.JOIN],
  102. # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
  103. )
  104. room_ids = set(r.room_id for r in rooms)
  105. room_ids = search_filter.filter_rooms(room_ids)
  106. if batch_group == "room_id":
  107. room_ids.intersection_update({batch_group_key})
  108. if not room_ids:
  109. defer.returnValue({
  110. "search_categories": {
  111. "room_events": {
  112. "results": [],
  113. "count": 0,
  114. "highlights": [],
  115. }
  116. }
  117. })
  118. rank_map = {} # event_id -> rank of event
  119. allowed_events = []
  120. room_groups = {} # Holds result of grouping by room, if applicable
  121. sender_group = {} # Holds result of grouping by sender, if applicable
  122. # Holds the next_batch for the entire result set if one of those exists
  123. global_next_batch = None
  124. highlights = set()
  125. count = None
  126. if order_by == "rank":
  127. search_result = yield self.store.search_msgs(
  128. room_ids, search_term, keys
  129. )
  130. count = search_result["count"]
  131. if search_result["highlights"]:
  132. highlights.update(search_result["highlights"])
  133. results = search_result["results"]
  134. results_map = {r["event"].event_id: r for r in results}
  135. rank_map.update({r["event"].event_id: r["rank"] for r in results})
  136. filtered_events = search_filter.filter([r["event"] for r in results])
  137. events = yield filter_events_for_client(
  138. self.store, user.to_string(), filtered_events
  139. )
  140. events.sort(key=lambda e: -rank_map[e.event_id])
  141. allowed_events = events[:search_filter.limit()]
  142. for e in allowed_events:
  143. rm = room_groups.setdefault(e.room_id, {
  144. "results": [],
  145. "order": rank_map[e.event_id],
  146. })
  147. rm["results"].append(e.event_id)
  148. s = sender_group.setdefault(e.sender, {
  149. "results": [],
  150. "order": rank_map[e.event_id],
  151. })
  152. s["results"].append(e.event_id)
  153. elif order_by == "recent":
  154. room_events = []
  155. i = 0
  156. pagination_token = batch_token
  157. # We keep looping and we keep filtering until we reach the limit
  158. # or we run out of things.
  159. # But only go around 5 times since otherwise synapse will be sad.
  160. while len(room_events) < search_filter.limit() and i < 5:
  161. i += 1
  162. search_result = yield self.store.search_rooms(
  163. room_ids, search_term, keys, search_filter.limit() * 2,
  164. pagination_token=pagination_token,
  165. )
  166. if search_result["highlights"]:
  167. highlights.update(search_result["highlights"])
  168. count = search_result["count"]
  169. results = search_result["results"]
  170. results_map = {r["event"].event_id: r for r in results}
  171. rank_map.update({r["event"].event_id: r["rank"] for r in results})
  172. filtered_events = search_filter.filter([
  173. r["event"] for r in results
  174. ])
  175. events = yield filter_events_for_client(
  176. self.store, user.to_string(), filtered_events
  177. )
  178. room_events.extend(events)
  179. room_events = room_events[:search_filter.limit()]
  180. if len(results) < search_filter.limit() * 2:
  181. pagination_token = None
  182. break
  183. else:
  184. pagination_token = results[-1]["pagination_token"]
  185. for event in room_events:
  186. group = room_groups.setdefault(event.room_id, {
  187. "results": [],
  188. })
  189. group["results"].append(event.event_id)
  190. if room_events and len(room_events) >= search_filter.limit():
  191. last_event_id = room_events[-1].event_id
  192. pagination_token = results_map[last_event_id]["pagination_token"]
  193. # We want to respect the given batch group and group keys so
  194. # that if people blindly use the top level `next_batch` token
  195. # it returns more from the same group (if applicable) rather
  196. # than reverting to searching all results again.
  197. if batch_group and batch_group_key:
  198. global_next_batch = encode_base64("%s\n%s\n%s" % (
  199. batch_group, batch_group_key, pagination_token
  200. ))
  201. else:
  202. global_next_batch = encode_base64("%s\n%s\n%s" % (
  203. "all", "", pagination_token
  204. ))
  205. for room_id, group in room_groups.items():
  206. group["next_batch"] = encode_base64("%s\n%s\n%s" % (
  207. "room_id", room_id, pagination_token
  208. ))
  209. allowed_events.extend(room_events)
  210. else:
  211. # We should never get here due to the guard earlier.
  212. raise NotImplementedError()
  213. logger.info("Found %d events to return", len(allowed_events))
  214. # If client has asked for "context" for each event (i.e. some surrounding
  215. # events and state), fetch that
  216. if event_context is not None:
  217. now_token = yield self.hs.get_event_sources().get_current_token()
  218. contexts = {}
  219. for event in allowed_events:
  220. res = yield self.store.get_events_around(
  221. event.room_id, event.event_id, before_limit, after_limit
  222. )
  223. logger.info(
  224. "Context for search returned %d and %d events",
  225. len(res["events_before"]), len(res["events_after"]),
  226. )
  227. res["events_before"] = yield filter_events_for_client(
  228. self.store, user.to_string(), res["events_before"]
  229. )
  230. res["events_after"] = yield filter_events_for_client(
  231. self.store, user.to_string(), res["events_after"]
  232. )
  233. res["start"] = now_token.copy_and_replace(
  234. "room_key", res["start"]
  235. ).to_string()
  236. res["end"] = now_token.copy_and_replace(
  237. "room_key", res["end"]
  238. ).to_string()
  239. if include_profile:
  240. senders = set(
  241. ev.sender
  242. for ev in itertools.chain(
  243. res["events_before"], [event], res["events_after"]
  244. )
  245. )
  246. if res["events_after"]:
  247. last_event_id = res["events_after"][-1].event_id
  248. else:
  249. last_event_id = event.event_id
  250. state = yield self.store.get_state_for_event(
  251. last_event_id,
  252. types=[(EventTypes.Member, sender) for sender in senders]
  253. )
  254. res["profile_info"] = {
  255. s.state_key: {
  256. "displayname": s.content.get("displayname", None),
  257. "avatar_url": s.content.get("avatar_url", None),
  258. }
  259. for s in state.values()
  260. if s.type == EventTypes.Member and s.state_key in senders
  261. }
  262. contexts[event.event_id] = res
  263. else:
  264. contexts = {}
  265. # TODO: Add a limit
  266. time_now = self.clock.time_msec()
  267. for context in contexts.values():
  268. context["events_before"] = [
  269. serialize_event(e, time_now)
  270. for e in context["events_before"]
  271. ]
  272. context["events_after"] = [
  273. serialize_event(e, time_now)
  274. for e in context["events_after"]
  275. ]
  276. state_results = {}
  277. if include_state:
  278. rooms = set(e.room_id for e in allowed_events)
  279. for room_id in rooms:
  280. state = yield self.state_handler.get_current_state(room_id)
  281. state_results[room_id] = list(state.values())
  282. state_results.values()
  283. # We're now about to serialize the events. We should not make any
  284. # blocking calls after this. Otherwise the 'age' will be wrong
  285. results = [
  286. {
  287. "rank": rank_map[e.event_id],
  288. "result": serialize_event(e, time_now),
  289. "context": contexts.get(e.event_id, {}),
  290. }
  291. for e in allowed_events
  292. ]
  293. rooms_cat_res = {
  294. "results": results,
  295. "count": count,
  296. "highlights": list(highlights),
  297. }
  298. if state_results:
  299. rooms_cat_res["state"] = {
  300. room_id: [serialize_event(e, time_now) for e in state]
  301. for room_id, state in state_results.items()
  302. }
  303. if room_groups and "room_id" in group_keys:
  304. rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
  305. if sender_group and "sender" in group_keys:
  306. rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
  307. if global_next_batch:
  308. rooms_cat_res["next_batch"] = global_next_batch
  309. defer.returnValue({
  310. "search_categories": {
  311. "room_events": rooms_cat_res
  312. }
  313. })