upgrade_db_to_v0.6.0.py 9.7 KB


  1. from synapse.storage import SCHEMA_VERSION, read_schema
  2. from synapse.storage._base import SQLBaseStore
  3. from synapse.storage.signatures import SignatureStore
  4. from synapse.storage.event_federation import EventFederationStore
  5. from syutil.base64util import encode_base64, decode_base64
  6. from synapse.crypto.event_signing import compute_event_signature
  7. from synapse.events.builder import EventBuilder
  8. from synapse.events.utils import prune_event
  9. from synapse.crypto.event_signing import check_event_content_hash
  10. from syutil.crypto.jsonsign import (
  11. verify_signed_json, SignatureVerifyException,
  12. )
  13. from syutil.crypto.signing_key import decode_verify_key_bytes
  14. from syutil.jsonutil import encode_canonical_json
  15. import argparse
  16. # import dns.resolver
  17. import hashlib
  18. import httplib
  19. import json
  20. import sqlite3
  21. import syutil
  22. import urllib2
  23. delta_sql = """
  24. CREATE TABLE IF NOT EXISTS event_json(
  25. event_id TEXT NOT NULL,
  26. room_id TEXT NOT NULL,
  27. internal_metadata NOT NULL,
  28. json BLOB NOT NULL,
  29. CONSTRAINT ev_j_uniq UNIQUE (event_id)
  30. );
  31. CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
  32. CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
  33. PRAGMA user_version = 10;
  34. """
  35. class Store(object):
  36. _get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
  37. _get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
  38. _get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
  39. _get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
  40. _get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
  41. _get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
  42. cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
  43. _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
  44. _simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
  45. _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
  46. def _generate_event_json(self, txn, rows):
  47. events = []
  48. for row in rows:
  49. d = dict(row)
  50. d.pop("stream_ordering", None)
  51. d.pop("topological_ordering", None)
  52. d.pop("processed", None)
  53. if "origin_server_ts" not in d:
  54. d["origin_server_ts"] = d.pop("ts", 0)
  55. else:
  56. d.pop("ts", 0)
  57. d.pop("prev_state", None)
  58. d.update(json.loads(d.pop("unrecognized_keys")))
  59. d["sender"] = d.pop("user_id")
  60. d["content"] = json.loads(d["content"])
  61. if "age_ts" not in d:
  62. # For compatibility
  63. d["age_ts"] = d.get("origin_server_ts", 0)
  64. d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
  65. outlier = d.pop("outlier", False)
  66. # d.pop("membership", None)
  67. d.pop("state_hash", None)
  68. d.pop("replaces_state", None)
  69. b = EventBuilder(d)
  70. b.internal_metadata.outlier = outlier
  71. events.append(b)
  72. for i, ev in enumerate(events):
  73. signatures = self._get_event_signatures_txn(
  74. txn, ev.event_id,
  75. )
  76. ev.signatures = {
  77. n: {
  78. k: encode_base64(v) for k, v in s.items()
  79. }
  80. for n, s in signatures.items()
  81. }
  82. hashes = self._get_event_content_hashes_txn(
  83. txn, ev.event_id,
  84. )
  85. ev.hashes = {
  86. k: encode_base64(v) for k, v in hashes.items()
  87. }
  88. prevs = self._get_prev_events_and_state(txn, ev.event_id)
  89. ev.prev_events = [
  90. (e_id, h)
  91. for e_id, h, is_state in prevs
  92. if is_state == 0
  93. ]
  94. # ev.auth_events = self._get_auth_events(txn, ev.event_id)
  95. hashes = dict(ev.auth_events)
  96. for e_id, hash in ev.prev_events:
  97. if e_id in hashes and not hash:
  98. hash.update(hashes[e_id])
  99. #
  100. # if hasattr(ev, "state_key"):
  101. # ev.prev_state = [
  102. # (e_id, h)
  103. # for e_id, h, is_state in prevs
  104. # if is_state == 1
  105. # ]
  106. return [e.build() for e in events]
  107. store = Store()
  108. # def get_key(server_name):
  109. # print "Getting keys for: %s" % (server_name,)
  110. # targets = []
  111. # if ":" in server_name:
  112. # target, port = server_name.split(":")
  113. # targets.append((target, int(port)))
  114. # try:
  115. # answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
  116. # for srv in answers:
  117. # targets.append((srv.target, srv.port))
  118. # except dns.resolver.NXDOMAIN:
  119. # targets.append((server_name, 8448))
  120. # except:
  121. # print "Failed to lookup keys for %s" % (server_name,)
  122. # return {}
  123. #
  124. # for target, port in targets:
  125. # url = "https://%s:%i/_matrix/key/v1" % (target, port)
  126. # try:
  127. # keys = json.load(urllib2.urlopen(url, timeout=2))
  128. # verify_keys = {}
  129. # for key_id, key_base64 in keys["verify_keys"].items():
  130. # verify_key = decode_verify_key_bytes(
  131. # key_id, decode_base64(key_base64)
  132. # )
  133. # verify_signed_json(keys, server_name, verify_key)
  134. # verify_keys[key_id] = verify_key
  135. # print "Got keys for: %s" % (server_name,)
  136. # return verify_keys
  137. # except urllib2.URLError:
  138. # pass
  139. # except urllib2.HTTPError:
  140. # pass
  141. # except httplib.HTTPException:
  142. # pass
  143. #
  144. # print "Failed to get keys for %s" % (server_name,)
  145. # return {}
  146. def reinsert_events(cursor, server_name, signing_key):
  147. print "Running delta: v10"
  148. cursor.executescript(delta_sql)
  149. cursor.execute(
  150. "SELECT * FROM events ORDER BY rowid ASC"
  151. )
  152. print "Getting events..."
  153. rows = store.cursor_to_dict(cursor)
  154. events = store._generate_event_json(cursor, rows)
  155. print "Got events from DB."
  156. algorithms = {
  157. "sha256": hashlib.sha256,
  158. }
  159. key_id = "%s:%s" % (signing_key.alg, signing_key.version)
  160. verify_key = signing_key.verify_key
  161. verify_key.alg = signing_key.alg
  162. verify_key.version = signing_key.version
  163. server_keys = {
  164. server_name: {
  165. key_id: verify_key
  166. }
  167. }
  168. i = 0
  169. N = len(events)
  170. for event in events:
  171. if i % 100 == 0:
  172. print "Processed: %d/%d events" % (i,N,)
  173. i += 1
  174. # for alg_name in event.hashes:
  175. # if check_event_content_hash(event, algorithms[alg_name]):
  176. # pass
  177. # else:
  178. # pass
  179. # print "FAIL content hash %s %s" % (alg_name, event.event_id, )
  180. have_own_correctly_signed = False
  181. for host, sigs in event.signatures.items():
  182. pruned = prune_event(event)
  183. for key_id in sigs:
  184. if host not in server_keys:
  185. server_keys[host] = {} # get_key(host)
  186. if key_id in server_keys[host]:
  187. try:
  188. verify_signed_json(
  189. pruned.get_pdu_json(),
  190. host,
  191. server_keys[host][key_id]
  192. )
  193. if host == server_name:
  194. have_own_correctly_signed = True
  195. except SignatureVerifyException:
  196. print "FAIL signature check %s %s" % (
  197. key_id, event.event_id
  198. )
  199. # TODO: Re sign with our own server key
  200. if not have_own_correctly_signed:
  201. sigs = compute_event_signature(event, server_name, signing_key)
  202. event.signatures.update(sigs)
  203. pruned = prune_event(event)
  204. for key_id in event.signatures[server_name]:
  205. verify_signed_json(
  206. pruned.get_pdu_json(),
  207. server_name,
  208. server_keys[server_name][key_id]
  209. )
  210. event_json = encode_canonical_json(
  211. event.get_dict()
  212. ).decode("UTF-8")
  213. metadata_json = encode_canonical_json(
  214. event.internal_metadata.get_dict()
  215. ).decode("UTF-8")
  216. store._simple_insert_txn(
  217. cursor,
  218. table="event_json",
  219. values={
  220. "event_id": event.event_id,
  221. "room_id": event.room_id,
  222. "internal_metadata": metadata_json,
  223. "json": event_json,
  224. },
  225. or_replace=True,
  226. )
  227. def main(database, server_name, signing_key):
  228. conn = sqlite3.connect(database)
  229. cursor = conn.cursor()
  230. # Do other deltas:
  231. cursor.execute("PRAGMA user_version")
  232. row = cursor.fetchone()
  233. if row and row[0]:
  234. user_version = row[0]
  235. # Run every version since after the current version.
  236. for v in range(user_version + 1, 10):
  237. print "Running delta: %d" % (v,)
  238. sql_script = read_schema("delta/v%d" % (v,))
  239. cursor.executescript(sql_script)
  240. reinsert_events(cursor, server_name, signing_key)
  241. conn.commit()
  242. print "Success!"
  243. if __name__ == "__main__":
  244. parser = argparse.ArgumentParser()
  245. parser.add_argument("database")
  246. parser.add_argument("server_name")
  247. parser.add_argument(
  248. "signing_key", type=argparse.FileType('r'),
  249. )
  250. args = parser.parse_args()
  251. signing_key = syutil.crypto.signing_key.read_signing_keys(
  252. args.signing_key
  253. )
  254. main(args.database, args.server_name, signing_key[0])