utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import atexit
  16. import hashlib
  17. import os
  18. import time
  19. import uuid
  20. import warnings
  21. from inspect import getcallargs
  22. from mock import Mock, patch
  23. from six.moves.urllib import parse as urlparse
  24. from twisted.internet import defer, reactor
  25. from synapse.api.constants import EventTypes, RoomVersions
  26. from synapse.api.errors import CodeMessageException, cs_error
  27. from synapse.config.server import ServerConfig
  28. from synapse.federation.transport import server as federation_server
  29. from synapse.http.server import HttpServer
  30. from synapse.server import HomeServer
  31. from synapse.storage import DataStore
  32. from synapse.storage.engines import PostgresEngine, create_engine
  33. from synapse.storage.prepare_database import (
  34. _get_or_create_schema_state,
  35. _setup_new_database,
  36. prepare_database,
  37. )
  38. from synapse.util.logcontext import LoggingContext
  39. from synapse.util.ratelimitutils import FederationRateLimiter
  40. # set this to True to run the tests against postgres instead of sqlite.
  41. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  42. LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
  43. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", None)
  44. POSTGRES_HOST = os.environ.get("SYNAPSE_POSTGRES_HOST", None)
  45. POSTGRES_PASSWORD = os.environ.get("SYNAPSE_POSTGRES_PASSWORD", None)
  46. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  47. def setupdb():
  48. # If we're using PostgreSQL, set up the db once
  49. if USE_POSTGRES_FOR_TESTS:
  50. pgconfig = {
  51. "name": "psycopg2",
  52. "args": {
  53. "database": POSTGRES_BASE_DB,
  54. "user": POSTGRES_USER,
  55. "host": POSTGRES_HOST,
  56. "password": POSTGRES_PASSWORD,
  57. "cp_min": 1,
  58. "cp_max": 5,
  59. },
  60. }
  61. config = Mock()
  62. config.password_providers = []
  63. config.database_config = pgconfig
  64. db_engine = create_engine(pgconfig)
  65. db_conn = db_engine.module.connect(
  66. user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
  67. )
  68. db_conn.autocommit = True
  69. cur = db_conn.cursor()
  70. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  71. cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
  72. cur.close()
  73. db_conn.close()
  74. # Set up in the db
  75. db_conn = db_engine.module.connect(
  76. database=POSTGRES_BASE_DB,
  77. user=POSTGRES_USER,
  78. host=POSTGRES_HOST,
  79. password=POSTGRES_PASSWORD,
  80. )
  81. cur = db_conn.cursor()
  82. _get_or_create_schema_state(cur, db_engine)
  83. _setup_new_database(cur, db_engine)
  84. db_conn.commit()
  85. cur.close()
  86. db_conn.close()
  87. def _cleanup():
  88. db_conn = db_engine.module.connect(
  89. user=POSTGRES_USER, host=POSTGRES_HOST, password=POSTGRES_PASSWORD
  90. )
  91. db_conn.autocommit = True
  92. cur = db_conn.cursor()
  93. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  94. cur.close()
  95. db_conn.close()
  96. atexit.register(_cleanup)
  97. def default_config(name):
  98. """
  99. Create a reasonable test config.
  100. """
  101. config = Mock()
  102. config.signing_key = [MockKey()]
  103. config.event_cache_size = 1
  104. config.enable_registration = True
  105. config.macaroon_secret_key = "not even a little secret"
  106. config.expire_access_token = False
  107. config.server_name = name
  108. config.trusted_third_party_id_servers = []
  109. config.room_invite_state_types = []
  110. config.password_providers = []
  111. config.worker_replication_url = ""
  112. config.worker_app = None
  113. config.email_enable_notifs = False
  114. config.block_non_admin_invites = False
  115. config.federation_domain_whitelist = None
  116. config.federation_rc_reject_limit = 10
  117. config.federation_rc_sleep_limit = 10
  118. config.federation_rc_sleep_delay = 100
  119. config.federation_rc_concurrent = 10
  120. config.filter_timeline_limit = 5000
  121. config.user_directory_search_all_users = False
  122. config.user_consent_server_notice_content = None
  123. config.block_events_without_consent_error = None
  124. config.user_consent_at_registration = False
  125. config.user_consent_policy_name = "Privacy Policy"
  126. config.media_storage_providers = []
  127. config.autocreate_auto_join_rooms = True
  128. config.auto_join_rooms = []
  129. config.limit_usage_by_mau = False
  130. config.hs_disabled = False
  131. config.hs_disabled_message = ""
  132. config.hs_disabled_limit_type = ""
  133. config.max_mau_value = 50
  134. config.mau_trial_days = 0
  135. config.mau_stats_only = False
  136. config.mau_limits_reserved_threepids = []
  137. config.admin_contact = None
  138. config.rc_messages_per_second = 10000
  139. config.rc_message_burst_count = 10000
  140. config.saml2_enabled = False
  141. config.public_baseurl = None
  142. config.default_identity_server = None
  143. config.key_refresh_interval = 24 * 60 * 60 * 1000
  144. config.old_signing_keys = {}
  145. config.tls_fingerprints = []
  146. config.use_frozen_dicts = False
  147. # we need a sane default_room_version, otherwise attempts to create rooms will
  148. # fail.
  149. config.default_room_version = "1"
  150. # disable user directory updates, because they get done in the
  151. # background, which upsets the test runner.
  152. config.update_user_directory = False
  153. def is_threepid_reserved(threepid):
  154. return ServerConfig.is_threepid_reserved(
  155. config.mau_limits_reserved_threepids, threepid
  156. )
  157. config.is_threepid_reserved.side_effect = is_threepid_reserved
  158. return config
  159. class TestHomeServer(HomeServer):
  160. DATASTORE_CLASS = DataStore
  161. @defer.inlineCallbacks
  162. def setup_test_homeserver(
  163. cleanup_func,
  164. name="test",
  165. datastore=None,
  166. config=None,
  167. reactor=None,
  168. homeserverToUse=TestHomeServer,
  169. **kargs
  170. ):
  171. """
  172. Setup a homeserver suitable for running tests against. Keyword arguments
  173. are passed to the Homeserver constructor.
  174. If no datastore is supplied, one is created and given to the homeserver.
  175. Args:
  176. cleanup_func : The function used to register a cleanup routine for
  177. after the test.
  178. Calling this method directly is deprecated: you should instead derive from
  179. HomeserverTestCase.
  180. """
  181. if reactor is None:
  182. from twisted.internet import reactor
  183. if config is None:
  184. config = default_config(name)
  185. config.ldap_enabled = False
  186. if "clock" not in kargs:
  187. kargs["clock"] = MockClock()
  188. if USE_POSTGRES_FOR_TESTS:
  189. test_db = "synapse_test_%s" % uuid.uuid4().hex
  190. config.database_config = {
  191. "name": "psycopg2",
  192. "args": {
  193. "database": test_db,
  194. "host": POSTGRES_HOST,
  195. "password": POSTGRES_PASSWORD,
  196. "user": POSTGRES_USER,
  197. "cp_min": 1,
  198. "cp_max": 5,
  199. },
  200. }
  201. else:
  202. config.database_config = {
  203. "name": "sqlite3",
  204. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  205. }
  206. db_engine = create_engine(config.database_config)
  207. # Create the database before we actually try and connect to it, based off
  208. # the template database we generate in setupdb()
  209. if datastore is None and isinstance(db_engine, PostgresEngine):
  210. db_conn = db_engine.module.connect(
  211. database=POSTGRES_BASE_DB,
  212. user=POSTGRES_USER,
  213. host=POSTGRES_HOST,
  214. password=POSTGRES_PASSWORD,
  215. )
  216. db_conn.autocommit = True
  217. cur = db_conn.cursor()
  218. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  219. cur.execute(
  220. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  221. )
  222. cur.close()
  223. db_conn.close()
  224. # we need to configure the connection pool to run the on_new_connection
  225. # function, so that we can test code that uses custom sqlite functions
  226. # (like rank).
  227. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  228. if datastore is None:
  229. hs = homeserverToUse(
  230. name,
  231. config=config,
  232. db_config=config.database_config,
  233. version_string="Synapse/tests",
  234. database_engine=db_engine,
  235. room_list_handler=object(),
  236. tls_server_context_factory=Mock(),
  237. tls_client_options_factory=Mock(),
  238. reactor=reactor,
  239. **kargs
  240. )
  241. # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
  242. # date db
  243. if not isinstance(db_engine, PostgresEngine):
  244. db_conn = hs.get_db_conn()
  245. yield prepare_database(db_conn, db_engine, config)
  246. db_conn.commit()
  247. db_conn.close()
  248. else:
  249. # We need to do cleanup on PostgreSQL
  250. def cleanup():
  251. import psycopg2
  252. # Close all the db pools
  253. hs.get_db_pool().close()
  254. dropped = False
  255. # Drop the test database
  256. db_conn = db_engine.module.connect(
  257. database=POSTGRES_BASE_DB,
  258. user=POSTGRES_USER,
  259. host=POSTGRES_HOST,
  260. password=POSTGRES_PASSWORD,
  261. )
  262. db_conn.autocommit = True
  263. cur = db_conn.cursor()
  264. # Try a few times to drop the DB. Some things may hold on to the
  265. # database for a few more seconds due to flakiness, preventing
  266. # us from dropping it when the test is over. If we can't drop
  267. # it, warn and move on.
  268. for x in range(5):
  269. try:
  270. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  271. db_conn.commit()
  272. dropped = True
  273. except psycopg2.OperationalError as e:
  274. warnings.warn(
  275. "Couldn't drop old db: " + str(e), category=UserWarning
  276. )
  277. time.sleep(0.5)
  278. cur.close()
  279. db_conn.close()
  280. if not dropped:
  281. warnings.warn("Failed to drop old DB.", category=UserWarning)
  282. if not LEAVE_DB:
  283. # Register the cleanup hook
  284. cleanup_func(cleanup)
  285. hs.setup()
  286. else:
  287. hs = homeserverToUse(
  288. name,
  289. db_pool=None,
  290. datastore=datastore,
  291. config=config,
  292. version_string="Synapse/tests",
  293. database_engine=db_engine,
  294. room_list_handler=object(),
  295. tls_server_context_factory=Mock(),
  296. tls_client_options_factory=Mock(),
  297. reactor=reactor,
  298. **kargs
  299. )
  300. # bcrypt is far too slow to be doing in unit tests
  301. # Need to let the HS build an auth handler and then mess with it
  302. # because AuthHandler's constructor requires the HS, so we can't make one
  303. # beforehand and pass it in to the HS's constructor (chicken / egg)
  304. hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
  305. hs.get_auth_handler().validate_hash = (
  306. lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
  307. )
  308. fed = kargs.get("resource_for_federation", None)
  309. if fed:
  310. register_federation_servlets(hs, fed)
  311. defer.returnValue(hs)
  312. def register_federation_servlets(hs, resource):
  313. federation_server.register_servlets(
  314. hs,
  315. resource=resource,
  316. authenticator=federation_server.Authenticator(hs),
  317. ratelimiter=FederationRateLimiter(
  318. hs.get_clock(),
  319. window_size=hs.config.federation_rc_window_size,
  320. sleep_limit=hs.config.federation_rc_sleep_limit,
  321. sleep_msec=hs.config.federation_rc_sleep_delay,
  322. reject_limit=hs.config.federation_rc_reject_limit,
  323. concurrent_requests=hs.config.federation_rc_concurrent,
  324. ),
  325. )
  326. def get_mock_call_args(pattern_func, mock_func):
  327. """ Return the arguments the mock function was called with interpreted
  328. by the pattern functions argument list.
  329. """
  330. invoked_args, invoked_kargs = mock_func.call_args
  331. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  332. def mock_getRawHeaders(headers=None):
  333. headers = headers if headers is not None else {}
  334. def getRawHeaders(name, default=None):
  335. return headers.get(name, default)
  336. return getRawHeaders
  337. # This is a mock /resource/ not an entire server
  338. class MockHttpResource(HttpServer):
  339. def __init__(self, prefix=""):
  340. self.callbacks = [] # 3-tuple of method/pattern/function
  341. self.prefix = prefix
  342. def trigger_get(self, path):
  343. return self.trigger(b"GET", path, None)
  344. @patch('twisted.web.http.Request')
  345. @defer.inlineCallbacks
  346. def trigger(
  347. self, http_method, path, content, mock_request, federation_auth_origin=None
  348. ):
  349. """ Fire an HTTP event.
  350. Args:
  351. http_method : The HTTP method
  352. path : The HTTP path
  353. content : The HTTP body
  354. mock_request : Mocked request to pass to the event so it can get
  355. content.
  356. federation_auth_origin (bytes|None): domain to authenticate as, for federation
  357. Returns:
  358. A tuple of (code, response)
  359. Raises:
  360. KeyError If no event is found which will handle the path.
  361. """
  362. path = self.prefix + path
  363. # annoyingly we return a twisted http request which has chained calls
  364. # to get at the http content, hence mock it here.
  365. mock_content = Mock()
  366. config = {'read.return_value': content}
  367. mock_content.configure_mock(**config)
  368. mock_request.content = mock_content
  369. mock_request.method = http_method.encode('ascii')
  370. mock_request.uri = path.encode('ascii')
  371. mock_request.getClientIP.return_value = "-"
  372. headers = {}
  373. if federation_auth_origin is not None:
  374. headers[b"Authorization"] = [
  375. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,)
  376. ]
  377. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  378. # return the right path if the event requires it
  379. mock_request.path = path
  380. # add in query params to the right place
  381. try:
  382. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  383. mock_request.path = path.split('?')[0]
  384. path = mock_request.path
  385. except Exception:
  386. pass
  387. if isinstance(path, bytes):
  388. path = path.decode('utf8')
  389. for (method, pattern, func) in self.callbacks:
  390. if http_method != method:
  391. continue
  392. matcher = pattern.match(path)
  393. if matcher:
  394. try:
  395. args = [urlparse.unquote(u) for u in matcher.groups()]
  396. (code, response) = yield func(mock_request, *args)
  397. defer.returnValue((code, response))
  398. except CodeMessageException as e:
  399. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  400. raise KeyError("No event can handle %s" % path)
  401. def register_paths(self, method, path_patterns, callback):
  402. for path_pattern in path_patterns:
  403. self.callbacks.append((method, path_pattern, callback))
  404. class MockKey(object):
  405. alg = "mock_alg"
  406. version = "mock_version"
  407. signature = b"\x9a\x87$"
  408. @property
  409. def verify_key(self):
  410. return self
  411. def sign(self, message):
  412. return self
  413. def verify(self, message, sig):
  414. assert sig == b"\x9a\x87$"
  415. def encode(self):
  416. return b"<fake_encoded_key>"
  417. class MockClock(object):
  418. now = 1000
  419. def __init__(self):
  420. # list of lists of [absolute_time, callback, expired] in no particular
  421. # order
  422. self.timers = []
  423. self.loopers = []
  424. def time(self):
  425. return self.now
  426. def time_msec(self):
  427. return self.time() * 1000
  428. def call_later(self, delay, callback, *args, **kwargs):
  429. current_context = LoggingContext.current_context()
  430. def wrapped_callback():
  431. LoggingContext.thread_local.current_context = current_context
  432. callback(*args, **kwargs)
  433. t = [self.now + delay, wrapped_callback, False]
  434. self.timers.append(t)
  435. return t
  436. def looping_call(self, function, interval):
  437. self.loopers.append([function, interval / 1000.0, self.now])
  438. def cancel_call_later(self, timer, ignore_errs=False):
  439. if timer[2]:
  440. if not ignore_errs:
  441. raise Exception("Cannot cancel an expired timer")
  442. timer[2] = True
  443. self.timers = [t for t in self.timers if t != timer]
  444. # For unit testing
  445. def advance_time(self, secs):
  446. self.now += secs
  447. timers = self.timers
  448. self.timers = []
  449. for t in timers:
  450. time, callback, expired = t
  451. if expired:
  452. raise Exception("Timer already expired")
  453. if self.now >= time:
  454. t[2] = True
  455. callback()
  456. else:
  457. self.timers.append(t)
  458. for looped in self.loopers:
  459. func, interval, last = looped
  460. if last + interval < self.now:
  461. func()
  462. looped[2] = self.now
  463. def advance_time_msec(self, ms):
  464. self.advance_time(ms / 1000.0)
  465. def time_bound_deferred(self, d, *args, **kwargs):
  466. # We don't bother timing things out for now.
  467. return d
  468. def _format_call(args, kwargs):
  469. return ", ".join(
  470. ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
  471. )
  472. class DeferredMockCallable(object):
  473. """A callable instance that stores a set of pending call expectations and
  474. return values for them. It allows a unit test to assert that the given set
  475. of function calls are eventually made, by awaiting on them to be called.
  476. """
  477. def __init__(self):
  478. self.expectations = []
  479. self.calls = []
  480. def __call__(self, *args, **kwargs):
  481. self.calls.append((args, kwargs))
  482. if not self.expectations:
  483. raise ValueError(
  484. "%r has no pending calls to handle call(%s)"
  485. % (self, _format_call(args, kwargs))
  486. )
  487. for (call, result, d) in self.expectations:
  488. if args == call[1] and kwargs == call[2]:
  489. d.callback(None)
  490. return result
  491. failure = AssertionError(
  492. "Was not expecting call(%s)" % (_format_call(args, kwargs))
  493. )
  494. for _, _, d in self.expectations:
  495. try:
  496. d.errback(failure)
  497. except Exception:
  498. pass
  499. raise failure
  500. def expect_call_and_return(self, call, result):
  501. self.expectations.append((call, result, defer.Deferred()))
  502. @defer.inlineCallbacks
  503. def await_calls(self, timeout=1000):
  504. deferred = defer.DeferredList(
  505. [d for _, _, d in self.expectations], fireOnOneErrback=True
  506. )
  507. timer = reactor.callLater(
  508. timeout / 1000,
  509. deferred.errback,
  510. AssertionError(
  511. "%d pending calls left: %s"
  512. % (
  513. len([e for e in self.expectations if not e[2].called]),
  514. [e for e in self.expectations if not e[2].called],
  515. )
  516. ),
  517. )
  518. yield deferred
  519. timer.cancel()
  520. self.calls = []
  521. def assert_had_no_calls(self):
  522. if self.calls:
  523. calls = self.calls
  524. self.calls = []
  525. raise AssertionError(
  526. "Expected not to received any calls, got:\n"
  527. + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
  528. )
  529. @defer.inlineCallbacks
  530. def create_room(hs, room_id, creator_id):
  531. """Creates and persist a creation event for the given room
  532. Args:
  533. hs
  534. room_id (str)
  535. creator_id (str)
  536. """
  537. store = hs.get_datastore()
  538. event_builder_factory = hs.get_event_builder_factory()
  539. event_creation_handler = hs.get_event_creation_handler()
  540. builder = event_builder_factory.new(
  541. RoomVersions.V1,
  542. {
  543. "type": EventTypes.Create,
  544. "state_key": "",
  545. "sender": creator_id,
  546. "room_id": room_id,
  547. "content": {},
  548. },
  549. )
  550. event, context = yield event_creation_handler.create_new_client_event(builder)
  551. yield store.persist_event(event, context)