utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. # Copyright 2018-2019 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import atexit
  17. import hashlib
  18. import os
  19. import time
  20. import uuid
  21. import warnings
  22. from inspect import getcallargs
  23. from mock import Mock, patch
  24. from six.moves.urllib import parse as urlparse
  25. from twisted.internet import defer, reactor
  26. from synapse.api.constants import EventTypes, RoomVersions
  27. from synapse.api.errors import CodeMessageException, cs_error
  28. from synapse.config.homeserver import HomeServerConfig
  29. from synapse.federation.transport import server as federation_server
  30. from synapse.http.server import HttpServer
  31. from synapse.server import HomeServer
  32. from synapse.storage import DataStore
  33. from synapse.storage.engines import PostgresEngine, create_engine
  34. from synapse.storage.prepare_database import (
  35. _get_or_create_schema_state,
  36. _setup_new_database,
  37. prepare_database,
  38. )
  39. from synapse.util.logcontext import LoggingContext
  40. from synapse.util.ratelimitutils import FederationRateLimiter
  41. # set this to True to run the tests against postgres instead of sqlite.
  42. #
  43. # When running under postgres, we first create a base database with the name
  44. # POSTGRES_BASE_DB and update it to the current schema. Then, for each test case, we
  45. # create another unique database, using the base database as a template.
  46. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  47. LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
  48. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
  49. POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
  50. POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
  51. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  52. # the dbname we will connect to in order to create the base database.
  53. POSTGRES_DBNAME_FOR_INITIAL_CREATE = "postgres"
  54. def setupdb():
  55. # If we're using PostgreSQL, set up the db once
  56. if USE_POSTGRES_FOR_TESTS:
  57. # create a PostgresEngine
  58. db_engine = create_engine({"name": "psycopg2", "args": {}})
  59. # connect to postgres to create the base database.
  60. db_conn = db_engine.module.connect(
  61. user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
  62. dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
  63. )
  64. db_conn.autocommit = True
  65. cur = db_conn.cursor()
  66. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  67. cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
  68. cur.close()
  69. db_conn.close()
  70. # Set up in the db
  71. db_conn = db_engine.module.connect(
  72. database=POSTGRES_BASE_DB,
  73. user=POSTGRES_USER,
  74. host=POSTGRES_HOST,
  75. password=POSTGRES_PASSWORD,
  76. )
  77. cur = db_conn.cursor()
  78. _get_or_create_schema_state(cur, db_engine)
  79. _setup_new_database(cur, db_engine)
  80. db_conn.commit()
  81. cur.close()
  82. db_conn.close()
  83. def _cleanup():
  84. db_conn = db_engine.module.connect(
  85. user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD,
  86. dbname=POSTGRES_DBNAME_FOR_INITIAL_CREATE,
  87. )
  88. db_conn.autocommit = True
  89. cur = db_conn.cursor()
  90. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  91. cur.close()
  92. db_conn.close()
  93. atexit.register(_cleanup)
  94. def default_config(name):
  95. """
  96. Create a reasonable test config.
  97. """
  98. config_dict = {
  99. "server_name": name,
  100. "media_store_path": "media",
  101. "uploads_path": "uploads",
  102. # the test signing key is just an arbitrary ed25519 key to keep the config
  103. # parser happy
  104. "signing_key": "ed25519 a_lPym qvioDNmfExFBRPgdTU+wtFYKq4JfwFRv7sYVgWvmgJg",
  105. }
  106. config = HomeServerConfig()
  107. config.parse_config_dict(config_dict)
  108. # TODO: move this stuff into config_dict or get rid of it
  109. config.event_cache_size = 1
  110. config.enable_registration = True
  111. config.enable_registration_captcha = False
  112. config.macaroon_secret_key = "not even a little secret"
  113. config.expire_access_token = False
  114. config.trusted_third_party_id_servers = []
  115. config.room_invite_state_types = []
  116. config.password_providers = []
  117. config.worker_replication_url = ""
  118. config.worker_app = None
  119. config.email_enable_notifs = False
  120. config.block_non_admin_invites = False
  121. config.federation_domain_whitelist = None
  122. config.federation_rc_reject_limit = 10
  123. config.federation_rc_sleep_limit = 10
  124. config.federation_rc_sleep_delay = 100
  125. config.federation_rc_concurrent = 10
  126. config.filter_timeline_limit = 5000
  127. config.user_directory_search_all_users = False
  128. config.user_consent_server_notice_content = None
  129. config.block_events_without_consent_error = None
  130. config.user_consent_at_registration = False
  131. config.user_consent_policy_name = "Privacy Policy"
  132. config.media_storage_providers = []
  133. config.autocreate_auto_join_rooms = True
  134. config.auto_join_rooms = []
  135. config.limit_usage_by_mau = False
  136. config.hs_disabled = False
  137. config.hs_disabled_message = ""
  138. config.hs_disabled_limit_type = ""
  139. config.max_mau_value = 50
  140. config.mau_trial_days = 0
  141. config.mau_stats_only = False
  142. config.mau_limits_reserved_threepids = []
  143. config.admin_contact = None
  144. config.rc_messages_per_second = 10000
  145. config.rc_message_burst_count = 10000
  146. config.rc_registration.per_second = 10000
  147. config.rc_registration.burst_count = 10000
  148. config.rc_login_address.per_second = 10000
  149. config.rc_login_address.burst_count = 10000
  150. config.rc_login_account.per_second = 10000
  151. config.rc_login_account.burst_count = 10000
  152. config.rc_login_failed_attempts.per_second = 10000
  153. config.rc_login_failed_attempts.burst_count = 10000
  154. config.saml2_enabled = False
  155. config.public_baseurl = None
  156. config.default_identity_server = None
  157. config.key_refresh_interval = 24 * 60 * 60 * 1000
  158. config.old_signing_keys = {}
  159. config.tls_fingerprints = []
  160. config.use_frozen_dicts = False
  161. # we need a sane default_room_version, otherwise attempts to create rooms will
  162. # fail.
  163. config.default_room_version = "1"
  164. # disable user directory updates, because they get done in the
  165. # background, which upsets the test runner.
  166. config.update_user_directory = False
  167. return config
  168. class TestHomeServer(HomeServer):
  169. DATASTORE_CLASS = DataStore
  170. @defer.inlineCallbacks
  171. def setup_test_homeserver(
  172. cleanup_func,
  173. name="test",
  174. datastore=None,
  175. config=None,
  176. reactor=None,
  177. homeserverToUse=TestHomeServer,
  178. **kargs
  179. ):
  180. """
  181. Setup a homeserver suitable for running tests against. Keyword arguments
  182. are passed to the Homeserver constructor.
  183. If no datastore is supplied, one is created and given to the homeserver.
  184. Args:
  185. cleanup_func : The function used to register a cleanup routine for
  186. after the test.
  187. Calling this method directly is deprecated: you should instead derive from
  188. HomeserverTestCase.
  189. """
  190. if reactor is None:
  191. from twisted.internet import reactor
  192. if config is None:
  193. config = default_config(name)
  194. config.ldap_enabled = False
  195. if "clock" not in kargs:
  196. kargs["clock"] = MockClock()
  197. if USE_POSTGRES_FOR_TESTS:
  198. test_db = "synapse_test_%s" % uuid.uuid4().hex
  199. config.database_config = {
  200. "name": "psycopg2",
  201. "args": {
  202. "database": test_db,
  203. "host": POSTGRES_HOST,
  204. "password": POSTGRES_PASSWORD,
  205. "user": POSTGRES_USER,
  206. "cp_min": 1,
  207. "cp_max": 5,
  208. },
  209. }
  210. else:
  211. config.database_config = {
  212. "name": "sqlite3",
  213. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  214. }
  215. db_engine = create_engine(config.database_config)
  216. # Create the database before we actually try and connect to it, based off
  217. # the template database we generate in setupdb()
  218. if datastore is None and isinstance(db_engine, PostgresEngine):
  219. db_conn = db_engine.module.connect(
  220. database=POSTGRES_BASE_DB,
  221. user=POSTGRES_USER,
  222. host=POSTGRES_HOST,
  223. password=POSTGRES_PASSWORD,
  224. )
  225. db_conn.autocommit = True
  226. cur = db_conn.cursor()
  227. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  228. cur.execute(
  229. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  230. )
  231. cur.close()
  232. db_conn.close()
  233. # we need to configure the connection pool to run the on_new_connection
  234. # function, so that we can test code that uses custom sqlite functions
  235. # (like rank).
  236. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  237. if datastore is None:
  238. hs = homeserverToUse(
  239. name,
  240. config=config,
  241. db_config=config.database_config,
  242. version_string="Synapse/tests",
  243. database_engine=db_engine,
  244. tls_server_context_factory=Mock(),
  245. tls_client_options_factory=Mock(),
  246. reactor=reactor,
  247. **kargs
  248. )
  249. # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
  250. # date db
  251. if not isinstance(db_engine, PostgresEngine):
  252. db_conn = hs.get_db_conn()
  253. yield prepare_database(db_conn, db_engine, config)
  254. db_conn.commit()
  255. db_conn.close()
  256. else:
  257. # We need to do cleanup on PostgreSQL
  258. def cleanup():
  259. import psycopg2
  260. # Close all the db pools
  261. hs.get_db_pool().close()
  262. dropped = False
  263. # Drop the test database
  264. db_conn = db_engine.module.connect(
  265. database=POSTGRES_BASE_DB,
  266. user=POSTGRES_USER,
  267. host=POSTGRES_HOST,
  268. password=POSTGRES_PASSWORD,
  269. )
  270. db_conn.autocommit = True
  271. cur = db_conn.cursor()
  272. # Try a few times to drop the DB. Some things may hold on to the
  273. # database for a few more seconds due to flakiness, preventing
  274. # us from dropping it when the test is over. If we can't drop
  275. # it, warn and move on.
  276. for x in range(5):
  277. try:
  278. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  279. db_conn.commit()
  280. dropped = True
  281. except psycopg2.OperationalError as e:
  282. warnings.warn(
  283. "Couldn't drop old db: " + str(e), category=UserWarning
  284. )
  285. time.sleep(0.5)
  286. cur.close()
  287. db_conn.close()
  288. if not dropped:
  289. warnings.warn("Failed to drop old DB.", category=UserWarning)
  290. if not LEAVE_DB:
  291. # Register the cleanup hook
  292. cleanup_func(cleanup)
  293. hs.setup()
  294. if homeserverToUse.__name__ == "TestHomeServer":
  295. hs.setup_master()
  296. else:
  297. hs = homeserverToUse(
  298. name,
  299. db_pool=None,
  300. datastore=datastore,
  301. config=config,
  302. version_string="Synapse/tests",
  303. database_engine=db_engine,
  304. tls_server_context_factory=Mock(),
  305. tls_client_options_factory=Mock(),
  306. reactor=reactor,
  307. **kargs
  308. )
  309. # bcrypt is far too slow to be doing in unit tests
  310. # Need to let the HS build an auth handler and then mess with it
  311. # because AuthHandler's constructor requires the HS, so we can't make one
  312. # beforehand and pass it in to the HS's constructor (chicken / egg)
  313. hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
  314. hs.get_auth_handler().validate_hash = (
  315. lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
  316. )
  317. fed = kargs.get("resource_for_federation", None)
  318. if fed:
  319. register_federation_servlets(hs, fed)
  320. defer.returnValue(hs)
  321. def register_federation_servlets(hs, resource):
  322. federation_server.register_servlets(
  323. hs,
  324. resource=resource,
  325. authenticator=federation_server.Authenticator(hs),
  326. ratelimiter=FederationRateLimiter(
  327. hs.get_clock(),
  328. window_size=hs.config.federation_rc_window_size,
  329. sleep_limit=hs.config.federation_rc_sleep_limit,
  330. sleep_msec=hs.config.federation_rc_sleep_delay,
  331. reject_limit=hs.config.federation_rc_reject_limit,
  332. concurrent_requests=hs.config.federation_rc_concurrent,
  333. ),
  334. )
  335. def get_mock_call_args(pattern_func, mock_func):
  336. """ Return the arguments the mock function was called with interpreted
  337. by the pattern functions argument list.
  338. """
  339. invoked_args, invoked_kargs = mock_func.call_args
  340. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  341. def mock_getRawHeaders(headers=None):
  342. headers = headers if headers is not None else {}
  343. def getRawHeaders(name, default=None):
  344. return headers.get(name, default)
  345. return getRawHeaders
  346. # This is a mock /resource/ not an entire server
  347. class MockHttpResource(HttpServer):
  348. def __init__(self, prefix=""):
  349. self.callbacks = [] # 3-tuple of method/pattern/function
  350. self.prefix = prefix
  351. def trigger_get(self, path):
  352. return self.trigger(b"GET", path, None)
  353. @patch('twisted.web.http.Request')
  354. @defer.inlineCallbacks
  355. def trigger(
  356. self, http_method, path, content, mock_request, federation_auth_origin=None
  357. ):
  358. """ Fire an HTTP event.
  359. Args:
  360. http_method : The HTTP method
  361. path : The HTTP path
  362. content : The HTTP body
  363. mock_request : Mocked request to pass to the event so it can get
  364. content.
  365. federation_auth_origin (bytes|None): domain to authenticate as, for federation
  366. Returns:
  367. A tuple of (code, response)
  368. Raises:
  369. KeyError If no event is found which will handle the path.
  370. """
  371. path = self.prefix + path
  372. # annoyingly we return a twisted http request which has chained calls
  373. # to get at the http content, hence mock it here.
  374. mock_content = Mock()
  375. config = {'read.return_value': content}
  376. mock_content.configure_mock(**config)
  377. mock_request.content = mock_content
  378. mock_request.method = http_method.encode('ascii')
  379. mock_request.uri = path.encode('ascii')
  380. mock_request.getClientIP.return_value = "-"
  381. headers = {}
  382. if federation_auth_origin is not None:
  383. headers[b"Authorization"] = [
  384. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
  385. ]
  386. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  387. # return the right path if the event requires it
  388. mock_request.path = path
  389. # add in query params to the right place
  390. try:
  391. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  392. mock_request.path = path.split('?')[0]
  393. path = mock_request.path
  394. except Exception:
  395. pass
  396. if isinstance(path, bytes):
  397. path = path.decode('utf8')
  398. for (method, pattern, func) in self.callbacks:
  399. if http_method != method:
  400. continue
  401. matcher = pattern.match(path)
  402. if matcher:
  403. try:
  404. args = [urlparse.unquote(u) for u in matcher.groups()]
  405. (code, response) = yield func(mock_request, *args)
  406. defer.returnValue((code, response))
  407. except CodeMessageException as e:
  408. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  409. raise KeyError("No event can handle %s" % path)
  410. def register_paths(self, method, path_patterns, callback):
  411. for path_pattern in path_patterns:
  412. self.callbacks.append((method, path_pattern, callback))
  413. class MockKey(object):
  414. alg = "mock_alg"
  415. version = "mock_version"
  416. signature = b"\x9a\x87$"
  417. @property
  418. def verify_key(self):
  419. return self
  420. def sign(self, message):
  421. return self
  422. def verify(self, message, sig):
  423. assert sig == b"\x9a\x87$"
  424. def encode(self):
  425. return b"<fake_encoded_key>"
  426. class MockClock(object):
  427. now = 1000
  428. def __init__(self):
  429. # list of lists of [absolute_time, callback, expired] in no particular
  430. # order
  431. self.timers = []
  432. self.loopers = []
  433. def time(self):
  434. return self.now
  435. def time_msec(self):
  436. return self.time() * 1000
  437. def call_later(self, delay, callback, *args, **kwargs):
  438. current_context = LoggingContext.current_context()
  439. def wrapped_callback():
  440. LoggingContext.thread_local.current_context = current_context
  441. callback(*args, **kwargs)
  442. t = [self.now + delay, wrapped_callback, False]
  443. self.timers.append(t)
  444. return t
  445. def looping_call(self, function, interval):
  446. self.loopers.append([function, interval / 1000.0, self.now])
  447. def cancel_call_later(self, timer, ignore_errs=False):
  448. if timer[2]:
  449. if not ignore_errs:
  450. raise Exception("Cannot cancel an expired timer")
  451. timer[2] = True
  452. self.timers = [t for t in self.timers if t != timer]
  453. # For unit testing
  454. def advance_time(self, secs):
  455. self.now += secs
  456. timers = self.timers
  457. self.timers = []
  458. for t in timers:
  459. time, callback, expired = t
  460. if expired:
  461. raise Exception("Timer already expired")
  462. if self.now >= time:
  463. t[2] = True
  464. callback()
  465. else:
  466. self.timers.append(t)
  467. for looped in self.loopers:
  468. func, interval, last = looped
  469. if last + interval < self.now:
  470. func()
  471. looped[2] = self.now
  472. def advance_time_msec(self, ms):
  473. self.advance_time(ms / 1000.0)
  474. def time_bound_deferred(self, d, *args, **kwargs):
  475. # We don't bother timing things out for now.
  476. return d
  477. def _format_call(args, kwargs):
  478. return ", ".join(
  479. ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
  480. )
  481. class DeferredMockCallable(object):
  482. """A callable instance that stores a set of pending call expectations and
  483. return values for them. It allows a unit test to assert that the given set
  484. of function calls are eventually made, by awaiting on them to be called.
  485. """
  486. def __init__(self):
  487. self.expectations = []
  488. self.calls = []
  489. def __call__(self, *args, **kwargs):
  490. self.calls.append((args, kwargs))
  491. if not self.expectations:
  492. raise ValueError(
  493. "%r has no pending calls to handle call(%s)"
  494. % (self, _format_call(args, kwargs))
  495. )
  496. for (call, result, d) in self.expectations:
  497. if args == call[1] and kwargs == call[2]:
  498. d.callback(None)
  499. return result
  500. failure = AssertionError(
  501. "Was not expecting call(%s)" % (_format_call(args, kwargs))
  502. )
  503. for _, _, d in self.expectations:
  504. try:
  505. d.errback(failure)
  506. except Exception:
  507. pass
  508. raise failure
  509. def expect_call_and_return(self, call, result):
  510. self.expectations.append((call, result, defer.Deferred()))
  511. @defer.inlineCallbacks
  512. def await_calls(self, timeout=1000):
  513. deferred = defer.DeferredList(
  514. [d for _, _, d in self.expectations], fireOnOneErrback=True
  515. )
  516. timer = reactor.callLater(
  517. timeout / 1000,
  518. deferred.errback,
  519. AssertionError(
  520. "%d pending calls left: %s"
  521. % (
  522. len([e for e in self.expectations if not e[2].called]),
  523. [e for e in self.expectations if not e[2].called],
  524. )
  525. ),
  526. )
  527. yield deferred
  528. timer.cancel()
  529. self.calls = []
  530. def assert_had_no_calls(self):
  531. if self.calls:
  532. calls = self.calls
  533. self.calls = []
  534. raise AssertionError(
  535. "Expected not to received any calls, got:\n"
  536. + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
  537. )
  538. @defer.inlineCallbacks
  539. def create_room(hs, room_id, creator_id):
  540. """Creates and persist a creation event for the given room
  541. Args:
  542. hs
  543. room_id (str)
  544. creator_id (str)
  545. """
  546. store = hs.get_datastore()
  547. event_builder_factory = hs.get_event_builder_factory()
  548. event_creation_handler = hs.get_event_creation_handler()
  549. builder = event_builder_factory.new(
  550. RoomVersions.V1,
  551. {
  552. "type": EventTypes.Create,
  553. "state_key": "",
  554. "sender": creator_id,
  555. "room_id": room_id,
  556. "content": {},
  557. },
  558. )
  559. event, context = yield event_creation_handler.create_new_client_event(builder)
  560. yield store.persist_event(event, context)