123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295 |
- # -*- coding: utf-8 -*-
- # Copyright 2014-2016 OpenMarket Ltd
- # Copyright 2018 New Vector
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import logging
- from mock import Mock
- from canonicaljson import json
- import twisted
- import twisted.logger
- from twisted.internet.defer import Deferred
- from twisted.trial import unittest
- from synapse.http.server import JsonResource
- from synapse.server import HomeServer
- from synapse.types import UserID, create_requester
- from synapse.util.logcontext import LoggingContextFilter
- from tests.server import get_clock, make_request, render, setup_test_homeserver
- # Set up putting Synapse's logs into Trial's.
- rootLogger = logging.getLogger()
- log_format = (
- "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s"
- )
- class ToTwistedHandler(logging.Handler):
- tx_log = twisted.logger.Logger()
- def emit(self, record):
- log_entry = self.format(record)
- log_level = record.levelname.lower().replace('warning', 'warn')
- self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level),
- log_entry.replace("{", r"(").replace("}", r")"),
- )
- handler = ToTwistedHandler()
- formatter = logging.Formatter(log_format)
- handler.setFormatter(formatter)
- handler.addFilter(LoggingContextFilter(request=""))
- rootLogger.addHandler(handler)
- def around(target):
- """A CLOS-style 'around' modifier, which wraps the original method of the
- given instance with another piece of code.
- @around(self)
- def method_name(orig, *args, **kwargs):
- return orig(*args, **kwargs)
- """
- def _around(code):
- name = code.__name__
- orig = getattr(target, name)
- def new(*args, **kwargs):
- return code(orig, *args, **kwargs)
- setattr(target, name, new)
- return _around
- class TestCase(unittest.TestCase):
- """A subclass of twisted.trial's TestCase which looks for 'loglevel'
- attributes on both itself and its individual test methods, to override the
- root logger's logging level while that test (case|method) runs."""
- def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
- method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", logging.ERROR))
- @around(self)
- def setUp(orig):
- # enable debugging of delayed calls - this means that we get a
- # traceback when a unit test exits leaving things on the reactor.
- twisted.internet.base.DelayedCall.debug = True
- old_level = logging.getLogger().level
- if old_level != level:
- @around(self)
- def tearDown(orig):
- ret = orig()
- logging.getLogger().setLevel(old_level)
- return ret
- logging.getLogger().setLevel(level)
- return orig()
- def assertObjectHasAttributes(self, attrs, obj):
- """Asserts that the given object has each of the attributes given, and
- that the value of each matches according to assertEquals."""
- for (key, value) in attrs.items():
- if not hasattr(obj, key):
- raise AssertionError("Expected obj to have a '.%s'" % key)
- try:
- self.assertEquals(attrs[key], getattr(obj, key))
- except AssertionError as e:
- raise (type(e))(e.message + " for '.%s'" % key)
- def assert_dict(self, required, actual):
- """Does a partial assert of a dict.
- Args:
- required (dict): The keys and value which MUST be in 'actual'.
- actual (dict): The test result. Extra keys will not be checked.
- """
- for key in required:
- self.assertEquals(
- required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
- )
- def DEBUG(target):
- """A decorator to set the .loglevel attribute to logging.DEBUG.
- Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.DEBUG
- return target
- class HomeserverTestCase(TestCase):
- """
- A base TestCase that reduces boilerplate for HomeServer-using test cases.
- Attributes:
- servlets (list[function]): List of servlet registration function.
- user_id (str): The user ID to assume if auth is hijacked.
- hijack_auth (bool): Whether to hijack auth to return the user specified
- in user_id.
- """
- servlets = []
- hijack_auth = True
- def setUp(self):
- """
- Set up the TestCase by calling the homeserver constructor, optionally
- hijacking the authentication system to return a fixed user, and then
- calling the prepare function.
- """
- self.reactor, self.clock = get_clock()
- self._hs_args = {"clock": self.clock, "reactor": self.reactor}
- self.hs = self.make_homeserver(self.reactor, self.clock)
- if self.hs is None:
- raise Exception("No homeserver returned from make_homeserver.")
- if not isinstance(self.hs, HomeServer):
- raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
- # Register the resources
- self.resource = JsonResource(self.hs)
- for servlet in self.servlets:
- servlet(self.hs, self.resource)
- if hasattr(self, "user_id"):
- from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, self.user_id)
- if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
- self.hs.get_auth().get_user_by_req = get_user_by_req
- self.hs.get_auth().get_user_by_access_token = get_user_by_access_token
- self.hs.get_auth().get_access_token_from_request = Mock(
- return_value="1234"
- )
- if hasattr(self, "prepare"):
- self.prepare(self.reactor, self.clock, self.hs)
- def make_homeserver(self, reactor, clock):
- """
- Make and return a homeserver.
- Args:
- reactor: A Twisted Reactor, or something that pretends to be one.
- clock (synapse.util.Clock): The Clock, associated with the reactor.
- Returns:
- A homeserver (synapse.server.HomeServer) suitable for testing.
- Function to be overridden in subclasses.
- """
- raise NotImplementedError()
- def prepare(self, reactor, clock, homeserver):
- """
- Prepare for the test. This involves things like mocking out parts of
- the homeserver, or building test data common across the whole test
- suite.
- Args:
- reactor: A Twisted Reactor, or something that pretends to be one.
- clock (synapse.util.Clock): The Clock, associated with the reactor.
- homeserver (synapse.server.HomeServer): The HomeServer to test
- against.
- Function to optionally be overridden in subclasses.
- """
- def make_request(self, method, path, content=b""):
- """
- Create a SynapseRequest at the path using the method and containing the
- given content.
- Args:
- method (bytes/unicode): The HTTP request method ("verb").
- path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
- escaped UTF-8 & spaces and such).
- content (bytes or dict): The body of the request. JSON-encoded, if
- a dict.
- Returns:
- A synapse.http.site.SynapseRequest.
- """
- if isinstance(content, dict):
- content = json.dumps(content).encode('utf8')
- return make_request(method, path, content)
- def render(self, request):
- """
- Render a request against the resources registered by the test class's
- servlets.
- Args:
- request (synapse.http.site.SynapseRequest): The request to render.
- """
- render(request, self.resource, self.reactor)
- def setup_test_homeserver(self, *args, **kwargs):
- """
- Set up the test homeserver, meant to be called by the overridable
- make_homeserver. It automatically passes through the test class's
- clock & reactor.
- Args:
- See tests.utils.setup_test_homeserver.
- Returns:
- synapse.server.HomeServer
- """
- kwargs = dict(kwargs)
- kwargs.update(self._hs_args)
- return setup_test_homeserver(self.addCleanup, *args, **kwargs)
- def pump(self, by=0.0):
- """
- Pump the reactor enough that Deferreds will fire.
- """
- self.reactor.pump([by] * 100)
- def get_success(self, d):
- if not isinstance(d, Deferred):
- return d
- self.pump()
- return self.successResultOf(d)
|