utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import atexit
  16. import hashlib
  17. import os
  18. import uuid
  19. from inspect import getcallargs
  20. from mock import Mock, patch
  21. from six.moves.urllib import parse as urlparse
  22. from twisted.internet import defer, reactor
  23. from synapse.api.errors import CodeMessageException, cs_error
  24. from synapse.federation.transport import server
  25. from synapse.http.server import HttpServer
  26. from synapse.server import HomeServer
  27. from synapse.storage import PostgresEngine
  28. from synapse.storage.engines import create_engine
  29. from synapse.storage.prepare_database import (
  30. _get_or_create_schema_state,
  31. _setup_new_database,
  32. prepare_database,
  33. )
  34. from synapse.util.logcontext import LoggingContext
  35. from synapse.util.ratelimitutils import FederationRateLimiter
  36. # set this to True to run the tests against postgres instead of sqlite.
  37. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  38. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
  39. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  40. def setupdb():
  41. # If we're using PostgreSQL, set up the db once
  42. if USE_POSTGRES_FOR_TESTS:
  43. pgconfig = {
  44. "name": "psycopg2",
  45. "args": {
  46. "database": POSTGRES_BASE_DB,
  47. "user": POSTGRES_USER,
  48. "cp_min": 1,
  49. "cp_max": 5,
  50. },
  51. }
  52. config = Mock()
  53. config.password_providers = []
  54. config.database_config = pgconfig
  55. db_engine = create_engine(pgconfig)
  56. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  57. db_conn.autocommit = True
  58. cur = db_conn.cursor()
  59. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  60. cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
  61. cur.close()
  62. db_conn.close()
  63. # Set up in the db
  64. db_conn = db_engine.module.connect(
  65. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  66. )
  67. cur = db_conn.cursor()
  68. _get_or_create_schema_state(cur, db_engine)
  69. _setup_new_database(cur, db_engine)
  70. db_conn.commit()
  71. cur.close()
  72. db_conn.close()
  73. def _cleanup():
  74. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  75. db_conn.autocommit = True
  76. cur = db_conn.cursor()
  77. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  78. cur.close()
  79. db_conn.close()
  80. atexit.register(_cleanup)
  81. @defer.inlineCallbacks
  82. def setup_test_homeserver(
  83. cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs
  84. ):
  85. """
  86. Setup a homeserver suitable for running tests against. Keyword arguments
  87. are passed to the Homeserver constructor.
  88. If no datastore is supplied, one is created and given to the homeserver.
  89. Args:
  90. cleanup_func : The function used to register a cleanup routine for
  91. after the test.
  92. """
  93. if reactor is None:
  94. from twisted.internet import reactor
  95. if config is None:
  96. config = Mock()
  97. config.signing_key = [MockKey()]
  98. config.event_cache_size = 1
  99. config.enable_registration = True
  100. config.macaroon_secret_key = "not even a little secret"
  101. config.expire_access_token = False
  102. config.server_name = name
  103. config.trusted_third_party_id_servers = []
  104. config.room_invite_state_types = []
  105. config.password_providers = []
  106. config.worker_replication_url = ""
  107. config.worker_app = None
  108. config.email_enable_notifs = False
  109. config.block_non_admin_invites = False
  110. config.federation_domain_whitelist = None
  111. config.federation_rc_reject_limit = 10
  112. config.federation_rc_sleep_limit = 10
  113. config.federation_rc_sleep_delay = 100
  114. config.federation_rc_concurrent = 10
  115. config.filter_timeline_limit = 5000
  116. config.user_directory_search_all_users = False
  117. config.user_consent_server_notice_content = None
  118. config.block_events_without_consent_error = None
  119. config.media_storage_providers = []
  120. config.auto_join_rooms = []
  121. config.limit_usage_by_mau = False
  122. config.hs_disabled = False
  123. config.hs_disabled_message = ""
  124. config.max_mau_value = 50
  125. config.mau_limits_reserved_threepids = []
  126. config.admin_uri = None
  127. # we need a sane default_room_version, otherwise attempts to create rooms will
  128. # fail.
  129. config.default_room_version = "1"
  130. # disable user directory updates, because they get done in the
  131. # background, which upsets the test runner.
  132. config.update_user_directory = False
  133. config.use_frozen_dicts = True
  134. config.ldap_enabled = False
  135. if "clock" not in kargs:
  136. kargs["clock"] = MockClock()
  137. if USE_POSTGRES_FOR_TESTS:
  138. test_db = "synapse_test_%s" % uuid.uuid4().hex
  139. config.database_config = {
  140. "name": "psycopg2",
  141. "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
  142. }
  143. else:
  144. config.database_config = {
  145. "name": "sqlite3",
  146. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  147. }
  148. db_engine = create_engine(config.database_config)
  149. # Create the database before we actually try and connect to it, based off
  150. # the template database we generate in setupdb()
  151. if datastore is None and isinstance(db_engine, PostgresEngine):
  152. db_conn = db_engine.module.connect(
  153. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  154. )
  155. db_conn.autocommit = True
  156. cur = db_conn.cursor()
  157. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  158. cur.execute(
  159. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  160. )
  161. cur.close()
  162. db_conn.close()
  163. # we need to configure the connection pool to run the on_new_connection
  164. # function, so that we can test code that uses custom sqlite functions
  165. # (like rank).
  166. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  167. if datastore is None:
  168. hs = HomeServer(
  169. name,
  170. config=config,
  171. db_config=config.database_config,
  172. version_string="Synapse/tests",
  173. database_engine=db_engine,
  174. room_list_handler=object(),
  175. tls_server_context_factory=Mock(),
  176. tls_client_options_factory=Mock(),
  177. reactor=reactor,
  178. **kargs
  179. )
  180. # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
  181. # date db
  182. if not isinstance(db_engine, PostgresEngine):
  183. db_conn = hs.get_db_conn()
  184. yield prepare_database(db_conn, db_engine, config)
  185. db_conn.commit()
  186. db_conn.close()
  187. else:
  188. # We need to do cleanup on PostgreSQL
  189. def cleanup():
  190. # Close all the db pools
  191. hs.get_db_pool().close()
  192. # Drop the test database
  193. db_conn = db_engine.module.connect(
  194. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  195. )
  196. db_conn.autocommit = True
  197. cur = db_conn.cursor()
  198. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  199. db_conn.commit()
  200. cur.close()
  201. db_conn.close()
  202. # Register the cleanup hook
  203. cleanup_func(cleanup)
  204. hs.setup()
  205. else:
  206. hs = HomeServer(
  207. name,
  208. db_pool=None,
  209. datastore=datastore,
  210. config=config,
  211. version_string="Synapse/tests",
  212. database_engine=db_engine,
  213. room_list_handler=object(),
  214. tls_server_context_factory=Mock(),
  215. tls_client_options_factory=Mock(),
  216. reactor=reactor,
  217. **kargs
  218. )
  219. # bcrypt is far too slow to be doing in unit tests
  220. # Need to let the HS build an auth handler and then mess with it
  221. # because AuthHandler's constructor requires the HS, so we can't make one
  222. # beforehand and pass it in to the HS's constructor (chicken / egg)
  223. hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
  224. hs.get_auth_handler().validate_hash = (
  225. lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
  226. )
  227. fed = kargs.get("resource_for_federation", None)
  228. if fed:
  229. server.register_servlets(
  230. hs,
  231. resource=fed,
  232. authenticator=server.Authenticator(hs),
  233. ratelimiter=FederationRateLimiter(
  234. hs.get_clock(),
  235. window_size=hs.config.federation_rc_window_size,
  236. sleep_limit=hs.config.federation_rc_sleep_limit,
  237. sleep_msec=hs.config.federation_rc_sleep_delay,
  238. reject_limit=hs.config.federation_rc_reject_limit,
  239. concurrent_requests=hs.config.federation_rc_concurrent,
  240. ),
  241. )
  242. defer.returnValue(hs)
  243. def get_mock_call_args(pattern_func, mock_func):
  244. """ Return the arguments the mock function was called with interpreted
  245. by the pattern functions argument list.
  246. """
  247. invoked_args, invoked_kargs = mock_func.call_args
  248. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  249. def mock_getRawHeaders(headers=None):
  250. headers = headers if headers is not None else {}
  251. def getRawHeaders(name, default=None):
  252. return headers.get(name, default)
  253. return getRawHeaders
  254. # This is a mock /resource/ not an entire server
  255. class MockHttpResource(HttpServer):
  256. def __init__(self, prefix=""):
  257. self.callbacks = [] # 3-tuple of method/pattern/function
  258. self.prefix = prefix
  259. def trigger_get(self, path):
  260. return self.trigger(b"GET", path, None)
  261. @patch('twisted.web.http.Request')
  262. @defer.inlineCallbacks
  263. def trigger(self, http_method, path, content, mock_request, federation_auth=False):
  264. """ Fire an HTTP event.
  265. Args:
  266. http_method : The HTTP method
  267. path : The HTTP path
  268. content : The HTTP body
  269. mock_request : Mocked request to pass to the event so it can get
  270. content.
  271. Returns:
  272. A tuple of (code, response)
  273. Raises:
  274. KeyError If no event is found which will handle the path.
  275. """
  276. path = self.prefix + path
  277. # annoyingly we return a twisted http request which has chained calls
  278. # to get at the http content, hence mock it here.
  279. mock_content = Mock()
  280. config = {'read.return_value': content}
  281. mock_content.configure_mock(**config)
  282. mock_request.content = mock_content
  283. mock_request.method = http_method.encode('ascii')
  284. mock_request.uri = path.encode('ascii')
  285. mock_request.getClientIP.return_value = "-"
  286. headers = {}
  287. if federation_auth:
  288. headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
  289. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  290. # return the right path if the event requires it
  291. mock_request.path = path
  292. # add in query params to the right place
  293. try:
  294. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  295. mock_request.path = path.split('?')[0]
  296. path = mock_request.path
  297. except Exception:
  298. pass
  299. if isinstance(path, bytes):
  300. path = path.decode('utf8')
  301. for (method, pattern, func) in self.callbacks:
  302. if http_method != method:
  303. continue
  304. matcher = pattern.match(path)
  305. if matcher:
  306. try:
  307. args = [urlparse.unquote(u) for u in matcher.groups()]
  308. (code, response) = yield func(mock_request, *args)
  309. defer.returnValue((code, response))
  310. except CodeMessageException as e:
  311. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  312. raise KeyError("No event can handle %s" % path)
  313. def register_paths(self, method, path_patterns, callback):
  314. for path_pattern in path_patterns:
  315. self.callbacks.append((method, path_pattern, callback))
  316. class MockKey(object):
  317. alg = "mock_alg"
  318. version = "mock_version"
  319. signature = b"\x9a\x87$"
  320. @property
  321. def verify_key(self):
  322. return self
  323. def sign(self, message):
  324. return self
  325. def verify(self, message, sig):
  326. assert sig == b"\x9a\x87$"
  327. class MockClock(object):
  328. now = 1000
  329. def __init__(self):
  330. # list of lists of [absolute_time, callback, expired] in no particular
  331. # order
  332. self.timers = []
  333. self.loopers = []
  334. def time(self):
  335. return self.now
  336. def time_msec(self):
  337. return self.time() * 1000
  338. def call_later(self, delay, callback, *args, **kwargs):
  339. current_context = LoggingContext.current_context()
  340. def wrapped_callback():
  341. LoggingContext.thread_local.current_context = current_context
  342. callback(*args, **kwargs)
  343. t = [self.now + delay, wrapped_callback, False]
  344. self.timers.append(t)
  345. return t
  346. def looping_call(self, function, interval):
  347. self.loopers.append([function, interval / 1000., self.now])
  348. def cancel_call_later(self, timer, ignore_errs=False):
  349. if timer[2]:
  350. if not ignore_errs:
  351. raise Exception("Cannot cancel an expired timer")
  352. timer[2] = True
  353. self.timers = [t for t in self.timers if t != timer]
  354. # For unit testing
  355. def advance_time(self, secs):
  356. self.now += secs
  357. timers = self.timers
  358. self.timers = []
  359. for t in timers:
  360. time, callback, expired = t
  361. if expired:
  362. raise Exception("Timer already expired")
  363. if self.now >= time:
  364. t[2] = True
  365. callback()
  366. else:
  367. self.timers.append(t)
  368. for looped in self.loopers:
  369. func, interval, last = looped
  370. if last + interval < self.now:
  371. func()
  372. looped[2] = self.now
  373. def advance_time_msec(self, ms):
  374. self.advance_time(ms / 1000.)
  375. def time_bound_deferred(self, d, *args, **kwargs):
  376. # We don't bother timing things out for now.
  377. return d
  378. def _format_call(args, kwargs):
  379. return ", ".join(
  380. ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
  381. )
  382. class DeferredMockCallable(object):
  383. """A callable instance that stores a set of pending call expectations and
  384. return values for them. It allows a unit test to assert that the given set
  385. of function calls are eventually made, by awaiting on them to be called.
  386. """
  387. def __init__(self):
  388. self.expectations = []
  389. self.calls = []
  390. def __call__(self, *args, **kwargs):
  391. self.calls.append((args, kwargs))
  392. if not self.expectations:
  393. raise ValueError(
  394. "%r has no pending calls to handle call(%s)"
  395. % (self, _format_call(args, kwargs))
  396. )
  397. for (call, result, d) in self.expectations:
  398. if args == call[1] and kwargs == call[2]:
  399. d.callback(None)
  400. return result
  401. failure = AssertionError(
  402. "Was not expecting call(%s)" % (_format_call(args, kwargs))
  403. )
  404. for _, _, d in self.expectations:
  405. try:
  406. d.errback(failure)
  407. except Exception:
  408. pass
  409. raise failure
  410. def expect_call_and_return(self, call, result):
  411. self.expectations.append((call, result, defer.Deferred()))
  412. @defer.inlineCallbacks
  413. def await_calls(self, timeout=1000):
  414. deferred = defer.DeferredList(
  415. [d for _, _, d in self.expectations], fireOnOneErrback=True
  416. )
  417. timer = reactor.callLater(
  418. timeout / 1000,
  419. deferred.errback,
  420. AssertionError(
  421. "%d pending calls left: %s"
  422. % (
  423. len([e for e in self.expectations if not e[2].called]),
  424. [e for e in self.expectations if not e[2].called],
  425. )
  426. ),
  427. )
  428. yield deferred
  429. timer.cancel()
  430. self.calls = []
  431. def assert_had_no_calls(self):
  432. if self.calls:
  433. calls = self.calls
  434. self.calls = []
  435. raise AssertionError(
  436. "Expected not to received any calls, got:\n"
  437. + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
  438. )