|
@@ -19,18 +19,23 @@ import urllib
|
|
|
import urlparse
|
|
|
|
|
|
from mock import Mock, patch
|
|
|
-from twisted.enterprise.adbapi import ConnectionPool
|
|
|
from twisted.internet import defer, reactor
|
|
|
|
|
|
from synapse.api.errors import CodeMessageException, cs_error
|
|
|
from synapse.federation.transport import server
|
|
|
from synapse.http.server import HttpServer
|
|
|
from synapse.server import HomeServer
|
|
|
+from synapse.storage import PostgresEngine
|
|
|
from synapse.storage.engines import create_engine
|
|
|
from synapse.storage.prepare_database import prepare_database
|
|
|
from synapse.util.logcontext import LoggingContext
|
|
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
|
|
|
|
|
+# set this to True to run the tests against postgres instead of sqlite.
|
|
|
+# It requires you to have a local postgres database called synapse_test, within
|
|
|
+# which ALL TABLES WILL BE DROPPED
|
|
|
+USE_POSTGRES_FOR_TESTS = False
|
|
|
+
|
|
|
|
|
|
@defer.inlineCallbacks
|
|
|
def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|
@@ -60,30 +65,62 @@ def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
|
|
|
config.update_user_directory = False
|
|
|
|
|
|
config.use_frozen_dicts = True
|
|
|
- config.database_config = {"name": "sqlite3"}
|
|
|
config.ldap_enabled = False
|
|
|
|
|
|
if "clock" not in kargs:
|
|
|
kargs["clock"] = MockClock()
|
|
|
|
|
|
+ if USE_POSTGRES_FOR_TESTS:
|
|
|
+ config.database_config = {
|
|
|
+ "name": "psycopg2",
|
|
|
+ "args": {
|
|
|
+ "database": "synapse_test",
|
|
|
+ "cp_min": 1,
|
|
|
+ "cp_max": 5,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ config.database_config = {
|
|
|
+ "name": "sqlite3",
|
|
|
+ "args": {
|
|
|
+ "database": ":memory:",
|
|
|
+ "cp_min": 1,
|
|
|
+ "cp_max": 1,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ db_engine = create_engine(config.database_config)
|
|
|
+
|
|
|
+ # we need to configure the connection pool to run the on_new_connection
|
|
|
+ # function, so that we can test code that uses custom sqlite functions
|
|
|
+ # (like rank).
|
|
|
+ config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
|
|
|
+
|
|
|
if datastore is None:
|
|
|
- db_pool = SQLiteMemoryDbPool()
|
|
|
- yield db_pool.prepare()
|
|
|
hs = HomeServer(
|
|
|
- name, db_pool=db_pool, config=config,
|
|
|
+ name, config=config,
|
|
|
+ db_config=config.database_config,
|
|
|
version_string="Synapse/tests",
|
|
|
- database_engine=create_engine(config.database_config),
|
|
|
- get_db_conn=db_pool.get_db_conn,
|
|
|
+ database_engine=db_engine,
|
|
|
room_list_handler=object(),
|
|
|
tls_server_context_factory=Mock(),
|
|
|
**kargs
|
|
|
)
|
|
|
+ db_conn = hs.get_db_conn()
|
|
|
+ # make sure that the database is empty
|
|
|
+ if isinstance(db_engine, PostgresEngine):
|
|
|
+ cur = db_conn.cursor()
|
|
|
+ cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
|
|
|
+ rows = cur.fetchall()
|
|
|
+ for r in rows:
|
|
|
+ cur.execute("DROP TABLE %s CASCADE" % r[0])
|
|
|
+ yield prepare_database(db_conn, db_engine, config)
|
|
|
hs.setup()
|
|
|
else:
|
|
|
hs = HomeServer(
|
|
|
name, db_pool=None, datastore=datastore, config=config,
|
|
|
version_string="Synapse/tests",
|
|
|
- database_engine=create_engine(config.database_config),
|
|
|
+ database_engine=db_engine,
|
|
|
room_list_handler=object(),
|
|
|
tls_server_context_factory=Mock(),
|
|
|
**kargs
|
|
@@ -302,34 +339,6 @@ class MockClock(object):
|
|
|
return d
|
|
|
|
|
|
|
|
|
-class SQLiteMemoryDbPool(ConnectionPool, object):
|
|
|
- def __init__(self):
|
|
|
- super(SQLiteMemoryDbPool, self).__init__(
|
|
|
- "sqlite3", ":memory:",
|
|
|
- cp_min=1,
|
|
|
- cp_max=1,
|
|
|
- )
|
|
|
-
|
|
|
- self.config = Mock()
|
|
|
- self.config.password_providers = []
|
|
|
- self.config.database_config = {"name": "sqlite3"}
|
|
|
-
|
|
|
- def prepare(self):
|
|
|
- engine = self.create_engine()
|
|
|
- return self.runWithConnection(
|
|
|
- lambda conn: prepare_database(conn, engine, self.config)
|
|
|
- )
|
|
|
-
|
|
|
- def get_db_conn(self):
|
|
|
- conn = self.connect()
|
|
|
- engine = self.create_engine()
|
|
|
- prepare_database(conn, engine, self.config)
|
|
|
- return conn
|
|
|
-
|
|
|
- def create_engine(self):
|
|
|
- return create_engine(self.config.database_config)
|
|
|
-
|
|
|
-
|
|
|
def _format_call(args, kwargs):
|
|
|
return ", ".join(
|
|
|
["%r" % (a) for a in args] +
|