search.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from .background_updates import BackgroundUpdateStore
  17. from synapse.api.errors import SynapseError
  18. from synapse.storage.engines import PostgresEngine, Sqlite3Engine
  19. import logging
  20. import re
  21. logger = logging.getLogger(__name__)
  22. class SearchStore(BackgroundUpdateStore):
  23. EVENT_SEARCH_UPDATE_NAME = "event_search"
  24. def __init__(self, hs):
  25. super(SearchStore, self).__init__(hs)
  26. self.register_background_update_handler(
  27. self.EVENT_SEARCH_UPDATE_NAME, self._background_reindex_search
  28. )
  29. @defer.inlineCallbacks
  30. def _background_reindex_search(self, progress, batch_size):
  31. target_min_stream_id = progress["target_min_stream_id_inclusive"]
  32. max_stream_id = progress["max_stream_id_exclusive"]
  33. rows_inserted = progress.get("rows_inserted", 0)
  34. INSERT_CLUMP_SIZE = 1000
  35. TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
  36. def reindex_search_txn(txn):
  37. sql = (
  38. "SELECT stream_ordering, event_id FROM events"
  39. " WHERE ? <= stream_ordering AND stream_ordering < ?"
  40. " AND (%s)"
  41. " ORDER BY stream_ordering DESC"
  42. " LIMIT ?"
  43. ) % (" OR ".join("type = '%s'" % (t,) for t in TYPES),)
  44. txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size))
  45. rows = txn.fetchall()
  46. if not rows:
  47. return 0
  48. min_stream_id = rows[-1][0]
  49. event_ids = [row[1] for row in rows]
  50. events = self._get_events_txn(txn, event_ids)
  51. event_search_rows = []
  52. for event in events:
  53. try:
  54. event_id = event.event_id
  55. room_id = event.room_id
  56. content = event.content
  57. if event.type == "m.room.message":
  58. key = "content.body"
  59. value = content["body"]
  60. elif event.type == "m.room.topic":
  61. key = "content.topic"
  62. value = content["topic"]
  63. elif event.type == "m.room.name":
  64. key = "content.name"
  65. value = content["name"]
  66. except (KeyError, AttributeError):
  67. # If the event is missing a necessary field then
  68. # skip over it.
  69. continue
  70. if not isinstance(value, basestring):
  71. # If the event body, name or topic isn't a string
  72. # then skip over it
  73. continue
  74. event_search_rows.append((event_id, room_id, key, value))
  75. if isinstance(self.database_engine, PostgresEngine):
  76. sql = (
  77. "INSERT INTO event_search (event_id, room_id, key, vector)"
  78. " VALUES (?,?,?,to_tsvector('english', ?))"
  79. )
  80. elif isinstance(self.database_engine, Sqlite3Engine):
  81. sql = (
  82. "INSERT INTO event_search (event_id, room_id, key, value)"
  83. " VALUES (?,?,?,?)"
  84. )
  85. else:
  86. # This should be unreachable.
  87. raise Exception("Unrecognized database engine")
  88. for index in range(0, len(event_search_rows), INSERT_CLUMP_SIZE):
  89. clump = event_search_rows[index:index + INSERT_CLUMP_SIZE]
  90. txn.executemany(sql, clump)
  91. progress = {
  92. "target_min_stream_id_inclusive": target_min_stream_id,
  93. "max_stream_id_exclusive": min_stream_id,
  94. "rows_inserted": rows_inserted + len(event_search_rows)
  95. }
  96. self._background_update_progress_txn(
  97. txn, self.EVENT_SEARCH_UPDATE_NAME, progress
  98. )
  99. return len(event_search_rows)
  100. result = yield self.runInteraction(
  101. self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
  102. )
  103. if not result:
  104. yield self._end_background_update(self.EVENT_SEARCH_UPDATE_NAME)
  105. defer.returnValue(result)
  106. @defer.inlineCallbacks
  107. def search_msgs(self, room_ids, search_term, keys):
  108. """Performs a full text search over events with given keys.
  109. Args:
  110. room_ids (list): List of room ids to search in
  111. search_term (str): Search term to search for
  112. keys (list): List of keys to search in, currently supports
  113. "content.body", "content.name", "content.topic"
  114. Returns:
  115. list of dicts
  116. """
  117. clauses = []
  118. search_query = search_query = _parse_query(self.database_engine, search_term)
  119. args = []
  120. # Make sure we don't explode because the person is in too many rooms.
  121. # We filter the results below regardless.
  122. if len(room_ids) < 500:
  123. clauses.append(
  124. "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
  125. )
  126. args.extend(room_ids)
  127. local_clauses = []
  128. for key in keys:
  129. local_clauses.append("key = ?")
  130. args.append(key)
  131. clauses.append(
  132. "(%s)" % (" OR ".join(local_clauses),)
  133. )
  134. count_args = args
  135. count_clauses = clauses
  136. if isinstance(self.database_engine, PostgresEngine):
  137. sql = (
  138. "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) AS rank,"
  139. " room_id, event_id"
  140. " FROM event_search"
  141. " WHERE vector @@ to_tsquery('english', ?)"
  142. )
  143. args = [search_query, search_query] + args
  144. count_sql = (
  145. "SELECT room_id, count(*) as count FROM event_search"
  146. " WHERE vector @@ to_tsquery('english', ?)"
  147. )
  148. count_args = [search_query] + count_args
  149. elif isinstance(self.database_engine, Sqlite3Engine):
  150. sql = (
  151. "SELECT rank(matchinfo(event_search)) as rank, room_id, event_id"
  152. " FROM event_search"
  153. " WHERE value MATCH ?"
  154. )
  155. args = [search_query] + args
  156. count_sql = (
  157. "SELECT room_id, count(*) as count FROM event_search"
  158. " WHERE value MATCH ?"
  159. )
  160. count_args = [search_term] + count_args
  161. else:
  162. # This should be unreachable.
  163. raise Exception("Unrecognized database engine")
  164. for clause in clauses:
  165. sql += " AND " + clause
  166. for clause in count_clauses:
  167. count_sql += " AND " + clause
  168. # We add an arbitrary limit here to ensure we don't try to pull the
  169. # entire table from the database.
  170. sql += " ORDER BY rank DESC LIMIT 500"
  171. results = yield self._execute(
  172. "search_msgs", self.cursor_to_dict, sql, *args
  173. )
  174. results = filter(lambda row: row["room_id"] in room_ids, results)
  175. events = yield self._get_events([r["event_id"] for r in results])
  176. event_map = {
  177. ev.event_id: ev
  178. for ev in events
  179. }
  180. highlights = None
  181. if isinstance(self.database_engine, PostgresEngine):
  182. highlights = yield self._find_highlights_in_postgres(search_query, events)
  183. count_sql += " GROUP BY room_id"
  184. count_results = yield self._execute(
  185. "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
  186. )
  187. count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
  188. defer.returnValue({
  189. "results": [
  190. {
  191. "event": event_map[r["event_id"]],
  192. "rank": r["rank"],
  193. }
  194. for r in results
  195. if r["event_id"] in event_map
  196. ],
  197. "highlights": highlights,
  198. "count": count,
  199. })
  200. @defer.inlineCallbacks
  201. def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
  202. """Performs a full text search over events with given keys.
  203. Args:
  204. room_id (list): The room_ids to search in
  205. search_term (str): Search term to search for
  206. keys (list): List of keys to search in, currently supports
  207. "content.body", "content.name", "content.topic"
  208. pagination_token (str): A pagination token previously returned
  209. Returns:
  210. list of dicts
  211. """
  212. clauses = []
  213. search_query = search_query = _parse_query(self.database_engine, search_term)
  214. args = []
  215. # Make sure we don't explode because the person is in too many rooms.
  216. # We filter the results below regardless.
  217. if len(room_ids) < 500:
  218. clauses.append(
  219. "room_id IN (%s)" % (",".join(["?"] * len(room_ids)),)
  220. )
  221. args.extend(room_ids)
  222. local_clauses = []
  223. for key in keys:
  224. local_clauses.append("key = ?")
  225. args.append(key)
  226. clauses.append(
  227. "(%s)" % (" OR ".join(local_clauses),)
  228. )
  229. # take copies of the current args and clauses lists, before adding
  230. # pagination clauses to main query.
  231. count_args = list(args)
  232. count_clauses = list(clauses)
  233. if pagination_token:
  234. try:
  235. origin_server_ts, stream = pagination_token.split(",")
  236. origin_server_ts = int(origin_server_ts)
  237. stream = int(stream)
  238. except:
  239. raise SynapseError(400, "Invalid pagination token")
  240. clauses.append(
  241. "(origin_server_ts < ?"
  242. " OR (origin_server_ts = ? AND stream_ordering < ?))"
  243. )
  244. args.extend([origin_server_ts, origin_server_ts, stream])
  245. if isinstance(self.database_engine, PostgresEngine):
  246. sql = (
  247. "SELECT ts_rank_cd(vector, to_tsquery('english', ?)) as rank,"
  248. " origin_server_ts, stream_ordering, room_id, event_id"
  249. " FROM event_search"
  250. " NATURAL JOIN events"
  251. " WHERE vector @@ to_tsquery('english', ?) AND "
  252. )
  253. args = [search_query, search_query] + args
  254. count_sql = (
  255. "SELECT room_id, count(*) as count FROM event_search"
  256. " WHERE vector @@ to_tsquery('english', ?) AND "
  257. )
  258. count_args = [search_query] + count_args
  259. elif isinstance(self.database_engine, Sqlite3Engine):
  260. # We use CROSS JOIN here to ensure we use the right indexes.
  261. # https://sqlite.org/optoverview.html#crossjoin
  262. #
  263. # We want to use the full text search index on event_search to
  264. # extract all possible matches first, then lookup those matches
  265. # in the events table to get the topological ordering. We need
  266. # to use the indexes in this order because sqlite refuses to
  267. # MATCH unless it uses the full text search index
  268. sql = (
  269. "SELECT rank(matchinfo) as rank, room_id, event_id,"
  270. " origin_server_ts, stream_ordering"
  271. " FROM (SELECT key, event_id, matchinfo(event_search) as matchinfo"
  272. " FROM event_search"
  273. " WHERE value MATCH ?"
  274. " )"
  275. " CROSS JOIN events USING (event_id)"
  276. " WHERE "
  277. )
  278. args = [search_query] + args
  279. count_sql = (
  280. "SELECT room_id, count(*) as count FROM event_search"
  281. " WHERE value MATCH ? AND "
  282. )
  283. count_args = [search_term] + count_args
  284. else:
  285. # This should be unreachable.
  286. raise Exception("Unrecognized database engine")
  287. sql += " AND ".join(clauses)
  288. count_sql += " AND ".join(count_clauses)
  289. # We add an arbitrary limit here to ensure we don't try to pull the
  290. # entire table from the database.
  291. sql += " ORDER BY origin_server_ts DESC, stream_ordering DESC LIMIT ?"
  292. args.append(limit)
  293. results = yield self._execute(
  294. "search_rooms", self.cursor_to_dict, sql, *args
  295. )
  296. results = filter(lambda row: row["room_id"] in room_ids, results)
  297. events = yield self._get_events([r["event_id"] for r in results])
  298. event_map = {
  299. ev.event_id: ev
  300. for ev in events
  301. }
  302. highlights = None
  303. if isinstance(self.database_engine, PostgresEngine):
  304. highlights = yield self._find_highlights_in_postgres(search_query, events)
  305. count_sql += " GROUP BY room_id"
  306. count_results = yield self._execute(
  307. "search_rooms_count", self.cursor_to_dict, count_sql, *count_args
  308. )
  309. count = sum(row["count"] for row in count_results if row["room_id"] in room_ids)
  310. defer.returnValue({
  311. "results": [
  312. {
  313. "event": event_map[r["event_id"]],
  314. "rank": r["rank"],
  315. "pagination_token": "%s,%s" % (
  316. r["origin_server_ts"], r["stream_ordering"]
  317. ),
  318. }
  319. for r in results
  320. if r["event_id"] in event_map
  321. ],
  322. "highlights": highlights,
  323. "count": count,
  324. })
  325. def _find_highlights_in_postgres(self, search_query, events):
  326. """Given a list of events and a search term, return a list of words
  327. that match from the content of the event.
  328. This is used to give a list of words that clients can match against to
  329. highlight the matching parts.
  330. Args:
  331. search_query (str)
  332. events (list): A list of events
  333. Returns:
  334. deferred : A set of strings.
  335. """
  336. def f(txn):
  337. highlight_words = set()
  338. for event in events:
  339. # As a hack we simply join values of all possible keys. This is
  340. # fine since we're only using them to find possible highlights.
  341. values = []
  342. for key in ("body", "name", "topic"):
  343. v = event.content.get(key, None)
  344. if v:
  345. values.append(v)
  346. if not values:
  347. continue
  348. value = " ".join(values)
  349. # We need to find some values for StartSel and StopSel that
  350. # aren't in the value so that we can pick results out.
  351. start_sel = "<"
  352. stop_sel = ">"
  353. while start_sel in value:
  354. start_sel += "<"
  355. while stop_sel in value:
  356. stop_sel += ">"
  357. query = "SELECT ts_headline(?, to_tsquery('english', ?), %s)" % (
  358. _to_postgres_options({
  359. "StartSel": start_sel,
  360. "StopSel": stop_sel,
  361. "MaxFragments": "50",
  362. })
  363. )
  364. txn.execute(query, (value, search_query,))
  365. headline, = txn.fetchall()[0]
  366. # Now we need to pick the possible highlights out of the haedline
  367. # result.
  368. matcher_regex = "%s(.*?)%s" % (
  369. re.escape(start_sel),
  370. re.escape(stop_sel),
  371. )
  372. res = re.findall(matcher_regex, headline)
  373. highlight_words.update([r.lower() for r in res])
  374. return highlight_words
  375. return self.runInteraction("_find_highlights", f)
  376. def _to_postgres_options(options_dict):
  377. return "'%s'" % (
  378. ",".join("%s=%s" % (k, v) for k, v in options_dict.items()),
  379. )
  380. def _parse_query(database_engine, search_term):
  381. """Takes a plain unicode string from the user and converts it into a form
  382. that can be passed to database.
  383. We use this so that we can add prefix matching, which isn't something
  384. that is supported by default.
  385. """
  386. # Pull out the individual words, discarding any non-word characters.
  387. results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
  388. if isinstance(database_engine, PostgresEngine):
  389. return " & ".join(result + ":*" for result in results)
  390. elif isinstance(database_engine, Sqlite3Engine):
  391. return " & ".join(result + "*" for result in results)
  392. else:
  393. # This should be unreachable.
  394. raise Exception("Unrecognized database engine")