utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import hashlib
  16. from inspect import getcallargs
  17. from mock import Mock, patch
  18. from six.moves.urllib import parse as urlparse
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import CodeMessageException, cs_error
  21. from synapse.federation.transport import server
  22. from synapse.http.server import HttpServer
  23. from synapse.server import HomeServer
  24. from synapse.storage import PostgresEngine
  25. from synapse.storage.engines import create_engine
  26. from synapse.storage.prepare_database import prepare_database
  27. from synapse.util.logcontext import LoggingContext
  28. from synapse.util.ratelimitutils import FederationRateLimiter
  29. # set this to True to run the tests against postgres instead of sqlite.
  30. # It requires you to have a local postgres database called synapse_test, within
  31. # which ALL TABLES WILL BE DROPPED
  32. USE_POSTGRES_FOR_TESTS = False
  33. @defer.inlineCallbacks
  34. def setup_test_homeserver(name="test", datastore=None, config=None, reactor=None,
  35. **kargs):
  36. """Setup a homeserver suitable for running tests against. Keyword arguments
  37. are passed to the Homeserver constructor. If no datastore is supplied a
  38. datastore backed by an in-memory sqlite db will be given to the HS.
  39. """
  40. if reactor is None:
  41. from twisted.internet import reactor
  42. if config is None:
  43. config = Mock()
  44. config.signing_key = [MockKey()]
  45. config.event_cache_size = 1
  46. config.enable_registration = True
  47. config.macaroon_secret_key = "not even a little secret"
  48. config.expire_access_token = False
  49. config.server_name = name
  50. config.trusted_third_party_id_servers = []
  51. config.room_invite_state_types = []
  52. config.password_providers = []
  53. config.worker_replication_url = ""
  54. config.worker_app = None
  55. config.email_enable_notifs = False
  56. config.block_non_admin_invites = False
  57. config.federation_domain_whitelist = None
  58. config.federation_rc_reject_limit = 10
  59. config.federation_rc_sleep_limit = 10
  60. config.federation_rc_sleep_delay = 100
  61. config.federation_rc_concurrent = 10
  62. config.filter_timeline_limit = 5000
  63. config.user_directory_search_all_users = False
  64. config.user_consent_server_notice_content = None
  65. config.block_events_without_consent_error = None
  66. config.media_storage_providers = []
  67. config.auto_join_rooms = []
  68. # disable user directory updates, because they get done in the
  69. # background, which upsets the test runner.
  70. config.update_user_directory = False
  71. config.use_frozen_dicts = True
  72. config.ldap_enabled = False
  73. if "clock" not in kargs:
  74. kargs["clock"] = MockClock()
  75. if USE_POSTGRES_FOR_TESTS:
  76. config.database_config = {
  77. "name": "psycopg2",
  78. "args": {
  79. "database": "synapse_test",
  80. "cp_min": 1,
  81. "cp_max": 5,
  82. },
  83. }
  84. else:
  85. config.database_config = {
  86. "name": "sqlite3",
  87. "args": {
  88. "database": ":memory:",
  89. "cp_min": 1,
  90. "cp_max": 1,
  91. },
  92. }
  93. db_engine = create_engine(config.database_config)
  94. # we need to configure the connection pool to run the on_new_connection
  95. # function, so that we can test code that uses custom sqlite functions
  96. # (like rank).
  97. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  98. if datastore is None:
  99. hs = HomeServer(
  100. name, config=config,
  101. db_config=config.database_config,
  102. version_string="Synapse/tests",
  103. database_engine=db_engine,
  104. room_list_handler=object(),
  105. tls_server_context_factory=Mock(),
  106. reactor=reactor,
  107. **kargs
  108. )
  109. db_conn = hs.get_db_conn()
  110. # make sure that the database is empty
  111. if isinstance(db_engine, PostgresEngine):
  112. cur = db_conn.cursor()
  113. cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
  114. rows = cur.fetchall()
  115. for r in rows:
  116. cur.execute("DROP TABLE %s CASCADE" % r[0])
  117. yield prepare_database(db_conn, db_engine, config)
  118. hs.setup()
  119. else:
  120. hs = HomeServer(
  121. name, db_pool=None, datastore=datastore, config=config,
  122. version_string="Synapse/tests",
  123. database_engine=db_engine,
  124. room_list_handler=object(),
  125. tls_server_context_factory=Mock(),
  126. reactor=reactor,
  127. **kargs
  128. )
  129. # bcrypt is far too slow to be doing in unit tests
  130. # Need to let the HS build an auth handler and then mess with it
  131. # because AuthHandler's constructor requires the HS, so we can't make one
  132. # beforehand and pass it in to the HS's constructor (chicken / egg)
  133. hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
  134. hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
  135. fed = kargs.get("resource_for_federation", None)
  136. if fed:
  137. server.register_servlets(
  138. hs,
  139. resource=fed,
  140. authenticator=server.Authenticator(hs),
  141. ratelimiter=FederationRateLimiter(
  142. hs.get_clock(),
  143. window_size=hs.config.federation_rc_window_size,
  144. sleep_limit=hs.config.federation_rc_sleep_limit,
  145. sleep_msec=hs.config.federation_rc_sleep_delay,
  146. reject_limit=hs.config.federation_rc_reject_limit,
  147. concurrent_requests=hs.config.federation_rc_concurrent
  148. ),
  149. )
  150. defer.returnValue(hs)
  151. def get_mock_call_args(pattern_func, mock_func):
  152. """ Return the arguments the mock function was called with interpreted
  153. by the pattern functions argument list.
  154. """
  155. invoked_args, invoked_kargs = mock_func.call_args
  156. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  157. def mock_getRawHeaders(headers=None):
  158. headers = headers if headers is not None else {}
  159. def getRawHeaders(name, default=None):
  160. return headers.get(name, default)
  161. return getRawHeaders
  162. # This is a mock /resource/ not an entire server
  163. class MockHttpResource(HttpServer):
  164. def __init__(self, prefix=""):
  165. self.callbacks = [] # 3-tuple of method/pattern/function
  166. self.prefix = prefix
  167. def trigger_get(self, path):
  168. return self.trigger(b"GET", path, None)
  169. @patch('twisted.web.http.Request')
  170. @defer.inlineCallbacks
  171. def trigger(self, http_method, path, content, mock_request, federation_auth=False):
  172. """ Fire an HTTP event.
  173. Args:
  174. http_method : The HTTP method
  175. path : The HTTP path
  176. content : The HTTP body
  177. mock_request : Mocked request to pass to the event so it can get
  178. content.
  179. Returns:
  180. A tuple of (code, response)
  181. Raises:
  182. KeyError If no event is found which will handle the path.
  183. """
  184. path = self.prefix + path
  185. # annoyingly we return a twisted http request which has chained calls
  186. # to get at the http content, hence mock it here.
  187. mock_content = Mock()
  188. config = {'read.return_value': content}
  189. mock_content.configure_mock(**config)
  190. mock_request.content = mock_content
  191. mock_request.method = http_method
  192. mock_request.uri = path
  193. mock_request.getClientIP.return_value = "-"
  194. headers = {}
  195. if federation_auth:
  196. headers[b"Authorization"] = [b"X-Matrix origin=test,key=,sig="]
  197. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  198. # return the right path if the event requires it
  199. mock_request.path = path
  200. # add in query params to the right place
  201. try:
  202. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  203. mock_request.path = path.split('?')[0]
  204. path = mock_request.path
  205. except Exception:
  206. pass
  207. if isinstance(path, bytes):
  208. path = path.decode('utf8')
  209. for (method, pattern, func) in self.callbacks:
  210. if http_method != method:
  211. continue
  212. matcher = pattern.match(path)
  213. if matcher:
  214. try:
  215. args = [
  216. urlparse.unquote(u)
  217. for u in matcher.groups()
  218. ]
  219. (code, response) = yield func(
  220. mock_request,
  221. *args
  222. )
  223. defer.returnValue((code, response))
  224. except CodeMessageException as e:
  225. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  226. raise KeyError("No event can handle %s" % path)
  227. def register_paths(self, method, path_patterns, callback):
  228. for path_pattern in path_patterns:
  229. self.callbacks.append((method, path_pattern, callback))
  230. class MockKey(object):
  231. alg = "mock_alg"
  232. version = "mock_version"
  233. signature = b"\x9a\x87$"
  234. @property
  235. def verify_key(self):
  236. return self
  237. def sign(self, message):
  238. return self
  239. def verify(self, message, sig):
  240. assert sig == b"\x9a\x87$"
  241. class MockClock(object):
  242. now = 1000
  243. def __init__(self):
  244. # list of lists of [absolute_time, callback, expired] in no particular
  245. # order
  246. self.timers = []
  247. self.loopers = []
  248. def time(self):
  249. return self.now
  250. def time_msec(self):
  251. return self.time() * 1000
  252. def call_later(self, delay, callback, *args, **kwargs):
  253. current_context = LoggingContext.current_context()
  254. def wrapped_callback():
  255. LoggingContext.thread_local.current_context = current_context
  256. callback(*args, **kwargs)
  257. t = [self.now + delay, wrapped_callback, False]
  258. self.timers.append(t)
  259. return t
  260. def looping_call(self, function, interval):
  261. self.loopers.append([function, interval / 1000., self.now])
  262. def cancel_call_later(self, timer, ignore_errs=False):
  263. if timer[2]:
  264. if not ignore_errs:
  265. raise Exception("Cannot cancel an expired timer")
  266. timer[2] = True
  267. self.timers = [t for t in self.timers if t != timer]
  268. # For unit testing
  269. def advance_time(self, secs):
  270. self.now += secs
  271. timers = self.timers
  272. self.timers = []
  273. for t in timers:
  274. time, callback, expired = t
  275. if expired:
  276. raise Exception("Timer already expired")
  277. if self.now >= time:
  278. t[2] = True
  279. callback()
  280. else:
  281. self.timers.append(t)
  282. for looped in self.loopers:
  283. func, interval, last = looped
  284. if last + interval < self.now:
  285. func()
  286. looped[2] = self.now
  287. def advance_time_msec(self, ms):
  288. self.advance_time(ms / 1000.)
  289. def time_bound_deferred(self, d, *args, **kwargs):
  290. # We don't bother timing things out for now.
  291. return d
  292. def _format_call(args, kwargs):
  293. return ", ".join(
  294. ["%r" % (a) for a in args] +
  295. ["%s=%r" % (k, v) for k, v in kwargs.items()]
  296. )
  297. class DeferredMockCallable(object):
  298. """A callable instance that stores a set of pending call expectations and
  299. return values for them. It allows a unit test to assert that the given set
  300. of function calls are eventually made, by awaiting on them to be called.
  301. """
  302. def __init__(self):
  303. self.expectations = []
  304. self.calls = []
  305. def __call__(self, *args, **kwargs):
  306. self.calls.append((args, kwargs))
  307. if not self.expectations:
  308. raise ValueError("%r has no pending calls to handle call(%s)" % (
  309. self, _format_call(args, kwargs))
  310. )
  311. for (call, result, d) in self.expectations:
  312. if args == call[1] and kwargs == call[2]:
  313. d.callback(None)
  314. return result
  315. failure = AssertionError("Was not expecting call(%s)" % (
  316. _format_call(args, kwargs)
  317. ))
  318. for _, _, d in self.expectations:
  319. try:
  320. d.errback(failure)
  321. except Exception:
  322. pass
  323. raise failure
  324. def expect_call_and_return(self, call, result):
  325. self.expectations.append((call, result, defer.Deferred()))
  326. @defer.inlineCallbacks
  327. def await_calls(self, timeout=1000):
  328. deferred = defer.DeferredList(
  329. [d for _, _, d in self.expectations],
  330. fireOnOneErrback=True
  331. )
  332. timer = reactor.callLater(
  333. timeout / 1000,
  334. deferred.errback,
  335. AssertionError("%d pending calls left: %s" % (
  336. len([e for e in self.expectations if not e[2].called]),
  337. [e for e in self.expectations if not e[2].called]
  338. ))
  339. )
  340. yield deferred
  341. timer.cancel()
  342. self.calls = []
  343. def assert_had_no_calls(self):
  344. if self.calls:
  345. calls = self.calls
  346. self.calls = []
  347. raise AssertionError(
  348. "Expected not to received any calls, got:\n" + "\n".join([
  349. "call(%s)" % _format_call(c[0], c[1]) for c in calls
  350. ])
  351. )