invite_tokens.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. class JoinTokenStore(object):
  17. def __init__(self, sydent):
  18. self.sydent = sydent
  19. def storeToken(self, medium, address, roomId, sender, token, originServer=None, originId=None, commit=True):
  20. """Stores an invite token.
  21. :param medium: The medium of the token.
  22. :type medium: str
  23. :param address: The address of the token.
  24. :type address: str
  25. :param roomId: The room ID this token is tied to.
  26. :type roomId: str
  27. :param sender: The sender of the invite.
  28. :type sender: str
  29. :param token: The token itself.
  30. :type token: str
  31. :param originServer: The server this invite originated from (if
  32. coming from replication).
  33. :type originServer: str, None
  34. :param originId: The id of the token in the DB of originServer. Used
  35. for determining if we've already received a token or not.
  36. :type originId: int, None
  37. :param commit: Whether DB changes should be committed by this
  38. function (or an external one).
  39. :type commit: bool
  40. """
  41. if originId and originServer:
  42. # Check if we've already seen this association from this server
  43. last_processed_id = tokensStore.getLastTokenIdFromServer(originServer)
  44. if int(originId) <= int(last_processed_id):
  45. logger.info("We have already seen token ID %s from %s. Ignoring.", originId, originServer)
  46. return
  47. cur = self.sydent.db.cursor()
  48. cur.execute("INSERT INTO invite_tokens"
  49. " ('medium', 'address', 'room_id', 'sender', 'token', 'received_ts', 'origin_server', 'origin_id')"
  50. " VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
  51. (medium, address, roomId, sender, token, int(time.time()), originServer, originId))
  52. if commit:
  53. self.sydent.db.commit()
  54. def getTokens(self, medium, address):
  55. """Retrieve the invite token(s) for a given 3PID medium and address.
  56. :param medium: The medium of the 3PID.
  57. :type medium: str
  58. :param address: The address of the 3PID.
  59. :type address: str
  60. :returns a list of invite tokens, or an empty list if no tokens found.
  61. :rtype: list[Dict[str, str]]
  62. """
  63. cur = self.sydent.db.cursor()
  64. res = cur.execute(
  65. "SELECT medium, address, room_id, sender, token FROM invite_tokens"
  66. " WHERE medium = ? AND address = ?",
  67. (medium, address,)
  68. )
  69. rows = res.fetchall()
  70. ret = []
  71. for row in rows:
  72. medium, address, roomId, sender, token = row
  73. ret.append({
  74. "medium": medium,
  75. "address": address,
  76. "room_id": roomId,
  77. "sender": sender,
  78. "token": token,
  79. })
  80. return ret
  81. def getInviteTokensAfterId(self, afterId, limit):
  82. """Retrieves max `limit` invite tokens after a given DB id.
  83. :param afterId: A database id to act as an offset. Tokens after this
  84. id are returned.
  85. :type afterId: int
  86. :param limit: Max amount of database rows to return.
  87. :type limit: int, None
  88. :returns a tuple consisting of a dict of invite tokens (with key
  89. being the token's DB id) and the maximum DB id that was extracted.
  90. Otherwise returns ({}, None) if no tokens are found.
  91. :rtype: Tuple[Dict[int, Dict], int|None]
  92. """
  93. cur = self.sydent.db.cursor()
  94. res = cur.execute(
  95. "SELECT id, medium, address, room_id, sender, token FROM invite_tokens"
  96. " WHERE id > ? AND origin_id = NULL LIMIT ?",
  97. (afterId, limit,)
  98. )
  99. rows = res.fetchall()
  100. # Dict of "id": {content}
  101. invite_tokens = {}
  102. maxId = None
  103. for row in rows:
  104. maxId, medium, address, room_id, sender, token = row
  105. invite_tokens[maxId] = {
  106. "origin_id": maxId,
  107. "medium": medium,
  108. "address": address,
  109. "room_id": room_id,
  110. "sender": sender,
  111. "token": token,
  112. }
  113. return (invite_tokens, maxId)
  114. def getLastTokenIdFromServer(self, server):
  115. """Returns the last known invite token that was received from the
  116. given server.
  117. :param server: The name of the origin server.
  118. :type server: str
  119. :returns a database id marking the last known invite token received
  120. from the given server. Returns 0 if no tokens have been received from
  121. this server.
  122. :rtype: int
  123. """
  124. cur = self.sydent.db.cursor()
  125. res = cur.execute("select max(origin_id), count(origin_id) from invite_tokens"
  126. " where origin_server = ?", (server,))
  127. row = res.fetchone()
  128. if row[1] == 0:
  129. return 0
  130. return row[0]
  131. def markTokensAsSent(self, medium, address):
  132. """Mark invite tokens as sent.
  133. :param medium: The medium of the token.
  134. :type medium: str
  135. :param address: The address of the token.
  136. :type address: str
  137. """
  138. cur = self.sydent.db.cursor()
  139. cur.execute(
  140. "UPDATE invite_tokens SET sent_ts = ? WHERE medium = ? AND address = ?",
  141. (int(time.time()), medium, address,)
  142. )
  143. self.sydent.db.commit()
  144. def storeEphemeralPublicKey(self, publicKey, persistenceTs=None, originServer=None, originId=None, commit=True):
  145. """Stores an ephemeral public key in the database.
  146. :param publicKey: the ephemeral public key to store.
  147. :type publicKey: str
  148. :param persistenceTs:
  149. :type persistenceTs: int
  150. :param originServer: the server this key was received from (if
  151. retrieved through replication).
  152. :type originServer: str
  153. :param originId: The id of the key in the DB of originServer. Used
  154. for determining if we've already received a key or not.
  155. :type originId: int
  156. :param commit: Whether DB changes should be committed by this
  157. function (or an external one).
  158. :type commit: bool
  159. """
  160. if originId and originServer:
  161. # Check if we've already seen this association from this server
  162. last_processed_id = tokensStore.getLastEphemeralPublicKeyIdFromServer(originServer)
  163. if int(originId) <= int(last_processed_id):
  164. logger.info("We have already seen key ID %s from %s. Ignoring.", originId, originServer)
  165. return
  166. if not persistenceTs:
  167. persistenceTs = int(time.time())
  168. cur = self.sydent.db.cursor()
  169. cur.execute(
  170. "INSERT INTO ephemeral_public_keys"
  171. " (public_key, persistence_ts, origin_server, origin_id)"
  172. " VALUES (?, ?, ?, ?)",
  173. (publicKey, persistenceTs, originServer, originId)
  174. )
  175. if commit:
  176. self.sydent.db.commit()
  177. def validateEphemeralPublicKey(self, publicKey):
  178. """Mark an ephemeral public key as validated.
  179. :param publicKey: An ephemeral public key.
  180. :type publicKey: str
  181. :returns true or false depending on whether validation was
  182. successful.
  183. :rtype: bool
  184. """
  185. cur = self.sydent.db.cursor()
  186. cur.execute(
  187. "UPDATE ephemeral_public_keys"
  188. " SET verify_count = verify_count + 1"
  189. " WHERE public_key = ?",
  190. (publicKey,)
  191. )
  192. self.sydent.db.commit()
  193. return cur.rowcount > 0
  194. def getEphemeralPublicKeysAfterId(self, afterId, limit):
  195. """Retrieves max `limit` ephemeral public keys after a given DB id.
  196. :param afterId: A database id to act as an offset. Keys after this id
  197. are returned.
  198. :type afterId: int
  199. :param limit: Max amount of database rows to return.
  200. :type limit: int
  201. :returns a tuple consisting of a list of ephemeral public keys (with
  202. key being the token's DB id) and the maximum table id that was
  203. extracted. Otherwise returns ({}, None) if no keys are found.
  204. :rtype: Tuple[Dict[int, Dict], int|None]
  205. """
  206. cur = self.sydent.db.cursor()
  207. res = cur.execute(
  208. "SELECT id, public_key, verify_count, persistence_ts FROM ephemeral_public_keys"
  209. " WHERE id > ? AND origin_id = NULL LIMIT ?",
  210. (afterId, limit,)
  211. )
  212. rows = res.fetchall()
  213. # Dict of "id": {content}
  214. ephemeral_keys = {}
  215. maxId = None
  216. for row in rows:
  217. maxId, public_key, verify_count, persistence_ts = row
  218. ephemeral_keys[maxId] = {
  219. "public_key": public_key,
  220. "verify_count": verify_count,
  221. "persistence_ts": persistence_ts,
  222. }
  223. return (ephemeral_keys, maxId)
  224. def getLastEphemeralPublicKeyIdFromServer(self, server):
  225. """Returns the last known ephemeral public key that was received from
  226. the given server.
  227. :param server: The name of the origin server.
  228. :type server: str
  229. :returns the last known DB id received from the given server, or 0 if
  230. none have been received.
  231. :rtype: int
  232. """
  233. cur = self.sydent.db.cursor()
  234. res = cur.execute("select max(origin_id),count(origin_id) from ephemeral_public_keys"
  235. " where origin_server = ?", (server,))
  236. row = res.fetchone()
  237. if not row or row[1] == 0:
  238. return 0
  239. return row[0]
  240. def getSenderForToken(self, token):
  241. """Returns the sender for a given invite token.
  242. :param token: The invite token.
  243. :type token: str
  244. :returns the sender of a given invite token or None if there isn't
  245. one.
  246. :rtype: str, None
  247. """
  248. cur = self.sydent.db.cursor()
  249. res = cur.execute(
  250. "SELECT sender FROM invite_tokens WHERE token = ?",
  251. (token,)
  252. )
  253. rows = res.fetchall()
  254. if rows:
  255. return rows[0][0]
  256. return None