utils.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018-2019 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import atexit
  16. import hashlib
  17. import os
  18. import time
  19. import uuid
  20. import warnings
  21. from typing import Type
  22. from unittest.mock import Mock, patch
  23. from urllib import parse as urlparse
  24. from twisted.internet import defer
  25. from synapse.api.constants import EventTypes
  26. from synapse.api.errors import CodeMessageException, cs_error
  27. from synapse.api.room_versions import RoomVersions
  28. from synapse.config.database import DatabaseConnectionConfig
  29. from synapse.config.homeserver import HomeServerConfig
  30. from synapse.config.server import DEFAULT_ROOM_VERSION
  31. from synapse.logging.context import current_context, set_current_context
  32. from synapse.server import HomeServer
  33. from synapse.storage import DataStore
  34. from synapse.storage.database import LoggingDatabaseConnection
  35. from synapse.storage.engines import PostgresEngine, create_engine
  36. from synapse.storage.prepare_database import prepare_database
  37. # set this to True to run the tests against postgres instead of sqlite.
  38. #
  39. # When running under postgres, we first create a base database with the name
  40. # POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we
  41. # create another unique database, using the base database as a template.
  42. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  43. LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
  44. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
  45. POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
  46. POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
  47. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  48. # the dbname we will connect to in order to create the base database.
  49. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
  50. def setupdb():
  51. # If we're using PostgreSQL, set up the db once
  52. if USE_POSTGRES_FOR_TESTS:
  53. # create a PostgresEngine
  54. db_engine = create_engine({"name": "psycopg2", "args": {}})
  55. # connect to postgres to create the base database.
  56. db_conn = db_engine.module.connect(
  57. user=POSTGRES_USER,
  58. host=POSTGRES_HOST,
  59. password=POSTGRES_PASSWORD,
  60. dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
  61. )
  62. db_conn.autocommit = True
  63. cur = db_conn.cursor()
  64. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  65. cur.execute(
  66. "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' "
  67. "template=template0;" % (POSTGRES_BASE_DB,)
  68. )
  69. cur.close()
  70. db_conn.close()
  71. # Set up in the db
  72. db_conn = db_engine.module.connect(
  73. database=POSTGRES_BASE_DB,
  74. user=POSTGRES_USER,
  75. host=POSTGRES_HOST,
  76. password=POSTGRES_PASSWORD,
  77. )
  78. db_conn = LoggingDatabaseConnection(db_conn, db_engine, "tests")
  79. prepare_database(db_conn, db_engine, None)
  80. db_conn.close()
  81. def _cleanup():
  82. db_conn = db_engine.module.connect(
  83. user=POSTGRES_USER,
  84. host=POSTGRES_HOST,
  85. password=POSTGRES_PASSWORD,
  86. dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
  87. )
  88. db_conn.autocommit = True
  89. cur = db_conn.cursor()
  90. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  91. cur.close()
  92. db_conn.close()
  93. atexit.register(_cleanup)
  94. def default_config(name, parse=False):
  95. """
  96. Create a reasonable test config.
  97. """
  98. config_dict = {
  99. "server_name": name,
  100. "send_federation": False,
  101. "media_store_path": "media",
  102. # the test signing key is just an arbitrary ed25519 key to keep the config
  103. # parser happy
  104. "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
  105. "event_cache_size": 1,
  106. "enable_registration": True,
  107. "enable_registration_captcha": False,
  108. "macaroon_secret_key": "not even a little secret",
  109. "trusted_third_party_id_servers": [],
  110. "password_providers": [],
  111. "worker_replication_url": "",
  112. "worker_app": None,
  113. "block_non_admin_invites": False,
  114. "federation_domain_whitelist": None,
  115. "filter_timeline_limit": 5000,
  116. "user_directory_search_all_users": False,
  117. "user_consent_server_notice_content": None,
  118. "block_events_without_consent_error": None,
  119. "user_consent_at_registration": False,
  120. "user_consent_policy_name": "Privacy Policy",
  121. "media_storage_providers": [],
  122. "autocreate_auto_join_rooms": True,
  123. "auto_join_rooms": [],
  124. "limit_usage_by_mau": False,
  125. "hs_disabled": False,
  126. "hs_disabled_message": "",
  127. "max_mau_value": 50,
  128. "mau_trial_days": 0,
  129. "mau_stats_only": False,
  130. "mau_limits_reserved_threepids": [],
  131. "admin_contact": None,
  132. "rc_message": {"per_second": 10000, "burst_count": 10000},
  133. "rc_registration": {"per_second": 10000, "burst_count": 10000},
  134. "rc_login": {
  135. "address": {"per_second": 10000, "burst_count": 10000},
  136. "account": {"per_second": 10000, "burst_count": 10000},
  137. "failed_attempts": {"per_second": 10000, "burst_count": 10000},
  138. },
  139. "rc_joins": {
  140. "local": {"per_second": 10000, "burst_count": 10000},
  141. "remote": {"per_second": 10000, "burst_count": 10000},
  142. },
  143. "rc_invites": {
  144. "per_room": {"per_second": 10000, "burst_count": 10000},
  145. "per_user": {"per_second": 10000, "burst_count": 10000},
  146. },
  147. "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000},
  148. "saml2_enabled": False,
  149. "public_baseurl": None,
  150. "default_identity_server": None,
  151. "key_refresh_interval": 24 * 60 * 60 * 1000,
  152. "old_signing_keys": {},
  153. "tls_fingerprints": [],
  154. "use_frozen_dicts": False,
  155. # We need a sane default_room_version, otherwise attempts to create
  156. # rooms will fail.
  157. "default_room_version": DEFAULT_ROOM_VERSION,
  158. # disable user directory updates, because they get done in the
  159. # background, which upsets the test runner.
  160. "update_user_directory": False,
  161. "caches": {"global_factor": 1},
  162. "listeners": [{"port": 0, "type": "http"}],
  163. }
  164. if parse:
  165. config = HomeServerConfig()
  166. config.parse_config_dict(config_dict, "", "")
  167. return config
  168. return config_dict
  169. class TestHomeServer(HomeServer):
  170. DATASTORE_CLASS = DataStore
  171. def setup_test_homeserver(
  172. cleanup_func,
  173. name="test",
  174. config=None,
  175. reactor=None,
  176. homeserver_to_use: Type[HomeServer] = TestHomeServer,
  177. **kwargs,
  178. ):
  179. """
  180. Setup a homeserver suitable for running tests against. Keyword arguments
  181. are passed to the Homeserver constructor.
  182. If no datastore is supplied, one is created and given to the homeserver.
  183. Args:
  184. cleanup_func : The function used to register a cleanup routine for
  185. after the test.
  186. Calling this method directly is deprecated: you should instead derive from
  187. HomeserverTestCase.
  188. """
  189. if reactor is None:
  190. from twisted.internet import reactor
  191. if config is None:
  192. config = default_config(name, parse=True)
  193. config.ldap_enabled = False
  194. if "clock" not in kwargs:
  195. kwargs["clock"] = MockClock()
  196. if USE_POSTGRES_FOR_TESTS:
  197. test_db = "synapse_test_%s" % uuid.uuid4().hex
  198. database_config = {
  199. "name": "psycopg2",
  200. "args": {
  201. "database": test_db,
  202. "host": POSTGRES_HOST,
  203. "password": POSTGRES_PASSWORD,
  204. "user": POSTGRES_USER,
  205. "cp_min": 1,
  206. "cp_max": 5,
  207. },
  208. }
  209. else:
  210. database_config = {
  211. "name": "sqlite3",
  212. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  213. }
  214. if "db_txn_limit" in kwargs:
  215. database_config["txn_limit"] = kwargs["db_txn_limit"]
  216. database = DatabaseConnectionConfig("master", database_config)
  217. config.database.databases = [database]
  218. db_engine = create_engine(database.config)
  219. # Create the database before we actually try and connect to it, based off
  220. # the template database we generate in setupdb()
  221. if isinstance(db_engine, PostgresEngine):
  222. db_conn = db_engine.module.connect(
  223. database=POSTGRES_BASE_DB,
  224. user=POSTGRES_USER,
  225. host=POSTGRES_HOST,
  226. password=POSTGRES_PASSWORD,
  227. )
  228. db_conn.autocommit = True
  229. cur = db_conn.cursor()
  230. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  231. cur.execute(
  232. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  233. )
  234. cur.close()
  235. db_conn.close()
  236. hs = homeserver_to_use(
  237. name,
  238. config=config,
  239. version_string="Synapse/tests",
  240. reactor=reactor,
  241. )
  242. # Install @cache_in_self attributes
  243. for key, val in kwargs.items():
  244. setattr(hs, "_" + key, val)
  245. # Mock TLS
  246. hs.tls_server_context_factory = Mock()
  247. hs.tls_client_options_factory = Mock()
  248. hs.setup()
  249. if homeserver_to_use == TestHomeServer:
  250. hs.setup_background_tasks()
  251. if isinstance(db_engine, PostgresEngine):
  252. database = hs.get_datastores().databases[0]
  253. # We need to do cleanup on PostgreSQL
  254. def cleanup():
  255. import psycopg2
  256. # Close all the db pools
  257. database._db_pool.close()
  258. dropped = False
  259. # Drop the test database
  260. db_conn = db_engine.module.connect(
  261. database=POSTGRES_BASE_DB,
  262. user=POSTGRES_USER,
  263. host=POSTGRES_HOST,
  264. password=POSTGRES_PASSWORD,
  265. )
  266. db_conn.autocommit = True
  267. cur = db_conn.cursor()
  268. # Try a few times to drop the DB. Some things may hold on to the
  269. # database for a few more seconds due to flakiness, preventing
  270. # us from dropping it when the test is over. If we can't drop
  271. # it, warn and move on.
  272. for _ in range(5):
  273. try:
  274. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  275. db_conn.commit()
  276. dropped = True
  277. except psycopg2.OperationalError as e:
  278. warnings.warn(
  279. "Couldn't drop old db: " + str(e), category=UserWarning
  280. )
  281. time.sleep(0.5)
  282. cur.close()
  283. db_conn.close()
  284. if not dropped:
  285. warnings.warn("Failed to drop old DB.", category=UserWarning)
  286. if not LEAVE_DB:
  287. # Register the cleanup hook
  288. cleanup_func(cleanup)
  289. # bcrypt is far too slow to be doing in unit tests
  290. # Need to let the HS build an auth handler and then mess with it
  291. # because AuthHandler's constructor requires the HS, so we can't make one
  292. # beforehand and pass it in to the HS's constructor (chicken / egg)
  293. async def hash(p):
  294. return hashlib.md5(p.encode("utf8")).hexdigest()
  295. hs.get_auth_handler().hash = hash
  296. async def validate_hash(p, h):
  297. return hashlib.md5(p.encode("utf8")).hexdigest() == h
  298. hs.get_auth_handler().validate_hash = validate_hash
  299. return hs
  300. def mock_getRawHeaders(headers=None):
  301. headers = headers if headers is not None else {}
  302. def getRawHeaders(name, default=None):
  303. return headers.get(name, default)
  304. return getRawHeaders
  305. # This is a mock /resource/ not an entire server
  306. class MockHttpResource:
  307. def __init__(self, prefix=""):
  308. self.callbacks = [] # 3-tuple of method/pattern/function
  309. self.prefix = prefix
  310. def trigger_get(self, path):
  311. return self.trigger(b"GET", path, None)
  312. @patch("twisted.web.http.Request")
  313. @defer.inlineCallbacks
  314. def trigger(
  315. self, http_method, path, content, mock_request, federation_auth_origin=None
  316. ):
  317. """Fire an HTTP event.
  318. Args:
  319. http_method : The HTTP method
  320. path : The HTTP path
  321. content : The HTTP body
  322. mock_request : Mocked request to pass to the event so it can get
  323. content.
  324. federation_auth_origin (bytes|None): domain to authenticate as, for federation
  325. Returns:
  326. A tuple of (code, response)
  327. Raises:
  328. KeyError If no event is found which will handle the path.
  329. """
  330. path = self.prefix + path
  331. # annoyingly we return a twisted http request which has chained calls
  332. # to get at the http content, hence mock it here.
  333. mock_content = Mock()
  334. config = {"read.return_value": content}
  335. mock_content.configure_mock(**config)
  336. mock_request.content = mock_content
  337. mock_request.method = http_method.encode("ascii")
  338. mock_request.uri = path.encode("ascii")
  339. mock_request.getClientIP.return_value = "-"
  340. headers = {}
  341. if federation_auth_origin is not None:
  342. headers[b"Authorization"] = [
  343. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
  344. ]
  345. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  346. # return the right path if the event requires it
  347. mock_request.path = path
  348. # add in query params to the right place
  349. try:
  350. mock_request.args = urlparse.parse_qs(path.split("?")[1])
  351. mock_request.path = path.split("?")[0]
  352. path = mock_request.path
  353. except Exception:
  354. pass
  355. if isinstance(path, bytes):
  356. path = path.decode("utf8")
  357. for (method, pattern, func) in self.callbacks:
  358. if http_method != method:
  359. continue
  360. matcher = pattern.match(path)
  361. if matcher:
  362. try:
  363. args = [urlparse.unquote(u) for u in matcher.groups()]
  364. (code, response) = yield defer.ensureDeferred(
  365. func(mock_request, *args)
  366. )
  367. return code, response
  368. except CodeMessageException as e:
  369. return e.code, cs_error(e.msg, code=e.errcode)
  370. raise KeyError("No event can handle %s" % path)
  371. def register_paths(self, method, path_patterns, callback, servlet_name):
  372. for path_pattern in path_patterns:
  373. self.callbacks.append((method, path_pattern, callback))
  374. class MockKey:
  375. alg = "mock_alg"
  376. version = "mock_version"
  377. signature = b"\x9a\x87$"
  378. @property
  379. def verify_key(self):
  380. return self
  381. def sign(self, message):
  382. return self
  383. def verify(self, message, sig):
  384. assert sig == b"\x9a\x87$"
  385. def encode(self):
  386. return b"<fake_encoded_key>"
  387. class MockClock:
  388. now = 1000
  389. def __init__(self):
  390. # list of lists of [absolute_time, callback, expired] in no particular
  391. # order
  392. self.timers = []
  393. self.loopers = []
  394. def time(self):
  395. return self.now
  396. def time_msec(self):
  397. return self.time() * 1000
  398. def call_later(self, delay, callback, *args, **kwargs):
  399. ctx = current_context()
  400. def wrapped_callback():
  401. set_current_context(ctx)
  402. callback(*args, **kwargs)
  403. t = [self.now + delay, wrapped_callback, False]
  404. self.timers.append(t)
  405. return t
  406. def looping_call(self, function, interval, *args, **kwargs):
  407. self.loopers.append([function, interval / 1000.0, self.now, args, kwargs])
  408. def cancel_call_later(self, timer, ignore_errs=False):
  409. if timer[2]:
  410. if not ignore_errs:
  411. raise Exception("Cannot cancel an expired timer")
  412. timer[2] = True
  413. self.timers = [t for t in self.timers if t != timer]
  414. # For unit testing
  415. def advance_time(self, secs):
  416. self.now += secs
  417. timers = self.timers
  418. self.timers = []
  419. for t in timers:
  420. time, callback, expired = t
  421. if expired:
  422. raise Exception("Timer already expired")
  423. if self.now >= time:
  424. t[2] = True
  425. callback()
  426. else:
  427. self.timers.append(t)
  428. for looped in self.loopers:
  429. func, interval, last, args, kwargs = looped
  430. if last + interval < self.now:
  431. func(*args, **kwargs)
  432. looped[2] = self.now
  433. def advance_time_msec(self, ms):
  434. self.advance_time(ms / 1000.0)
  435. def time_bound_deferred(self, d, *args, **kwargs):
  436. # We don't bother timing things out for now.
  437. return d
  438. async def create_room(hs, room_id: str, creator_id: str):
  439. """Creates and persist a creation event for the given room"""
  440. persistence_store = hs.get_storage().persistence
  441. store = hs.get_datastore()
  442. event_builder_factory = hs.get_event_builder_factory()
  443. event_creation_handler = hs.get_event_creation_handler()
  444. await store.store_room(
  445. room_id=room_id,
  446. room_creator_user_id=creator_id,
  447. is_public=False,
  448. room_version=RoomVersions.V1,
  449. )
  450. builder = event_builder_factory.for_room_version(
  451. RoomVersions.V1,
  452. {
  453. "type": EventTypes.Create,
  454. "state_key": "",
  455. "sender": creator_id,
  456. "room_id": room_id,
  457. "content": {},
  458. },
  459. )
  460. event, context = await event_creation_handler.create_new_client_event(builder)
  461. await persistence_store.persist_event(event, context)