utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import hashlib
  16. from inspect import getcallargs
  17. from six.moves.urllib import parse as urlparse
  18. from mock import Mock, patch
  19. from twisted.internet import defer, reactor
  20. from synapse.api.errors import CodeMessageException, cs_error
  21. from synapse.federation.transport import server
  22. from synapse.http.server import HttpServer
  23. from synapse.server import HomeServer
  24. from synapse.storage import PostgresEngine
  25. from synapse.storage.engines import create_engine
  26. from synapse.storage.prepare_database import prepare_database
  27. from synapse.util.logcontext import LoggingContext
  28. from synapse.util.ratelimitutils import FederationRateLimiter
  29. # set this to True to run the tests against postgres instead of sqlite.
  30. # It requires you to have a local postgres database called synapse_test, within
  31. # which ALL TABLES WILL BE DROPPED
  32. USE_POSTGRES_FOR_TESTS = False
  33. @defer.inlineCallbacks
  34. def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
  35. """Setup a homeserver suitable for running tests against. Keyword arguments
  36. are passed to the Homeserver constructor. If no datastore is supplied a
  37. datastore backed by an in-memory sqlite db will be given to the HS.
  38. """
  39. if config is None:
  40. config = Mock()
  41. config.signing_key = [MockKey()]
  42. config.event_cache_size = 1
  43. config.enable_registration = True
  44. config.macaroon_secret_key = "not even a little secret"
  45. config.expire_access_token = False
  46. config.server_name = name
  47. config.trusted_third_party_id_servers = []
  48. config.room_invite_state_types = []
  49. config.password_providers = []
  50. config.worker_replication_url = ""
  51. config.worker_app = None
  52. config.email_enable_notifs = False
  53. config.block_non_admin_invites = False
  54. config.federation_domain_whitelist = None
  55. config.federation_rc_reject_limit = 10
  56. config.federation_rc_sleep_limit = 10
  57. config.federation_rc_concurrent = 10
  58. config.filter_timeline_limit = 5000
  59. config.user_directory_search_all_users = False
  60. config.user_consent_server_notice_content = None
  61. config.block_events_without_consent_error = None
  62. # disable user directory updates, because they get done in the
  63. # background, which upsets the test runner.
  64. config.update_user_directory = False
  65. config.use_frozen_dicts = True
  66. config.ldap_enabled = False
  67. if "clock" not in kargs:
  68. kargs["clock"] = MockClock()
  69. if USE_POSTGRES_FOR_TESTS:
  70. config.database_config = {
  71. "name": "psycopg2",
  72. "args": {
  73. "database": "synapse_test",
  74. "cp_min": 1,
  75. "cp_max": 5,
  76. },
  77. }
  78. else:
  79. config.database_config = {
  80. "name": "sqlite3",
  81. "args": {
  82. "database": ":memory:",
  83. "cp_min": 1,
  84. "cp_max": 1,
  85. },
  86. }
  87. db_engine = create_engine(config.database_config)
  88. # we need to configure the connection pool to run the on_new_connection
  89. # function, so that we can test code that uses custom sqlite functions
  90. # (like rank).
  91. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  92. if datastore is None:
  93. hs = HomeServer(
  94. name, config=config,
  95. db_config=config.database_config,
  96. version_string="Synapse/tests",
  97. database_engine=db_engine,
  98. room_list_handler=object(),
  99. tls_server_context_factory=Mock(),
  100. **kargs
  101. )
  102. db_conn = hs.get_db_conn()
  103. # make sure that the database is empty
  104. if isinstance(db_engine, PostgresEngine):
  105. cur = db_conn.cursor()
  106. cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
  107. rows = cur.fetchall()
  108. for r in rows:
  109. cur.execute("DROP TABLE %s CASCADE" % r[0])
  110. yield prepare_database(db_conn, db_engine, config)
  111. hs.setup()
  112. else:
  113. hs = HomeServer(
  114. name, db_pool=None, datastore=datastore, config=config,
  115. version_string="Synapse/tests",
  116. database_engine=db_engine,
  117. room_list_handler=object(),
  118. tls_server_context_factory=Mock(),
  119. **kargs
  120. )
  121. # bcrypt is far too slow to be doing in unit tests
  122. # Need to let the HS build an auth handler and then mess with it
  123. # because AuthHandler's constructor requires the HS, so we can't make one
  124. # beforehand and pass it in to the HS's constructor (chicken / egg)
  125. hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
  126. hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
  127. fed = kargs.get("resource_for_federation", None)
  128. if fed:
  129. server.register_servlets(
  130. hs,
  131. resource=fed,
  132. authenticator=server.Authenticator(hs),
  133. ratelimiter=FederationRateLimiter(
  134. hs.get_clock(),
  135. window_size=hs.config.federation_rc_window_size,
  136. sleep_limit=hs.config.federation_rc_sleep_limit,
  137. sleep_msec=hs.config.federation_rc_sleep_delay,
  138. reject_limit=hs.config.federation_rc_reject_limit,
  139. concurrent_requests=hs.config.federation_rc_concurrent
  140. ),
  141. )
  142. defer.returnValue(hs)
  143. def get_mock_call_args(pattern_func, mock_func):
  144. """ Return the arguments the mock function was called with interpreted
  145. by the pattern functions argument list.
  146. """
  147. invoked_args, invoked_kargs = mock_func.call_args
  148. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  149. def mock_getRawHeaders(headers=None):
  150. headers = headers if headers is not None else {}
  151. def getRawHeaders(name, default=None):
  152. return headers.get(name, default)
  153. return getRawHeaders
  154. # This is a mock /resource/ not an entire server
  155. class MockHttpResource(HttpServer):
  156. def __init__(self, prefix=""):
  157. self.callbacks = [] # 3-tuple of method/pattern/function
  158. self.prefix = prefix
  159. def trigger_get(self, path):
  160. return self.trigger("GET", path, None)
  161. @patch('twisted.web.http.Request')
  162. @defer.inlineCallbacks
  163. def trigger(self, http_method, path, content, mock_request, federation_auth=False):
  164. """ Fire an HTTP event.
  165. Args:
  166. http_method : The HTTP method
  167. path : The HTTP path
  168. content : The HTTP body
  169. mock_request : Mocked request to pass to the event so it can get
  170. content.
  171. Returns:
  172. A tuple of (code, response)
  173. Raises:
  174. KeyError If no event is found which will handle the path.
  175. """
  176. path = self.prefix + path
  177. # annoyingly we return a twisted http request which has chained calls
  178. # to get at the http content, hence mock it here.
  179. mock_content = Mock()
  180. config = {'read.return_value': content}
  181. mock_content.configure_mock(**config)
  182. mock_request.content = mock_content
  183. mock_request.method = http_method
  184. mock_request.uri = path
  185. mock_request.getClientIP.return_value = "-"
  186. headers = {}
  187. if federation_auth:
  188. headers[b"Authorization"] = ["X-Matrix origin=test,key=,sig="]
  189. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  190. # return the right path if the event requires it
  191. mock_request.path = path
  192. # add in query params to the right place
  193. try:
  194. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  195. mock_request.path = path.split('?')[0]
  196. path = mock_request.path
  197. except Exception:
  198. pass
  199. for (method, pattern, func) in self.callbacks:
  200. if http_method != method:
  201. continue
  202. matcher = pattern.match(path)
  203. if matcher:
  204. try:
  205. args = [
  206. urlparse.unquote(u).decode("UTF-8")
  207. for u in matcher.groups()
  208. ]
  209. (code, response) = yield func(
  210. mock_request,
  211. *args
  212. )
  213. defer.returnValue((code, response))
  214. except CodeMessageException as e:
  215. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  216. raise KeyError("No event can handle %s" % path)
  217. def register_paths(self, method, path_patterns, callback):
  218. for path_pattern in path_patterns:
  219. self.callbacks.append((method, path_pattern, callback))
  220. class MockKey(object):
  221. alg = "mock_alg"
  222. version = "mock_version"
  223. signature = b"\x9a\x87$"
  224. @property
  225. def verify_key(self):
  226. return self
  227. def sign(self, message):
  228. return self
  229. def verify(self, message, sig):
  230. assert sig == b"\x9a\x87$"
  231. class MockClock(object):
  232. now = 1000
  233. def __init__(self):
  234. # list of lists of [absolute_time, callback, expired] in no particular
  235. # order
  236. self.timers = []
  237. self.loopers = []
  238. def time(self):
  239. return self.now
  240. def time_msec(self):
  241. return self.time() * 1000
  242. def call_later(self, delay, callback, *args, **kwargs):
  243. current_context = LoggingContext.current_context()
  244. def wrapped_callback():
  245. LoggingContext.thread_local.current_context = current_context
  246. callback(*args, **kwargs)
  247. t = [self.now + delay, wrapped_callback, False]
  248. self.timers.append(t)
  249. return t
  250. def looping_call(self, function, interval):
  251. self.loopers.append([function, interval / 1000., self.now])
  252. def cancel_call_later(self, timer, ignore_errs=False):
  253. if timer[2]:
  254. if not ignore_errs:
  255. raise Exception("Cannot cancel an expired timer")
  256. timer[2] = True
  257. self.timers = [t for t in self.timers if t != timer]
  258. # For unit testing
  259. def advance_time(self, secs):
  260. self.now += secs
  261. timers = self.timers
  262. self.timers = []
  263. for t in timers:
  264. time, callback, expired = t
  265. if expired:
  266. raise Exception("Timer already expired")
  267. if self.now >= time:
  268. t[2] = True
  269. callback()
  270. else:
  271. self.timers.append(t)
  272. for looped in self.loopers:
  273. func, interval, last = looped
  274. if last + interval < self.now:
  275. func()
  276. looped[2] = self.now
  277. def advance_time_msec(self, ms):
  278. self.advance_time(ms / 1000.)
  279. def time_bound_deferred(self, d, *args, **kwargs):
  280. # We don't bother timing things out for now.
  281. return d
  282. def _format_call(args, kwargs):
  283. return ", ".join(
  284. ["%r" % (a) for a in args] +
  285. ["%s=%r" % (k, v) for k, v in kwargs.items()]
  286. )
  287. class DeferredMockCallable(object):
  288. """A callable instance that stores a set of pending call expectations and
  289. return values for them. It allows a unit test to assert that the given set
  290. of function calls are eventually made, by awaiting on them to be called.
  291. """
  292. def __init__(self):
  293. self.expectations = []
  294. self.calls = []
  295. def __call__(self, *args, **kwargs):
  296. self.calls.append((args, kwargs))
  297. if not self.expectations:
  298. raise ValueError("%r has no pending calls to handle call(%s)" % (
  299. self, _format_call(args, kwargs))
  300. )
  301. for (call, result, d) in self.expectations:
  302. if args == call[1] and kwargs == call[2]:
  303. d.callback(None)
  304. return result
  305. failure = AssertionError("Was not expecting call(%s)" % (
  306. _format_call(args, kwargs)
  307. ))
  308. for _, _, d in self.expectations:
  309. try:
  310. d.errback(failure)
  311. except Exception:
  312. pass
  313. raise failure
  314. def expect_call_and_return(self, call, result):
  315. self.expectations.append((call, result, defer.Deferred()))
  316. @defer.inlineCallbacks
  317. def await_calls(self, timeout=1000):
  318. deferred = defer.DeferredList(
  319. [d for _, _, d in self.expectations],
  320. fireOnOneErrback=True
  321. )
  322. timer = reactor.callLater(
  323. timeout / 1000,
  324. deferred.errback,
  325. AssertionError("%d pending calls left: %s" % (
  326. len([e for e in self.expectations if not e[2].called]),
  327. [e for e in self.expectations if not e[2].called]
  328. ))
  329. )
  330. yield deferred
  331. timer.cancel()
  332. self.calls = []
  333. def assert_had_no_calls(self):
  334. if self.calls:
  335. calls = self.calls
  336. self.calls = []
  337. raise AssertionError(
  338. "Expected not to received any calls, got:\n" + "\n".join([
  339. "call(%s)" % _format_call(c[0], c[1]) for c in calls
  340. ])
  341. )