utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import atexit
  16. import hashlib
  17. import os
  18. import uuid
  19. from inspect import getcallargs
  20. from mock import Mock, patch
  21. from six.moves.urllib import parse as urlparse
  22. from twisted.internet import defer, reactor
  23. from synapse.api.constants import EventTypes
  24. from synapse.api.errors import CodeMessageException, cs_error
  25. from synapse.config.server import ServerConfig
  26. from synapse.federation.transport import server
  27. from synapse.http.server import HttpServer
  28. from synapse.server import HomeServer
  29. from synapse.storage import DataStore
  30. from synapse.storage.engines import PostgresEngine, create_engine
  31. from synapse.storage.prepare_database import (
  32. _get_or_create_schema_state,
  33. _setup_new_database,
  34. prepare_database,
  35. )
  36. from synapse.util.logcontext import LoggingContext
  37. from synapse.util.ratelimitutils import FederationRateLimiter
  38. # set this to True to run the tests against postgres instead of sqlite.
  39. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  40. LEAVE_DB = os.environ.get("SYNAPSE_LEAVE_DB", False)
  41. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
  42. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  43. def setupdb():
  44. # If we're using PostgreSQL, set up the db once
  45. if USE_POSTGRES_FOR_TESTS:
  46. pgconfig = {
  47. "name": "psycopg2",
  48. "args": {
  49. "database": POSTGRES_BASE_DB,
  50. "user": POSTGRES_USER,
  51. "cp_min": 1,
  52. "cp_max": 5,
  53. },
  54. }
  55. config = Mock()
  56. config.password_providers = []
  57. config.database_config = pgconfig
  58. db_engine = create_engine(pgconfig)
  59. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  60. db_conn.autocommit = True
  61. cur = db_conn.cursor()
  62. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  63. cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
  64. cur.close()
  65. db_conn.close()
  66. # Set up in the db
  67. db_conn = db_engine.module.connect(
  68. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  69. )
  70. cur = db_conn.cursor()
  71. _get_or_create_schema_state(cur, db_engine)
  72. _setup_new_database(cur, db_engine)
  73. db_conn.commit()
  74. cur.close()
  75. db_conn.close()
  76. def _cleanup():
  77. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  78. db_conn.autocommit = True
  79. cur = db_conn.cursor()
  80. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  81. cur.close()
  82. db_conn.close()
  83. atexit.register(_cleanup)
  84. class TestHomeServer(HomeServer):
  85. DATASTORE_CLASS = DataStore
  86. @defer.inlineCallbacks
  87. def setup_test_homeserver(
  88. cleanup_func, name="test", datastore=None, config=None, reactor=None,
  89. homeserverToUse=TestHomeServer, **kargs
  90. ):
  91. """
  92. Setup a homeserver suitable for running tests against. Keyword arguments
  93. are passed to the Homeserver constructor.
  94. If no datastore is supplied, one is created and given to the homeserver.
  95. Args:
  96. cleanup_func : The function used to register a cleanup routine for
  97. after the test.
  98. """
  99. if reactor is None:
  100. from twisted.internet import reactor
  101. if config is None:
  102. config = Mock()
  103. config.signing_key = [MockKey()]
  104. config.event_cache_size = 1
  105. config.enable_registration = True
  106. config.macaroon_secret_key = "not even a little secret"
  107. config.expire_access_token = False
  108. config.server_name = name
  109. config.trusted_third_party_id_servers = []
  110. config.room_invite_state_types = []
  111. config.password_providers = []
  112. config.worker_replication_url = ""
  113. config.worker_app = None
  114. config.email_enable_notifs = False
  115. config.block_non_admin_invites = False
  116. config.federation_domain_whitelist = None
  117. config.federation_rc_reject_limit = 10
  118. config.federation_rc_sleep_limit = 10
  119. config.federation_rc_sleep_delay = 100
  120. config.federation_rc_concurrent = 10
  121. config.filter_timeline_limit = 5000
  122. config.user_directory_search_all_users = False
  123. config.user_consent_server_notice_content = None
  124. config.block_events_without_consent_error = None
  125. config.media_storage_providers = []
  126. config.auto_join_rooms = []
  127. config.limit_usage_by_mau = False
  128. config.hs_disabled = False
  129. config.hs_disabled_message = ""
  130. config.hs_disabled_limit_type = ""
  131. config.max_mau_value = 50
  132. config.mau_limits_reserved_threepids = []
  133. config.admin_contact = None
  134. config.rc_messages_per_second = 10000
  135. config.rc_message_burst_count = 10000
  136. # we need a sane default_room_version, otherwise attempts to create rooms will
  137. # fail.
  138. config.default_room_version = "1"
  139. # disable user directory updates, because they get done in the
  140. # background, which upsets the test runner.
  141. config.update_user_directory = False
  142. def is_threepid_reserved(threepid):
  143. return ServerConfig.is_threepid_reserved(config, threepid)
  144. config.is_threepid_reserved.side_effect = is_threepid_reserved
  145. config.use_frozen_dicts = True
  146. config.ldap_enabled = False
  147. if "clock" not in kargs:
  148. kargs["clock"] = MockClock()
  149. if USE_POSTGRES_FOR_TESTS:
  150. test_db = "synapse_test_%s" % uuid.uuid4().hex
  151. config.database_config = {
  152. "name": "psycopg2",
  153. "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
  154. }
  155. else:
  156. config.database_config = {
  157. "name": "sqlite3",
  158. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  159. }
  160. db_engine = create_engine(config.database_config)
  161. # Create the database before we actually try and connect to it, based off
  162. # the template database we generate in setupdb()
  163. if datastore is None and isinstance(db_engine, PostgresEngine):
  164. db_conn = db_engine.module.connect(
  165. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  166. )
  167. db_conn.autocommit = True
  168. cur = db_conn.cursor()
  169. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  170. cur.execute(
  171. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  172. )
  173. cur.close()
  174. db_conn.close()
  175. # we need to configure the connection pool to run the on_new_connection
  176. # function, so that we can test code that uses custom sqlite functions
  177. # (like rank).
  178. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  179. if datastore is None:
  180. hs = homeserverToUse(
  181. name,
  182. config=config,
  183. db_config=config.database_config,
  184. version_string="Synapse/tests",
  185. database_engine=db_engine,
  186. room_list_handler=object(),
  187. tls_server_context_factory=Mock(),
  188. tls_client_options_factory=Mock(),
  189. reactor=reactor,
  190. **kargs
  191. )
  192. # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
  193. # date db
  194. if not isinstance(db_engine, PostgresEngine):
  195. db_conn = hs.get_db_conn()
  196. yield prepare_database(db_conn, db_engine, config)
  197. db_conn.commit()
  198. db_conn.close()
  199. else:
  200. # We need to do cleanup on PostgreSQL
  201. def cleanup():
  202. # Close all the db pools
  203. hs.get_db_pool().close()
  204. # Drop the test database
  205. db_conn = db_engine.module.connect(
  206. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  207. )
  208. db_conn.autocommit = True
  209. cur = db_conn.cursor()
  210. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  211. db_conn.commit()
  212. cur.close()
  213. db_conn.close()
  214. if not LEAVE_DB:
  215. # Register the cleanup hook
  216. cleanup_func(cleanup)
  217. hs.setup()
  218. else:
  219. hs = homeserverToUse(
  220. name,
  221. db_pool=None,
  222. datastore=datastore,
  223. config=config,
  224. version_string="Synapse/tests",
  225. database_engine=db_engine,
  226. room_list_handler=object(),
  227. tls_server_context_factory=Mock(),
  228. tls_client_options_factory=Mock(),
  229. reactor=reactor,
  230. **kargs
  231. )
  232. # bcrypt is far too slow to be doing in unit tests
  233. # Need to let the HS build an auth handler and then mess with it
  234. # because AuthHandler's constructor requires the HS, so we can't make one
  235. # beforehand and pass it in to the HS's constructor (chicken / egg)
  236. hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
  237. hs.get_auth_handler().validate_hash = (
  238. lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
  239. )
  240. fed = kargs.get("resource_for_federation", None)
  241. if fed:
  242. server.register_servlets(
  243. hs,
  244. resource=fed,
  245. authenticator=server.Authenticator(hs),
  246. ratelimiter=FederationRateLimiter(
  247. hs.get_clock(),
  248. window_size=hs.config.federation_rc_window_size,
  249. sleep_limit=hs.config.federation_rc_sleep_limit,
  250. sleep_msec=hs.config.federation_rc_sleep_delay,
  251. reject_limit=hs.config.federation_rc_reject_limit,
  252. concurrent_requests=hs.config.federation_rc_concurrent,
  253. ),
  254. )
  255. defer.returnValue(hs)
  256. def get_mock_call_args(pattern_func, mock_func):
  257. """ Return the arguments the mock function was called with interpreted
  258. by the pattern functions argument list.
  259. """
  260. invoked_args, invoked_kargs = mock_func.call_args
  261. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  262. def mock_getRawHeaders(headers=None):
  263. headers = headers if headers is not None else {}
  264. def getRawHeaders(name, default=None):
  265. return headers.get(name, default)
  266. return getRawHeaders
  267. # This is a mock /resource/ not an entire server
  268. class MockHttpResource(HttpServer):
  269. def __init__(self, prefix=""):
  270. self.callbacks = [] # 3-tuple of method/pattern/function
  271. self.prefix = prefix
  272. def trigger_get(self, path):
  273. return self.trigger(b"GET", path, None)
  274. @patch('twisted.web.http.Request')
  275. @defer.inlineCallbacks
  276. def trigger(
  277. self, http_method, path, content, mock_request,
  278. federation_auth_origin=None,
  279. ):
  280. """ Fire an HTTP event.
  281. Args:
  282. http_method : The HTTP method
  283. path : The HTTP path
  284. content : The HTTP body
  285. mock_request : Mocked request to pass to the event so it can get
  286. content.
  287. federation_auth_origin (bytes|None): domain to authenticate as, for federation
  288. Returns:
  289. A tuple of (code, response)
  290. Raises:
  291. KeyError If no event is found which will handle the path.
  292. """
  293. path = self.prefix + path
  294. # annoyingly we return a twisted http request which has chained calls
  295. # to get at the http content, hence mock it here.
  296. mock_content = Mock()
  297. config = {'read.return_value': content}
  298. mock_content.configure_mock(**config)
  299. mock_request.content = mock_content
  300. mock_request.method = http_method.encode('ascii')
  301. mock_request.uri = path.encode('ascii')
  302. mock_request.getClientIP.return_value = "-"
  303. headers = {}
  304. if federation_auth_origin is not None:
  305. headers[b"Authorization"] = [
  306. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin, )
  307. ]
  308. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  309. # return the right path if the event requires it
  310. mock_request.path = path
  311. # add in query params to the right place
  312. try:
  313. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  314. mock_request.path = path.split('?')[0]
  315. path = mock_request.path
  316. except Exception:
  317. pass
  318. if isinstance(path, bytes):
  319. path = path.decode('utf8')
  320. for (method, pattern, func) in self.callbacks:
  321. if http_method != method:
  322. continue
  323. matcher = pattern.match(path)
  324. if matcher:
  325. try:
  326. args = [urlparse.unquote(u) for u in matcher.groups()]
  327. (code, response) = yield func(mock_request, *args)
  328. defer.returnValue((code, response))
  329. except CodeMessageException as e:
  330. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  331. raise KeyError("No event can handle %s" % path)
  332. def register_paths(self, method, path_patterns, callback):
  333. for path_pattern in path_patterns:
  334. self.callbacks.append((method, path_pattern, callback))
  335. class MockKey(object):
  336. alg = "mock_alg"
  337. version = "mock_version"
  338. signature = b"\x9a\x87$"
  339. @property
  340. def verify_key(self):
  341. return self
  342. def sign(self, message):
  343. return self
  344. def verify(self, message, sig):
  345. assert sig == b"\x9a\x87$"
  346. class MockClock(object):
  347. now = 1000
  348. def __init__(self):
  349. # list of lists of [absolute_time, callback, expired] in no particular
  350. # order
  351. self.timers = []
  352. self.loopers = []
  353. def time(self):
  354. return self.now
  355. def time_msec(self):
  356. return self.time() * 1000
  357. def call_later(self, delay, callback, *args, **kwargs):
  358. current_context = LoggingContext.current_context()
  359. def wrapped_callback():
  360. LoggingContext.thread_local.current_context = current_context
  361. callback(*args, **kwargs)
  362. t = [self.now + delay, wrapped_callback, False]
  363. self.timers.append(t)
  364. return t
  365. def looping_call(self, function, interval):
  366. self.loopers.append([function, interval / 1000., self.now])
  367. def cancel_call_later(self, timer, ignore_errs=False):
  368. if timer[2]:
  369. if not ignore_errs:
  370. raise Exception("Cannot cancel an expired timer")
  371. timer[2] = True
  372. self.timers = [t for t in self.timers if t != timer]
  373. # For unit testing
  374. def advance_time(self, secs):
  375. self.now += secs
  376. timers = self.timers
  377. self.timers = []
  378. for t in timers:
  379. time, callback, expired = t
  380. if expired:
  381. raise Exception("Timer already expired")
  382. if self.now >= time:
  383. t[2] = True
  384. callback()
  385. else:
  386. self.timers.append(t)
  387. for looped in self.loopers:
  388. func, interval, last = looped
  389. if last + interval < self.now:
  390. func()
  391. looped[2] = self.now
  392. def advance_time_msec(self, ms):
  393. self.advance_time(ms / 1000.)
  394. def time_bound_deferred(self, d, *args, **kwargs):
  395. # We don't bother timing things out for now.
  396. return d
  397. def _format_call(args, kwargs):
  398. return ", ".join(
  399. ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
  400. )
  401. class DeferredMockCallable(object):
  402. """A callable instance that stores a set of pending call expectations and
  403. return values for them. It allows a unit test to assert that the given set
  404. of function calls are eventually made, by awaiting on them to be called.
  405. """
  406. def __init__(self):
  407. self.expectations = []
  408. self.calls = []
  409. def __call__(self, *args, **kwargs):
  410. self.calls.append((args, kwargs))
  411. if not self.expectations:
  412. raise ValueError(
  413. "%r has no pending calls to handle call(%s)"
  414. % (self, _format_call(args, kwargs))
  415. )
  416. for (call, result, d) in self.expectations:
  417. if args == call[1] and kwargs == call[2]:
  418. d.callback(None)
  419. return result
  420. failure = AssertionError(
  421. "Was not expecting call(%s)" % (_format_call(args, kwargs))
  422. )
  423. for _, _, d in self.expectations:
  424. try:
  425. d.errback(failure)
  426. except Exception:
  427. pass
  428. raise failure
  429. def expect_call_and_return(self, call, result):
  430. self.expectations.append((call, result, defer.Deferred()))
  431. @defer.inlineCallbacks
  432. def await_calls(self, timeout=1000):
  433. deferred = defer.DeferredList(
  434. [d for _, _, d in self.expectations], fireOnOneErrback=True
  435. )
  436. timer = reactor.callLater(
  437. timeout / 1000,
  438. deferred.errback,
  439. AssertionError(
  440. "%d pending calls left: %s"
  441. % (
  442. len([e for e in self.expectations if not e[2].called]),
  443. [e for e in self.expectations if not e[2].called],
  444. )
  445. ),
  446. )
  447. yield deferred
  448. timer.cancel()
  449. self.calls = []
  450. def assert_had_no_calls(self):
  451. if self.calls:
  452. calls = self.calls
  453. self.calls = []
  454. raise AssertionError(
  455. "Expected not to received any calls, got:\n"
  456. + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
  457. )
  458. @defer.inlineCallbacks
  459. def create_room(hs, room_id, creator_id):
  460. """Creates and persist a creation event for the given room
  461. Args:
  462. hs
  463. room_id (str)
  464. creator_id (str)
  465. """
  466. store = hs.get_datastore()
  467. event_builder_factory = hs.get_event_builder_factory()
  468. event_creation_handler = hs.get_event_creation_handler()
  469. builder = event_builder_factory.new({
  470. "type": EventTypes.Create,
  471. "state_key": "",
  472. "sender": creator_id,
  473. "room_id": room_id,
  474. "content": {},
  475. })
  476. event, context = yield event_creation_handler.create_new_client_event(
  477. builder
  478. )
  479. yield store.persist_event(event, context)