1
0

sqlitedb.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import sqlite3
  17. import logging
  18. import os
  19. logger = logging.getLogger(__name__)
  20. class SqliteDatabase:
  21. def __init__(self, syd):
  22. self.sydent = syd
  23. dbFilePath = self.sydent.cfg.get("db", "db.file")
  24. logger.info("Using DB file %s", dbFilePath)
  25. self.db = sqlite3.connect(dbFilePath)
  26. curVer = self._getSchemaVersion()
  27. # We always run the schema files if the version is zero: either the db is
  28. # completely empty and schema-less or it has the v0 schema, which is safe to
  29. # replay the schema files. The files in the sql directory are the v0 schema, so
  30. # a new installations will start as v0 then be upgraded to the current version.
  31. if curVer == 0:
  32. self._createSchema()
  33. self._upgradeSchema()
  34. def _createSchema(self):
  35. logger.info("Running schema files...")
  36. schemaDir = os.path.dirname(__file__)
  37. c = self.db.cursor()
  38. for f in os.listdir(schemaDir):
  39. if not f.endswith(".sql"):
  40. continue
  41. scriptPath = os.path.join(schemaDir, f)
  42. fp = open(scriptPath, 'r')
  43. try:
  44. c.executescript(fp.read())
  45. except:
  46. logger.error("Error importing %s", f)
  47. raise
  48. fp.close()
  49. c.close()
  50. self.db.commit()
  51. def _upgradeSchema(self):
  52. curVer = self._getSchemaVersion()
  53. if curVer < 1:
  54. cur = self.db.cursor()
  55. # add auto_increment to the primary key of local_threepid_associations to ensure ids are never re-used,
  56. # allow the mxid column to be null to represent the deletion of a binding
  57. # and remove not null constraints on ts, notBefore and notAfter (again for when a binding has been deleted
  58. # and these wouldn't be very meaningful)
  59. logger.info("Migrating schema from v0 to v1")
  60. cur.execute("DROP INDEX IF EXISTS medium_address")
  61. cur.execute("DROP INDEX IF EXISTS local_threepid_medium_address")
  62. cur.execute("ALTER TABLE local_threepid_associations RENAME TO old_local_threepid_associations");
  63. cur.execute(
  64. "CREATE TABLE local_threepid_associations (id integer primary key autoincrement, "
  65. "medium varchar(16) not null, "
  66. "address varchar(256) not null, "
  67. "mxid varchar(256), "
  68. "ts integer, "
  69. "notBefore bigint, "
  70. "notAfter bigint)"
  71. )
  72. cur.execute(
  73. "INSERT INTO local_threepid_associations (medium, address, mxid, ts, notBefore, notAfter) "
  74. "SELECT medium, address, mxid, ts, notBefore, notAfter FROM old_local_threepid_associations"
  75. )
  76. cur.execute(
  77. "CREATE UNIQUE INDEX local_threepid_medium_address on local_threepid_associations(medium, address)"
  78. )
  79. cur.execute("DROP TABLE old_local_threepid_associations")
  80. # same autoincrement for global_threepid_associations (fields stay non-nullable because we don't need
  81. # entries in this table for deletions, we can just delete the rows)
  82. cur.execute("DROP INDEX IF EXISTS global_threepid_medium_address")
  83. cur.execute("DROP INDEX IF EXISTS global_threepid_medium_lower_address")
  84. cur.execute("DROP INDEX IF EXISTS global_threepid_originServer_originId")
  85. cur.execute("DROP INDEX IF EXISTS medium_lower_address")
  86. cur.execute("DROP INDEX IF EXISTS threepid_originServer_originId")
  87. cur.execute("ALTER TABLE global_threepid_associations RENAME TO old_global_threepid_associations");
  88. cur.execute(
  89. "CREATE TABLE IF NOT EXISTS global_threepid_associations "
  90. "(id integer primary key autoincrement, "
  91. "medium varchar(16) not null, "
  92. "address varchar(256) not null, "
  93. "mxid varchar(256) not null, "
  94. "ts integer not null, "
  95. "notBefore bigint not null, "
  96. "notAfter integer not null, "
  97. "originServer varchar(255) not null, "
  98. "originId integer not null, "
  99. "sgAssoc text not null)"
  100. )
  101. cur.execute(
  102. "INSERT INTO global_threepid_associations "
  103. "(medium, address, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc) "
  104. "SELECT medium, address, mxid, ts, notBefore, notAfter, originServer, originId, sgAssoc "
  105. "FROM old_global_threepid_associations"
  106. )
  107. cur.execute("CREATE INDEX global_threepid_medium_address on global_threepid_associations (medium, address)")
  108. cur.execute(
  109. "CREATE INDEX global_threepid_medium_lower_address on "
  110. "global_threepid_associations (medium, lower(address))"
  111. )
  112. cur.execute(
  113. "CREATE UNIQUE INDEX global_threepid_originServer_originId on "
  114. "global_threepid_associations (originServer, originId)"
  115. )
  116. cur.execute("DROP TABLE old_global_threepid_associations")
  117. self.db.commit()
  118. logger.info("v0 -> v1 schema migration complete")
  119. self._setSchemaVersion(1)
  120. def _getSchemaVersion(self):
  121. cur = self.db.cursor()
  122. res = cur.execute("PRAGMA user_version");
  123. row = cur.fetchone()
  124. return row[0]
  125. def _setSchemaVersion(self, ver):
  126. cur = self.db.cursor()
  127. # NB. pragma doesn't support variable substitution so we
  128. # do it in python (as a decimal so we don't risk SQL injection)
  129. res = cur.execute("PRAGMA user_version = %d" % (ver,));