utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import atexit
  16. import hashlib
  17. import os
  18. import uuid
  19. from inspect import getcallargs
  20. from mock import Mock, patch
  21. from six.moves.urllib import parse as urlparse
  22. from twisted.internet import defer, reactor
  23. from synapse.api.errors import CodeMessageException, cs_error
  24. from synapse.federation.transport import server
  25. from synapse.http.server import HttpServer
  26. from synapse.server import HomeServer
  27. from synapse.storage import PostgresEngine
  28. from synapse.storage.engines import create_engine
  29. from synapse.storage.prepare_database import (
  30. _get_or_create_schema_state,
  31. _setup_new_database,
  32. prepare_database,
  33. )
  34. from synapse.util.logcontext import LoggingContext
  35. from synapse.util.ratelimitutils import FederationRateLimiter
  36. # set this to True to run the tests against postgres instead of sqlite.
  37. USE_POSTGRES_FOR_TESTS = os.environ.get("SYNAPSE_POSTGRES", False)
  38. POSTGRES_USER = os.environ.get("SYNAPSE_POSTGRES_USER", "postgres")
  39. POSTGRES_BASE_DB = "_synapse_unit_tests_base_%s" % (os.getpid(),)
  40. def setupdb():
  41. # If we're using PostgreSQL, set up the db once
  42. if USE_POSTGRES_FOR_TESTS:
  43. pgconfig = {
  44. "name": "psycopg2",
  45. "args": {
  46. "database": POSTGRES_BASE_DB,
  47. "user": POSTGRES_USER,
  48. "cp_min": 1,
  49. "cp_max": 5,
  50. },
  51. }
  52. config = Mock()
  53. config.password_providers = []
  54. config.database_config = pgconfig
  55. db_engine = create_engine(pgconfig)
  56. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  57. db_conn.autocommit = True
  58. cur = db_conn.cursor()
  59. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  60. cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,))
  61. cur.close()
  62. db_conn.close()
  63. # Set up in the db
  64. db_conn = db_engine.module.connect(
  65. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  66. )
  67. cur = db_conn.cursor()
  68. _get_or_create_schema_state(cur, db_engine)
  69. _setup_new_database(cur, db_engine)
  70. db_conn.commit()
  71. cur.close()
  72. db_conn.close()
  73. def _cleanup():
  74. db_conn = db_engine.module.connect(user=POSTGRES_USER)
  75. db_conn.autocommit = True
  76. cur = db_conn.cursor()
  77. cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,))
  78. cur.close()
  79. db_conn.close()
  80. atexit.register(_cleanup)
  81. @defer.inlineCallbacks
  82. def setup_test_homeserver(
  83. cleanup_func, name="test", datastore=None, config=None, reactor=None, **kargs
  84. ):
  85. """
  86. Setup a homeserver suitable for running tests against. Keyword arguments
  87. are passed to the Homeserver constructor.
  88. If no datastore is supplied, one is created and given to the homeserver.
  89. Args:
  90. cleanup_func : The function used to register a cleanup routine for
  91. after the test.
  92. """
  93. if reactor is None:
  94. from twisted.internet import reactor
  95. if config is None:
  96. config = Mock()
  97. config.signing_key = [MockKey()]
  98. config.event_cache_size = 1
  99. config.enable_registration = True
  100. config.macaroon_secret_key = "not even a little secret"
  101. config.expire_access_token = False
  102. config.server_name = name
  103. config.trusted_third_party_id_servers = []
  104. config.room_invite_state_types = []
  105. config.password_providers = []
  106. config.worker_replication_url = ""
  107. config.worker_app = None
  108. config.email_enable_notifs = False
  109. config.block_non_admin_invites = False
  110. config.federation_domain_whitelist = None
  111. config.federation_rc_reject_limit = 10
  112. config.federation_rc_sleep_limit = 10
  113. config.federation_rc_sleep_delay = 100
  114. config.federation_rc_concurrent = 10
  115. config.filter_timeline_limit = 5000
  116. config.user_directory_search_all_users = False
  117. config.user_consent_server_notice_content = None
  118. config.block_events_without_consent_error = None
  119. config.media_storage_providers = []
  120. config.auto_join_rooms = []
  121. config.limit_usage_by_mau = False
  122. config.hs_disabled = False
  123. config.hs_disabled_message = ""
  124. config.max_mau_value = 50
  125. config.mau_limits_reserved_threepids = []
  126. # we need a sane default_room_version, otherwise attempts to create rooms will
  127. # fail.
  128. config.default_room_version = "1"
  129. # disable user directory updates, because they get done in the
  130. # background, which upsets the test runner.
  131. config.update_user_directory = False
  132. config.use_frozen_dicts = True
  133. config.ldap_enabled = False
  134. if "clock" not in kargs:
  135. kargs["clock"] = MockClock()
  136. if USE_POSTGRES_FOR_TESTS:
  137. test_db = "synapse_test_%s" % uuid.uuid4().hex
  138. config.database_config = {
  139. "name": "psycopg2",
  140. "args": {"database": test_db, "cp_min": 1, "cp_max": 5},
  141. }
  142. else:
  143. config.database_config = {
  144. "name": "sqlite3",
  145. "args": {"database": ":memory:", "cp_min": 1, "cp_max": 1},
  146. }
  147. db_engine = create_engine(config.database_config)
  148. # Create the database before we actually try and connect to it, based off
  149. # the template database we generate in setupdb()
  150. if datastore is None and isinstance(db_engine, PostgresEngine):
  151. db_conn = db_engine.module.connect(
  152. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  153. )
  154. db_conn.autocommit = True
  155. cur = db_conn.cursor()
  156. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  157. cur.execute(
  158. "CREATE DATABASE %s WITH TEMPLATE %s;" % (test_db, POSTGRES_BASE_DB)
  159. )
  160. cur.close()
  161. db_conn.close()
  162. # we need to configure the connection pool to run the on_new_connection
  163. # function, so that we can test code that uses custom sqlite functions
  164. # (like rank).
  165. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  166. if datastore is None:
  167. hs = HomeServer(
  168. name,
  169. config=config,
  170. db_config=config.database_config,
  171. version_string="Synapse/tests",
  172. database_engine=db_engine,
  173. room_list_handler=object(),
  174. tls_server_context_factory=Mock(),
  175. tls_client_options_factory=Mock(),
  176. reactor=reactor,
  177. **kargs
  178. )
  179. # Prepare the DB on SQLite -- PostgreSQL is a copy of an already up to
  180. # date db
  181. if not isinstance(db_engine, PostgresEngine):
  182. db_conn = hs.get_db_conn()
  183. yield prepare_database(db_conn, db_engine, config)
  184. db_conn.commit()
  185. db_conn.close()
  186. else:
  187. # We need to do cleanup on PostgreSQL
  188. def cleanup():
  189. # Close all the db pools
  190. hs.get_db_pool().close()
  191. # Drop the test database
  192. db_conn = db_engine.module.connect(
  193. database=POSTGRES_BASE_DB, user=POSTGRES_USER
  194. )
  195. db_conn.autocommit = True
  196. cur = db_conn.cursor()
  197. cur.execute("DROP DATABASE IF EXISTS %s;" % (test_db,))
  198. db_conn.commit()
  199. cur.close()
  200. db_conn.close()
  201. # Register the cleanup hook
  202. cleanup_func(cleanup)
  203. hs.setup()
  204. else:
  205. hs = HomeServer(
  206. name,
  207. db_pool=None,
  208. datastore=datastore,
  209. config=config,
  210. version_string="Synapse/tests",
  211. database_engine=db_engine,
  212. room_list_handler=object(),
  213. tls_server_context_factory=Mock(),
  214. tls_client_options_factory=Mock(),
  215. reactor=reactor,
  216. **kargs
  217. )
  218. # bcrypt is far too slow to be doing in unit tests
  219. # Need to let the HS build an auth handler and then mess with it
  220. # because AuthHandler's constructor requires the HS, so we can't make one
  221. # beforehand and pass it in to the HS's constructor (chicken / egg)
  222. hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode('utf8')).hexdigest()
  223. hs.get_auth_handler().validate_hash = (
  224. lambda p, h: hashlib.md5(p.encode('utf8')).hexdigest() == h
  225. )
  226. fed = kargs.get("resource_for_federation", None)
  227. if fed:
  228. server.register_servlets(
  229. hs,
  230. resource=fed,
  231. authenticator=server.Authenticator(hs),
  232. ratelimiter=FederationRateLimiter(
  233. hs.get_clock(),
  234. window_size=hs.config.federation_rc_window_size,
  235. sleep_limit=hs.config.federation_rc_sleep_limit,
  236. sleep_msec=hs.config.federation_rc_sleep_delay,
  237. reject_limit=hs.config.federation_rc_reject_limit,
  238. concurrent_requests=hs.config.federation_rc_concurrent,
  239. ),
  240. )
  241. defer.returnValue(hs)
  242. def get_mock_call_args(pattern_func, mock_func):
  243. """ Return the arguments the mock function was called with interpreted
  244. by the pattern functions argument list.
  245. """
  246. invoked_args, invoked_kargs = mock_func.call_args
  247. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  248. def mock_getRawHeaders(headers=None):
  249. headers = headers if headers is not None else {}
  250. def getRawHeaders(name, default=None):
  251. return headers.get(name, default)
  252. return getRawHeaders
  253. # This is a mock /resource/ not an entire server
  254. class MockHttpResource(HttpServer):
  255. def __init__(self, prefix=""):
  256. self.callbacks = [] # 3-tuple of method/pattern/function
  257. self.prefix = prefix
  258. def trigger_get(self, path):
  259. return self.trigger(b"GET", path, None)
  260. @patch('twisted.web.http.Request')
  261. @defer.inlineCallbacks
  262. def trigger(self, http_method, path, content, mock_request, federation_auth=False):
  263. """ Fire an HTTP event.
  264. Args:
  265. http_method : The HTTP method
  266. path : The HTTP path
  267. content : The HTTP body
  268. mock_request : Mocked request to pass to the event so it can get
  269. content.
  270. Returns:
  271. A tuple of (code, response)
  272. Raises:
  273. KeyError If no event is found which will handle the path.
  274. """
  275. path = self.prefix + path
  276. # annoyingly we return a twisted http request which has chained calls
  277. # to get at the http content, hence mock it here.
  278. mock_content = Mock()
  279. config = {'read.return_value': content}
  280. mock_content.configure_mock(**config)
  281. mock_request.content = mock_content
  282. mock_request.method = http_method.encode('ascii')
  283. mock_request.uri = path.encode('ascii')
  284. mock_request.getClientIP.return_value = "-"
  285. headers = {}
  286. if federation_auth:
  287. headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
  288. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  289. # return the right path if the event requires it
  290. mock_request.path = path
  291. # add in query params to the right place
  292. try:
  293. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  294. mock_request.path = path.split('?')[0]
  295. path = mock_request.path
  296. except Exception:
  297. pass
  298. if isinstance(path, bytes):
  299. path = path.decode('utf8')
  300. for (method, pattern, func) in self.callbacks:
  301. if http_method != method:
  302. continue
  303. matcher = pattern.match(path)
  304. if matcher:
  305. try:
  306. args = [urlparse.unquote(u) for u in matcher.groups()]
  307. (code, response) = yield func(mock_request, *args)
  308. defer.returnValue((code, response))
  309. except CodeMessageException as e:
  310. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  311. raise KeyError("No event can handle %s" % path)
  312. def register_paths(self, method, path_patterns, callback):
  313. for path_pattern in path_patterns:
  314. self.callbacks.append((method, path_pattern, callback))
  315. class MockKey(object):
  316. alg = "mock_alg"
  317. version = "mock_version"
  318. signature = b"\x9a\x87$"
  319. @property
  320. def verify_key(self):
  321. return self
  322. def sign(self, message):
  323. return self
  324. def verify(self, message, sig):
  325. assert sig == b"\x9a\x87$"
  326. class MockClock(object):
  327. now = 1000
  328. def __init__(self):
  329. # list of lists of [absolute_time, callback, expired] in no particular
  330. # order
  331. self.timers = []
  332. self.loopers = []
  333. def time(self):
  334. return self.now
  335. def time_msec(self):
  336. return self.time() * 1000
  337. def call_later(self, delay, callback, *args, **kwargs):
  338. current_context = LoggingContext.current_context()
  339. def wrapped_callback():
  340. LoggingContext.thread_local.current_context = current_context
  341. callback(*args, **kwargs)
  342. t = [self.now + delay, wrapped_callback, False]
  343. self.timers.append(t)
  344. return t
  345. def looping_call(self, function, interval):
  346. self.loopers.append([function, interval / 1000., self.now])
  347. def cancel_call_later(self, timer, ignore_errs=False):
  348. if timer[2]:
  349. if not ignore_errs:
  350. raise Exception("Cannot cancel an expired timer")
  351. timer[2] = True
  352. self.timers = [t for t in self.timers if t != timer]
  353. # For unit testing
  354. def advance_time(self, secs):
  355. self.now += secs
  356. timers = self.timers
  357. self.timers = []
  358. for t in timers:
  359. time, callback, expired = t
  360. if expired:
  361. raise Exception("Timer already expired")
  362. if self.now >= time:
  363. t[2] = True
  364. callback()
  365. else:
  366. self.timers.append(t)
  367. for looped in self.loopers:
  368. func, interval, last = looped
  369. if last + interval < self.now:
  370. func()
  371. looped[2] = self.now
  372. def advance_time_msec(self, ms):
  373. self.advance_time(ms / 1000.)
  374. def time_bound_deferred(self, d, *args, **kwargs):
  375. # We don't bother timing things out for now.
  376. return d
  377. def _format_call(args, kwargs):
  378. return ", ".join(
  379. ["%r" % (a) for a in args] + ["%s=%r" % (k, v) for k, v in kwargs.items()]
  380. )
  381. class DeferredMockCallable(object):
  382. """A callable instance that stores a set of pending call expectations and
  383. return values for them. It allows a unit test to assert that the given set
  384. of function calls are eventually made, by awaiting on them to be called.
  385. """
  386. def __init__(self):
  387. self.expectations = []
  388. self.calls = []
  389. def __call__(self, *args, **kwargs):
  390. self.calls.append((args, kwargs))
  391. if not self.expectations:
  392. raise ValueError(
  393. "%r has no pending calls to handle call(%s)"
  394. % (self, _format_call(args, kwargs))
  395. )
  396. for (call, result, d) in self.expectations:
  397. if args == call[1] and kwargs == call[2]:
  398. d.callback(None)
  399. return result
  400. failure = AssertionError(
  401. "Was not expecting call(%s)" % (_format_call(args, kwargs))
  402. )
  403. for _, _, d in self.expectations:
  404. try:
  405. d.errback(failure)
  406. except Exception:
  407. pass
  408. raise failure
  409. def expect_call_and_return(self, call, result):
  410. self.expectations.append((call, result, defer.Deferred()))
  411. @defer.inlineCallbacks
  412. def await_calls(self, timeout=1000):
  413. deferred = defer.DeferredList(
  414. [d for _, _, d in self.expectations], fireOnOneErrback=True
  415. )
  416. timer = reactor.callLater(
  417. timeout / 1000,
  418. deferred.errback,
  419. AssertionError(
  420. "%d pending calls left: %s"
  421. % (
  422. len([e for e in self.expectations if not e[2].called]),
  423. [e for e in self.expectations if not e[2].called],
  424. )
  425. ),
  426. )
  427. yield deferred
  428. timer.cancel()
  429. self.calls = []
  430. def assert_had_no_calls(self):
  431. if self.calls:
  432. calls = self.calls
  433. self.calls = []
  434. raise AssertionError(
  435. "Expected not to received any calls, got:\n"
  436. + "\n".join(["call(%s)" % _format_call(c[0], c[1]) for c in calls])
  437. )