threepid_associations.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014,2017 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from __future__ import absolute_import
  16. from sydent.util import time_msec
  17. from sydent.threepid import ThreepidAssociation
  18. from sydent.threepid.signer import Signer
  19. import logging
  20. logger = logging.getLogger(__name__)
  21. class LocalAssociationStore:
  22. def __init__(self, sydent):
  23. self.sydent = sydent
  24. def addOrUpdateAssociation(self, assoc):
  25. """
  26. Updates an association, or creates one if none exists with these parameters.
  27. :param assoc: The association to create or update.
  28. :type assoc: ThreepidAssociation
  29. """
  30. cur = self.sydent.db.cursor()
  31. # sqlite's support for upserts is atrocious
  32. cur.execute("insert or replace into local_threepid_associations "
  33. "('medium', 'address', 'lookup_hash', 'mxid', 'ts', 'notBefore', 'notAfter')"
  34. " values (?, ?, ?, ?, ?, ?, ?)",
  35. (assoc.medium, assoc.address, assoc.lookup_hash, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after))
  36. self.sydent.db.commit()
  37. def getAssociationsAfterId(self, afterId, limit=None):
  38. """
  39. Retrieves every association after the given ID.
  40. :param afterId: The ID after which to retrieve associations.
  41. :type afterId: int
  42. :param limit: The maximum number of associations to retrieve, or None if no
  43. limit.
  44. :type limit: int or None
  45. :return: The retrieved associations (in a dict[id, assoc]), and the highest ID
  46. retrieved (or None if no ID thus no association was retrieved).
  47. :rtype: tuple[dict[int, ThreepidAssociation] or int or None]
  48. """
  49. cur = self.sydent.db.cursor()
  50. if afterId is None:
  51. afterId = -1
  52. q = "select id, medium, address, lookup_hash, mxid, ts, notBefore, notAfter from " \
  53. "local_threepid_associations " \
  54. "where id > ? order by id asc"
  55. if limit is not None:
  56. q += " limit ?"
  57. res = cur.execute(q, (afterId, limit))
  58. else:
  59. # No no, no no no no, no no no no, no no, there's no limit.
  60. res = cur.execute(q, (afterId,))
  61. maxId = None
  62. assocs = {}
  63. for row in res.fetchall():
  64. assoc = ThreepidAssociation(row[1], row[2], row[3], row[4], row[5], row[6], row[7])
  65. assocs[row[0]] = assoc
  66. maxId = row[0]
  67. return assocs, maxId
  68. def getSignedAssociationsAfterId(self, afterId, limit=None):
  69. """Get associations after a given ID, and sign them before returning
  70. :param afterId: The ID to return results after (not inclusive)
  71. :type afterId: int
  72. :param limit: The maximum amount of signed associations to return. None for no
  73. limit.
  74. :type limit: int|None
  75. :return: A tuple consisting of a dictionary containing the signed associations
  76. (id: assoc dict) and an int representing the maximum ID (which is None if
  77. there was no association to retrieve).
  78. :rtype: tuple[dict[int, dict[str, any]] or int or None]
  79. """
  80. assocs = {}
  81. (localAssocs, maxId) = self.getAssociationsAfterId(afterId, limit)
  82. signer = Signer(self.sydent)
  83. for localId in localAssocs:
  84. sgAssoc = signer.signedThreePidAssociation(localAssocs[localId])
  85. assocs[localId] = sgAssoc
  86. return assocs, maxId
  87. def removeAssociation(self, threepid, mxid):
  88. """
  89. Delete the association between a 3PID and a MXID, if it exists. If the
  90. association doesn't exist, log and do nothing.
  91. :param threepid: The 3PID of the binding to remove.
  92. :type threepid: dict[unicode, unicode]
  93. :param mxid: The MXID of the binding to remove.
  94. :type mxid: unicode
  95. """
  96. cur = self.sydent.db.cursor()
  97. # check to see if we have any matching associations first.
  98. # We use a REPLACE INTO because we need the resulting row to have
  99. # a new ID (such that we know it's a new change that needs to be
  100. # replicated) so there's no need to insert a deletion row if there's
  101. # nothing to delete.
  102. cur.execute(
  103. "SELECT COUNT(*) FROM local_threepid_associations "
  104. "WHERE medium = ? AND address = ? AND mxid = ?",
  105. (threepid['medium'], threepid['address'], mxid)
  106. )
  107. row = cur.fetchone()
  108. if row[0] > 0:
  109. ts = time_msec()
  110. cur.execute(
  111. "REPLACE INTO local_threepid_associations "
  112. "('medium', 'address', 'mxid', 'ts', 'notBefore', 'notAfter') "
  113. " values (?, ?, NULL, ?, null, null)",
  114. (threepid['medium'], threepid['address'], ts),
  115. )
  116. logger.info(
  117. "Deleting local assoc for %s/%s/%s replaced %d rows",
  118. threepid['medium'], threepid['address'], mxid, cur.rowcount,
  119. )
  120. self.sydent.db.commit()
  121. else:
  122. logger.info(
  123. "No local assoc found for %s/%s/%s",
  124. threepid['medium'], threepid['address'], mxid,
  125. )
  126. # we still consider this successful in the name of idempotency:
  127. # the binding to be deleted is not there, so we're in the desired state.
  128. class GlobalAssociationStore:
  129. def __init__(self, sydent):
  130. self.sydent = sydent
  131. def signedAssociationStringForThreepid(self, medium, address):
  132. """
  133. Retrieve the JSON for the signed association matching the provided 3PID,
  134. if one exists.
  135. :param medium: The medium of the 3PID.
  136. :type medium: unicode
  137. :param address: The address of the 3PID.
  138. :type address: unicode
  139. :return: The signed association, or None if no association was found for this
  140. 3PID.
  141. :rtype: unicode or None
  142. """
  143. cur = self.sydent.db.cursor()
  144. # We treat address as case-insensitive because that's true for all the
  145. # threepids we have currently (we treat the local part of email addresses as
  146. # case insensitive which is technically incorrect). If we someday get a
  147. # case-sensitive threepid, this can change.
  148. res = cur.execute("select sgAssoc from global_threepid_associations where "
  149. "medium = ? and lower(address) = lower(?) and notBefore < ? and notAfter > ? "
  150. "order by ts desc limit 1",
  151. (medium, address, time_msec(), time_msec()))
  152. row = res.fetchone()
  153. if not row:
  154. return None
  155. sgAssocStr = row[0]
  156. return sgAssocStr
  157. def getMxid(self, medium, address):
  158. """
  159. Retrieves the MXID associated with a 3PID.
  160. :param medium: The medium of the 3PID.
  161. :type medium: unicode
  162. :param address: The address of the 3PID.
  163. :type address: unicode
  164. :return: The associated MXID, or None if no MXID is associated with this 3PID.
  165. :rtype: unicode or None
  166. """
  167. cur = self.sydent.db.cursor()
  168. res = cur.execute("select mxid from global_threepid_associations where "
  169. "medium = ? and lower(address) = lower(?) and notBefore < ? and notAfter > ? "
  170. "order by ts desc limit 1",
  171. (medium, address, time_msec(), time_msec()))
  172. row = res.fetchone()
  173. if not row:
  174. return None
  175. return row[0]
  176. def getMxids(self, threepid_tuples):
  177. """Given a list of threepid_tuples, return the same list but with
  178. mxids appended to each tuple for which a match was found in the
  179. database for. Output is ordered by medium, address, timestamp DESC
  180. :param threepid_tuples: List containing (medium, address) tuples
  181. :type threepid_tuples: list[tuple[unicode]]
  182. :return: a list of (medium, address, mxid) tuples
  183. :rtype: list[tuple[unicode]]
  184. """
  185. cur = self.sydent.db.cursor()
  186. cur.execute("CREATE TEMPORARY TABLE tmp_getmxids (medium VARCHAR(16), address VARCHAR(256))")
  187. cur.execute("CREATE INDEX tmp_getmxids_medium_lower_address ON tmp_getmxids (medium, lower(address))")
  188. try:
  189. inserted_cap = 0
  190. while inserted_cap < len(threepid_tuples):
  191. cur.executemany(
  192. "INSERT INTO tmp_getmxids (medium, address) VALUES (?, ?)",
  193. threepid_tuples[inserted_cap:inserted_cap + 500]
  194. )
  195. inserted_cap += 500
  196. res = cur.execute(
  197. # 'notBefore' is the time the association starts being valid, 'notAfter' the the time at which
  198. # it ceases to be valid, so the ts must be greater than 'notBefore' and less than 'notAfter'.
  199. "SELECT gte.medium, gte.address, gte.ts, gte.mxid FROM global_threepid_associations gte "
  200. "JOIN tmp_getmxids ON gte.medium = tmp_getmxids.medium AND lower(gte.address) = lower(tmp_getmxids.address) "
  201. "WHERE gte.notBefore < ? AND gte.notAfter > ? "
  202. "ORDER BY gte.medium, gte.address, gte.ts DESC",
  203. (time_msec(), time_msec())
  204. )
  205. results = []
  206. current = ()
  207. for row in res.fetchall():
  208. # only use the most recent entry for each
  209. # threepid (they're sorted by ts)
  210. if (row[0], row[1]) == current:
  211. continue
  212. current = (row[0], row[1])
  213. results.append((row[0], row[1], row[3]))
  214. finally:
  215. cur.execute("DROP TABLE tmp_getmxids")
  216. return results
  217. def addAssociation(self, assoc, rawSgAssoc, originServer, originId, commit=True):
  218. """
  219. Saves an association received through either a replication push or a local push.
  220. :param assoc: The association to add as a high level object.
  221. :type assoc: sydent.threepid.ThreepidAssociation
  222. :param rawSgAssoc: The original raw bytes of the signed association.
  223. :type rawSgAssoc: dict[str, any]
  224. :param originServer: The name of the server the association was created on.
  225. :type originServer: str
  226. :param originId: The ID of the association on the server the association was
  227. created on.
  228. :type originId: int
  229. :param commit: Whether to commit the database transaction after inserting the
  230. association.
  231. :type commit: bool
  232. """
  233. cur = self.sydent.db.cursor()
  234. cur.execute("insert or ignore into global_threepid_associations "
  235. "(medium, address, lookup_hash, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) values "
  236. "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
  237. (assoc.medium, assoc.address, assoc.lookup_hash, assoc.mxid, assoc.ts, assoc.not_before, assoc.not_after,
  238. originServer, originId, rawSgAssoc))
  239. if commit:
  240. self.sydent.db.commit()
  241. def lastIdFromServer(self, server):
  242. """
  243. Retrieves the ID of the last association received from the given peer.
  244. :param server:
  245. :type server: str
  246. :return: The the ID of the last association received from the peer, or None if
  247. no association has ever been received from that peer.
  248. :rtype: int or None
  249. """
  250. cur = self.sydent.db.cursor()
  251. res = cur.execute("select max(originId),count(originId) from global_threepid_associations "
  252. "where originServer = ?", (server,))
  253. row = res.fetchone()
  254. if row[1] == 0:
  255. return None
  256. return row[0]
  257. def removeAssociation(self, medium, address):
  258. """
  259. Removes any association stored for the provided 3PID.
  260. :param medium: The medium for the 3PID.
  261. :type medium: unicode
  262. :param address: The address for the 3PID.
  263. :type address: unicode
  264. """
  265. cur = self.sydent.db.cursor()
  266. cur.execute(
  267. "DELETE FROM global_threepid_associations WHERE "
  268. "medium = ? AND address = ?",
  269. (medium, address),
  270. )
  271. logger.info(
  272. "Deleted %d rows from global associations for %s/%s",
  273. cur.rowcount, medium, address,
  274. )
  275. self.sydent.db.commit()
  276. def retrieveMxidsForHashes(self, addresses):
  277. """Returns a mapping from hash: mxid from a list of given lookup_hash values
  278. :param addresses: An array of lookup_hash values to check against the db
  279. :type addresses: list[unicode]
  280. :returns a dictionary of lookup_hash values to mxids of all discovered matches
  281. :rtype: dict[unicode, unicode]
  282. """
  283. cur = self.sydent.db.cursor()
  284. cur.execute("CREATE TEMPORARY TABLE tmp_retrieve_mxids_for_hashes "
  285. "(lookup_hash VARCHAR)")
  286. cur.execute("CREATE INDEX tmp_retrieve_mxids_for_hashes_lookup_hash ON "
  287. "tmp_retrieve_mxids_for_hashes(lookup_hash)")
  288. results = {}
  289. try:
  290. # Convert list of addresses to list of tuples of addresses
  291. addresses = [(x,) for x in addresses]
  292. inserted_cap = 0
  293. while inserted_cap < len(addresses):
  294. cur.executemany(
  295. "INSERT INTO tmp_retrieve_mxids_for_hashes(lookup_hash) "
  296. "VALUES (?)",
  297. addresses[inserted_cap:inserted_cap + 500]
  298. )
  299. inserted_cap += 500
  300. res = cur.execute(
  301. # 'notBefore' is the time the association starts being valid, 'notAfter' the the time at which
  302. # it ceases to be valid, so the ts must be greater than 'notBefore' and less than 'notAfter'.
  303. "SELECT gta.lookup_hash, gta.mxid FROM global_threepid_associations gta "
  304. "JOIN tmp_retrieve_mxids_for_hashes "
  305. "ON gta.lookup_hash = tmp_retrieve_mxids_for_hashes.lookup_hash "
  306. "WHERE gta.notBefore < ? AND gta.notAfter > ? "
  307. "ORDER BY gta.lookup_hash, gta.mxid, gta.ts",
  308. (time_msec(), time_msec())
  309. )
  310. # Place the results from the query into a dictionary
  311. # Results are sorted from oldest to newest, so if there are multiple mxid's for
  312. # the same lookup hash, only the newest mapping will be returned
  313. for lookup_hash, mxid in res.fetchall():
  314. results[lookup_hash] = mxid
  315. finally:
  316. cur.execute("DROP TABLE tmp_retrieve_mxids_for_hashes")
  317. return results