utils.py 13 KB


  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import hashlib
  16. from inspect import getcallargs
  17. import urllib
  18. import urlparse
  19. from mock import Mock, patch
  20. from twisted.internet import defer, reactor
  21. from synapse.api.errors import CodeMessageException, cs_error
  22. from synapse.federation.transport import server
  23. from synapse.http.server import HttpServer
  24. from synapse.server import HomeServer
  25. from synapse.storage import PostgresEngine
  26. from synapse.storage.engines import create_engine
  27. from synapse.storage.prepare_database import prepare_database
  28. from synapse.util.logcontext import LoggingContext
  29. from synapse.util.ratelimitutils import FederationRateLimiter
  30. # set this to True to run the tests against postgres instead of sqlite.
  31. # It requires you to have a local postgres database called synapse_test, within
  32. # which ALL TABLES WILL BE DROPPED
  33. USE_POSTGRES_FOR_TESTS = False
  34. @defer.inlineCallbacks
  35. def setup_test_homeserver(name="test", datastore=None, config=None, **kargs):
  36. """Setup a homeserver suitable for running tests against. Keyword arguments
  37. are passed to the Homeserver constructor. If no datastore is supplied a
  38. datastore backed by an in-memory sqlite db will be given to the HS.
  39. """
  40. if config is None:
  41. config = Mock()
  42. config.signing_key = [MockKey()]
  43. config.event_cache_size = 1
  44. config.enable_registration = True
  45. config.macaroon_secret_key = "not even a little secret"
  46. config.expire_access_token = False
  47. config.server_name = name
  48. config.trusted_third_party_id_servers = []
  49. config.room_invite_state_types = []
  50. config.password_providers = []
  51. config.worker_replication_url = ""
  52. config.worker_app = None
  53. config.email_enable_notifs = False
  54. config.block_non_admin_invites = False
  55. config.federation_domain_whitelist = None
  56. config.user_directory_search_all_users = False
  57. # disable user directory updates, because they get done in the
  58. # background, which upsets the test runner.
  59. config.update_user_directory = False
  60. config.use_frozen_dicts = True
  61. config.ldap_enabled = False
  62. if "clock" not in kargs:
  63. kargs["clock"] = MockClock()
  64. if USE_POSTGRES_FOR_TESTS:
  65. config.database_config = {
  66. "name": "psycopg2",
  67. "args": {
  68. "database": "synapse_test",
  69. "cp_min": 1,
  70. "cp_max": 5,
  71. },
  72. }
  73. else:
  74. config.database_config = {
  75. "name": "sqlite3",
  76. "args": {
  77. "database": ":memory:",
  78. "cp_min": 1,
  79. "cp_max": 1,
  80. },
  81. }
  82. db_engine = create_engine(config.database_config)
  83. # we need to configure the connection pool to run the on_new_connection
  84. # function, so that we can test code that uses custom sqlite functions
  85. # (like rank).
  86. config.database_config["args"]["cp_openfun"] = db_engine.on_new_connection
  87. if datastore is None:
  88. hs = HomeServer(
  89. name, config=config,
  90. db_config=config.database_config,
  91. version_string="Synapse/tests",
  92. database_engine=db_engine,
  93. room_list_handler=object(),
  94. tls_server_context_factory=Mock(),
  95. **kargs
  96. )
  97. db_conn = hs.get_db_conn()
  98. # make sure that the database is empty
  99. if isinstance(db_engine, PostgresEngine):
  100. cur = db_conn.cursor()
  101. cur.execute("SELECT tablename FROM pg_tables where schemaname='public'")
  102. rows = cur.fetchall()
  103. for r in rows:
  104. cur.execute("DROP TABLE %s CASCADE" % r[0])
  105. yield prepare_database(db_conn, db_engine, config)
  106. hs.setup()
  107. else:
  108. hs = HomeServer(
  109. name, db_pool=None, datastore=datastore, config=config,
  110. version_string="Synapse/tests",
  111. database_engine=db_engine,
  112. room_list_handler=object(),
  113. tls_server_context_factory=Mock(),
  114. **kargs
  115. )
  116. # bcrypt is far too slow to be doing in unit tests
  117. # Need to let the HS build an auth handler and then mess with it
  118. # because AuthHandler's constructor requires the HS, so we can't make one
  119. # beforehand and pass it in to the HS's constructor (chicken / egg)
  120. hs.get_auth_handler().hash = lambda p: hashlib.md5(p).hexdigest()
  121. hs.get_auth_handler().validate_hash = lambda p, h: hashlib.md5(p).hexdigest() == h
  122. fed = kargs.get("resource_for_federation", None)
  123. if fed:
  124. server.register_servlets(
  125. hs,
  126. resource=fed,
  127. authenticator=server.Authenticator(hs),
  128. ratelimiter=FederationRateLimiter(
  129. hs.get_clock(),
  130. window_size=hs.config.federation_rc_window_size,
  131. sleep_limit=hs.config.federation_rc_sleep_limit,
  132. sleep_msec=hs.config.federation_rc_sleep_delay,
  133. reject_limit=hs.config.federation_rc_reject_limit,
  134. concurrent_requests=hs.config.federation_rc_concurrent
  135. ),
  136. )
  137. defer.returnValue(hs)
  138. def get_mock_call_args(pattern_func, mock_func):
  139. """ Return the arguments the mock function was called with interpreted
  140. by the pattern functions argument list.
  141. """
  142. invoked_args, invoked_kargs = mock_func.call_args
  143. return getcallargs(pattern_func, *invoked_args, **invoked_kargs)
  144. def mock_getRawHeaders(headers=None):
  145. headers = headers if headers is not None else {}
  146. def getRawHeaders(name, default=None):
  147. return headers.get(name, default)
  148. return getRawHeaders
  149. # This is a mock /resource/ not an entire server
  150. class MockHttpResource(HttpServer):
  151. def __init__(self, prefix=""):
  152. self.callbacks = [] # 3-tuple of method/pattern/function
  153. self.prefix = prefix
  154. def trigger_get(self, path):
  155. return self.trigger("GET", path, None)
  156. @patch('twisted.web.http.Request')
  157. @defer.inlineCallbacks
  158. def trigger(self, http_method, path, content, mock_request, federation_auth=False):
  159. """ Fire an HTTP event.
  160. Args:
  161. http_method : The HTTP method
  162. path : The HTTP path
  163. content : The HTTP body
  164. mock_request : Mocked request to pass to the event so it can get
  165. content.
  166. Returns:
  167. A tuple of (code, response)
  168. Raises:
  169. KeyError If no event is found which will handle the path.
  170. """
  171. path = self.prefix + path
  172. # annoyingly we return a twisted http request which has chained calls
  173. # to get at the http content, hence mock it here.
  174. mock_content = Mock()
  175. config = {'read.return_value': content}
  176. mock_content.configure_mock(**config)
  177. mock_request.content = mock_content
  178. mock_request.method = http_method
  179. mock_request.uri = path
  180. mock_request.getClientIP.return_value = "-"
  181. headers = {}
  182. if federation_auth:
  183. headers["Authorization"] = ["X-Matrix origin=test,key=,sig="]
  184. mock_request.requestHeaders.getRawHeaders = mock_getRawHeaders(headers)
  185. # return the right path if the event requires it
  186. mock_request.path = path
  187. # add in query params to the right place
  188. try:
  189. mock_request.args = urlparse.parse_qs(path.split('?')[1])
  190. mock_request.path = path.split('?')[0]
  191. path = mock_request.path
  192. except Exception:
  193. pass
  194. for (method, pattern, func) in self.callbacks:
  195. if http_method != method:
  196. continue
  197. matcher = pattern.match(path)
  198. if matcher:
  199. try:
  200. args = [
  201. urllib.unquote(u).decode("UTF-8")
  202. for u in matcher.groups()
  203. ]
  204. (code, response) = yield func(
  205. mock_request,
  206. *args
  207. )
  208. defer.returnValue((code, response))
  209. except CodeMessageException as e:
  210. defer.returnValue((e.code, cs_error(e.msg, code=e.errcode)))
  211. raise KeyError("No event can handle %s" % path)
  212. def register_paths(self, method, path_patterns, callback):
  213. for path_pattern in path_patterns:
  214. self.callbacks.append((method, path_pattern, callback))
  215. class MockKey(object):
  216. alg = "mock_alg"
  217. version = "mock_version"
  218. signature = b"\x9a\x87$"
  219. @property
  220. def verify_key(self):
  221. return self
  222. def sign(self, message):
  223. return self
  224. def verify(self, message, sig):
  225. assert sig == b"\x9a\x87$"
  226. class MockClock(object):
  227. now = 1000
  228. def __init__(self):
  229. # list of lists of [absolute_time, callback, expired] in no particular
  230. # order
  231. self.timers = []
  232. self.loopers = []
  233. def time(self):
  234. return self.now
  235. def time_msec(self):
  236. return self.time() * 1000
  237. def call_later(self, delay, callback, *args, **kwargs):
  238. current_context = LoggingContext.current_context()
  239. def wrapped_callback():
  240. LoggingContext.thread_local.current_context = current_context
  241. callback(*args, **kwargs)
  242. t = [self.now + delay, wrapped_callback, False]
  243. self.timers.append(t)
  244. return t
  245. def looping_call(self, function, interval):
  246. self.loopers.append([function, interval / 1000., self.now])
  247. def cancel_call_later(self, timer, ignore_errs=False):
  248. if timer[2]:
  249. if not ignore_errs:
  250. raise Exception("Cannot cancel an expired timer")
  251. timer[2] = True
  252. self.timers = [t for t in self.timers if t != timer]
  253. # For unit testing
  254. def advance_time(self, secs):
  255. self.now += secs
  256. timers = self.timers
  257. self.timers = []
  258. for t in timers:
  259. time, callback, expired = t
  260. if expired:
  261. raise Exception("Timer already expired")
  262. if self.now >= time:
  263. t[2] = True
  264. callback()
  265. else:
  266. self.timers.append(t)
  267. for looped in self.loopers:
  268. func, interval, last = looped
  269. if last + interval < self.now:
  270. func()
  271. looped[2] = self.now
  272. def advance_time_msec(self, ms):
  273. self.advance_time(ms / 1000.)
  274. def time_bound_deferred(self, d, *args, **kwargs):
  275. # We don't bother timing things out for now.
  276. return d
  277. def _format_call(args, kwargs):
  278. return ", ".join(
  279. ["%r" % (a) for a in args] +
  280. ["%s=%r" % (k, v) for k, v in kwargs.items()]
  281. )
  282. class DeferredMockCallable(object):
  283. """A callable instance that stores a set of pending call expectations and
  284. return values for them. It allows a unit test to assert that the given set
  285. of function calls are eventually made, by awaiting on them to be called.
  286. """
  287. def __init__(self):
  288. self.expectations = []
  289. self.calls = []
  290. def __call__(self, *args, **kwargs):
  291. self.calls.append((args, kwargs))
  292. if not self.expectations:
  293. raise ValueError("%r has no pending calls to handle call(%s)" % (
  294. self, _format_call(args, kwargs))
  295. )
  296. for (call, result, d) in self.expectations:
  297. if args == call[1] and kwargs == call[2]:
  298. d.callback(None)
  299. return result
  300. failure = AssertionError("Was not expecting call(%s)" % (
  301. _format_call(args, kwargs)
  302. ))
  303. for _, _, d in self.expectations:
  304. try:
  305. d.errback(failure)
  306. except Exception:
  307. pass
  308. raise failure
  309. def expect_call_and_return(self, call, result):
  310. self.expectations.append((call, result, defer.Deferred()))
  311. @defer.inlineCallbacks
  312. def await_calls(self, timeout=1000):
  313. deferred = defer.DeferredList(
  314. [d for _, _, d in self.expectations],
  315. fireOnOneErrback=True
  316. )
  317. timer = reactor.callLater(
  318. timeout / 1000,
  319. deferred.errback,
  320. AssertionError("%d pending calls left: %s" % (
  321. len([e for e in self.expectations if not e[2].called]),
  322. [e for e in self.expectations if not e[2].called]
  323. ))
  324. )
  325. yield deferred
  326. timer.cancel()
  327. self.calls = []
  328. def assert_had_no_calls(self):
  329. if self.calls:
  330. calls = self.calls
  331. self.calls = []
  332. raise AssertionError(
  333. "Expected not to received any calls, got:\n" + "\n".join([
  334. "call(%s)" % _format_call(c[0], c[1]) for c in calls
  335. ])
  336. )