hashing_metadata.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 The Matrix.org Foundation C.I.C.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # Actions on the hashing_metadata table which is defined in the migration process in
  16. # sqlitedb.py
  17. class HashingMetadataStore:
  18. def __init__(self, sydent):
  19. self.sydent = sydent
  20. def get_lookup_pepper(self):
  21. """Return the value of the current lookup pepper from the db
  22. :return: A pepper if it exists in the database, or None if one does
  23. not exist
  24. :rtype: unicode or None
  25. """
  26. cur = self.sydent.db.cursor()
  27. res = cur.execute("select lookup_pepper from hashing_metadata")
  28. row = res.fetchone()
  29. if not row:
  30. return None
  31. pepper = row[0]
  32. # Ensure we're dealing with unicode.
  33. if isinstance(pepper, bytes):
  34. pepper = pepper.decode("UTF-8")
  35. return pepper
  36. def store_lookup_pepper(self, hashing_function, pepper):
  37. """Stores a new lookup pepper in the hashing_metadata db table and rehashes all 3PIDs
  38. :param hashing_function: A function with single input and output strings
  39. :type hashing_function func(str) -> str
  40. :param pepper: The pepper to store in the database
  41. :type pepper: str
  42. """
  43. cur = self.sydent.db.cursor()
  44. # Create or update lookup_pepper
  45. sql = (
  46. 'INSERT OR REPLACE INTO hashing_metadata (id, lookup_pepper) '
  47. 'VALUES (0, ?)'
  48. )
  49. cur.execute(sql, (pepper,))
  50. # Hand the cursor to each rehashing function
  51. # Each function will queue some rehashing db transactions
  52. self._rehash_threepids(cur, hashing_function, pepper, "local_threepid_associations")
  53. self._rehash_threepids(cur, hashing_function, pepper, "global_threepid_associations")
  54. # Commit the queued db transactions so that adding a new pepper and hashing is atomic
  55. self.sydent.db.commit()
  56. def _rehash_threepids(self, cur, hashing_function, pepper, table):
  57. """Rehash 3PIDs of a given table using a given hashing_function and pepper
  58. A database cursor `cur` must be passed to this function. After this function completes,
  59. the calling function should make sure to call self`self.sydent.db.commit()` to commit
  60. the made changes to the database.
  61. :param cur: Database cursor
  62. :type cur:
  63. :param hashing_function: A function with single input and output strings
  64. :type hashing_function func(str) -> str
  65. :param pepper: A pepper to append to the end of the 3PID (after a space) before hashing
  66. :type pepper: str
  67. :param table: The database table to perform the rehashing on
  68. :type table: str
  69. """
  70. # Get count of all 3PID records
  71. # Medium/address combos are marked as UNIQUE in the database
  72. sql = "SELECT COUNT(*) FROM %s" % table
  73. res = cur.execute(sql)
  74. row_count = res.fetchone()
  75. row_count = row_count[0]
  76. # Iterate through each medium, address combo, hash it,
  77. # and store in the db
  78. batch_size = 500
  79. count = 0
  80. while count < row_count:
  81. sql = (
  82. "SELECT medium, address FROM %s ORDER BY id LIMIT %s OFFSET %s" %
  83. (table, batch_size, count)
  84. )
  85. res = cur.execute(sql)
  86. rows = res.fetchall()
  87. for medium, address in rows:
  88. # Skip broken db entry
  89. if not medium or not address:
  90. continue
  91. # Combine the medium, address and pepper together in the
  92. # following form: "address medium pepper"
  93. # According to MSC2134: https://github.com/matrix-org/matrix-doc/pull/2134
  94. combo = "%s %s %s" % (address, medium, pepper)
  95. # Hash the resulting string
  96. result = hashing_function(combo)
  97. # Save the result to the DB
  98. sql = (
  99. "UPDATE %s SET lookup_hash = ? "
  100. "WHERE medium = ? AND address = ?"
  101. % table
  102. )
  103. # Lines up the query to be executed on commit
  104. cur.execute(sql, (result, medium, address))
  105. count += len(rows)