convert_server_keys.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. import hashlib
  2. import json
  3. import sys
  4. import time
  5. import six
  6. import psycopg2
  7. import yaml
  8. from canonicaljson import encode_canonical_json
  9. from signedjson.key import read_signing_keys
  10. from signedjson.sign import sign_json
  11. from unpaddedbase64 import encode_base64
  12. if six.PY2:
  13. db_type = six.moves.builtins.buffer
  14. else:
  15. db_type = memoryview
  16. def select_v1_keys(connection):
  17. cursor = connection.cursor()
  18. cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
  19. rows = cursor.fetchall()
  20. cursor.close()
  21. results = {}
  22. for server_name, key_id, verify_key in rows:
  23. results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
  24. return results
  25. def select_v1_certs(connection):
  26. cursor = connection.cursor()
  27. cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
  28. rows = cursor.fetchall()
  29. cursor.close()
  30. results = {}
  31. for server_name, tls_certificate in rows:
  32. results[server_name] = tls_certificate
  33. return results
  34. def select_v2_json(connection):
  35. cursor = connection.cursor()
  36. cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
  37. rows = cursor.fetchall()
  38. cursor.close()
  39. results = {}
  40. for server_name, key_id, key_json in rows:
  41. results.setdefault(server_name, {})[key_id] = json.loads(
  42. str(key_json).decode("utf-8")
  43. )
  44. return results
  45. def convert_v1_to_v2(server_name, valid_until, keys, certificate):
  46. return {
  47. "old_verify_keys": {},
  48. "server_name": server_name,
  49. "verify_keys": {key_id: {"key": key} for key_id, key in keys.items()},
  50. "valid_until_ts": valid_until,
  51. "tls_fingerprints": [fingerprint(certificate)],
  52. }
  53. def fingerprint(certificate):
  54. finger = hashlib.sha256(certificate)
  55. return {"sha256": encode_base64(finger.digest())}
  56. def rows_v2(server, json):
  57. valid_until = json["valid_until_ts"]
  58. key_json = encode_canonical_json(json)
  59. for key_id in json["verify_keys"]:
  60. yield (server, key_id, "-", valid_until, valid_until, db_type(key_json))
  61. def main():
  62. config = yaml.safe_load(open(sys.argv[1]))
  63. valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
  64. server_name = config["server_name"]
  65. signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
  66. database = config["database"]
  67. assert database["name"] == "psycopg2", "Can only convert for postgresql"
  68. args = database["args"]
  69. args.pop("cp_max")
  70. args.pop("cp_min")
  71. connection = psycopg2.connect(**args)
  72. keys = select_v1_keys(connection)
  73. certificates = select_v1_certs(connection)
  74. json = select_v2_json(connection)
  75. result = {}
  76. for server in keys:
  77. if server not in json:
  78. v2_json = convert_v1_to_v2(
  79. server, valid_until, keys[server], certificates[server]
  80. )
  81. v2_json = sign_json(v2_json, server_name, signing_key)
  82. result[server] = v2_json
  83. yaml.safe_dump(result, sys.stdout, default_flow_style=False)
  84. rows = [row for server, json in result.items() for row in rows_v2(server, json)]
  85. cursor = connection.cursor()
  86. cursor.executemany(
  87. "INSERT INTO server_keys_json ("
  88. " server_name, key_id, from_server,"
  89. " ts_added_ms, ts_valid_until_ms, key_json"
  90. ") VALUES (%s, %s, %s, %s, %s, %s)",
  91. rows,
  92. )
  93. connection.commit()
  94. if __name__ == "__main__":
  95. main()