12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091 |
- # Copyright 2014-2016 OpenMarket Ltd
- # Copyright 2018 New Vector
- # Copyright 2019 Matrix.org Federation C.I.C
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import functools
- import gc
- import hashlib
- import hmac
- import json
- import logging
- import secrets
- import time
- from typing import (
- Any,
- Awaitable,
- Callable,
- ClassVar,
- Dict,
- Generic,
- Iterable,
- List,
- NoReturn,
- Optional,
- Tuple,
- Type,
- TypeVar,
- Union,
- )
- from unittest.mock import Mock, patch
- import canonicaljson
- import signedjson.key
- import unpaddedbase64
- from typing_extensions import Concatenate, ParamSpec, Protocol
- from twisted.internet.defer import Deferred, ensureDeferred
- from twisted.python.failure import Failure
- from twisted.python.threadpool import ThreadPool
- from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
- from twisted.trial import unittest
- from twisted.web.resource import Resource
- from twisted.web.server import Request
- from synapse import events
- from synapse.api.constants import EventTypes
- from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
- from synapse.config._base import Config, RootConfig
- from synapse.config.homeserver import HomeServerConfig
- from synapse.config.server import DEFAULT_ROOM_VERSION
- from synapse.crypto.event_signing import add_hashes_and_signatures
- from synapse.federation.transport.server import TransportLayerServer
- from synapse.http.server import JsonResource
- from synapse.http.site import SynapseRequest, SynapseSite
- from synapse.logging.context import (
- SENTINEL_CONTEXT,
- LoggingContext,
- current_context,
- set_current_context,
- )
- from synapse.rest import RegisterServletsFunc
- from synapse.server import HomeServer
- from synapse.types import JsonDict, Requester, UserID, create_requester
- from synapse.util import Clock
- from synapse.util.httpresourcetree import create_resource_tree
- from tests.server import (
- CustomHeaderType,
- FakeChannel,
- ThreadedMemoryReactorClock,
- get_clock,
- make_request,
- setup_test_homeserver,
- )
- from tests.test_utils import event_injection, setup_awaitable_errors
- from tests.test_utils.logging_setup import setup_logging
- from tests.utils import checked_cast, default_config, setupdb
- setupdb()
- setup_logging()
- TV = TypeVar("TV")
- _ExcType = TypeVar("_ExcType", bound=BaseException, covariant=True)
- P = ParamSpec("P")
- R = TypeVar("R")
- S = TypeVar("S")
- class _TypedFailure(Generic[_ExcType], Protocol):
- """Extension to twisted.Failure, where the 'value' has a certain type."""
- @property
- def value(self) -> _ExcType:
- ...
- def around(target: TV) -> Callable[[Callable[Concatenate[S, P], R]], None]:
- """A CLOS-style 'around' modifier, which wraps the original method of the
- given instance with another piece of code.
- @around(self)
- def method_name(orig, *args, **kwargs):
- return orig(*args, **kwargs)
- """
- def _around(code: Callable[Concatenate[S, P], R]) -> None:
- name = code.__name__
- orig = getattr(target, name)
- def new(*args: P.args, **kwargs: P.kwargs) -> R:
- return code(orig, *args, **kwargs)
- setattr(target, name, new)
- return _around
- _TConfig = TypeVar("_TConfig", Config, RootConfig)
- def deepcopy_config(config: _TConfig) -> _TConfig:
- new_config: _TConfig
- if isinstance(config, RootConfig):
- new_config = config.__class__(config.config_files) # type: ignore[arg-type]
- else:
- new_config = config.__class__(config.root)
- for attr_name in config.__dict__:
- if attr_name.startswith("__") or attr_name == "root":
- continue
- attr = getattr(config, attr_name)
- if isinstance(attr, Config):
- new_attr = deepcopy_config(attr)
- else:
- new_attr = attr
- setattr(new_config, attr_name, new_attr)
- return new_config
- @functools.lru_cache(maxsize=8)
- def _parse_config_dict(config: str) -> RootConfig:
- config_obj = HomeServerConfig()
- config_obj.parse_config_dict(json.loads(config), "", "")
- return config_obj
- def make_homeserver_config_obj(config: Dict[str, Any]) -> RootConfig:
- """Creates a :class:`HomeServerConfig` instance with the given configuration dict.
- This is equivalent to::
- config_obj = HomeServerConfig()
- config_obj.parse_config_dict(config, "", "")
- but it keeps a cache of `HomeServerConfig` instances and deepcopies them as needed,
- to avoid validating the whole configuration every time.
- """
- config_obj = _parse_config_dict(json.dumps(config, sort_keys=True))
- return deepcopy_config(config_obj)
- class TestCase(unittest.TestCase):
- """A subclass of twisted.trial's TestCase which looks for 'loglevel'
- attributes on both itself and its individual test methods, to override the
- root logger's logging level while that test (case|method) runs."""
- def __init__(self, methodName: str):
- super().__init__(methodName)
- method = getattr(self, methodName)
- level = getattr(method, "loglevel", getattr(self, "loglevel", None))
- @around(self)
- def setUp(orig: Callable[[], R]) -> R:
- # if we're not starting in the sentinel logcontext, then to be honest
- # all future bets are off.
- if current_context():
- self.fail(
- "Test starting with non-sentinel logging context %s"
- % (current_context(),)
- )
- # Disable GC for duration of test. See below for why.
- gc.disable()
- old_level = logging.getLogger().level
- if level is not None and old_level != level:
- @around(self)
- def tearDown(orig: Callable[[], R]) -> R:
- ret = orig()
- logging.getLogger().setLevel(old_level)
- return ret
- logging.getLogger().setLevel(level)
- # Trial messes with the warnings configuration, thus this has to be
- # done in the context of an individual TestCase.
- self.addCleanup(setup_awaitable_errors())
- return orig()
- # We want to force a GC to workaround problems with deferreds leaking
- # logcontexts when they are GCed (see the logcontext docs).
- #
- # The easiest way to do this would be to do a full GC after each test
- # run, but that is very expensive. Instead, we disable GC (above) for
- # the duration of the test and only run a gen-0 GC, which is a lot
- # quicker. This doesn't clean up everything, since the TestCase
- # instance still holds references to objects created during the test,
- # such as HomeServers, so we do a full GC every so often.
- @around(self)
- def tearDown(orig: Callable[[], R]) -> R:
- import sys
- import time
- import tracemalloc
- ret = orig()
- gc.collect(0)
- # Run a full GC every 50 gen-0 GCs.
- gen0_stats = gc.get_stats()[0]
- gen0_collections = gen0_stats["collections"]
- if gen0_collections % 50 == 0:
- if not getattr(tracemalloc, "aaa", None):
- tracemalloc.aaa = True
- tracemalloc.start()
- tracemalloc.s0 = tracemalloc.take_snapshot()
- t0 = time.time()
- gc.collect()
- dt = time.time() - t0
- s1 = tracemalloc.take_snapshot()
- for line in s1.statistics("lineno")[:20]:
- sys.stdout.write(f" {line}\n")
- sys.stdout.write(f"full collection took {dt} s\n")
- dt = time.time() - t0
- sys.stdout.write(f"snapshot took {dt} s\n")
- if dt > 1.5:
- def dump_paths(o: object, max_distance: int) -> None:
- import pdb
- from collections import deque
- queue: deque[object] = deque()
- seen: set[int] = set()
- prevs: Dict[int, object] = {}
- roots: List[object] = []
- distances: Dict[int, int] = {}
- whitelist = {id(queue), id(seen), id(prevs), id(roots)}
- i = 0
- seen.add(id(o))
- queue.append(o)
- distances[id(o)] = 0
- while len(queue) > 0:
- o = queue.popleft()
- has_referrers = False
- if (
- not isinstance(o, pdb.Pdb)
- and distances[id(o)] < max_distance
- ):
- referrers = gc.get_referrers(o)
- for referrer in referrers:
- if id(referrer) in whitelist:
- continue
- has_referrers = True
- if id(referrer) in seen:
- continue
- prevs[id(referrer)] = o
- distances[id(referrer)] = distances[id(o)] + 1
- seen.add(id(referrer))
- queue.append(referrer)
- if not has_referrers:
- roots.append(o)
- i += 1
- print(f"{len(roots)} roots")
- for root in roots:
- o = root
- while o is not None:
- print(str(o)[:200])
- o = prevs.get(id(o))
- print("")
- print(f"{len(roots)} roots")
- import pdb
- pdb.set_trace()
- gc.enable()
- set_current_context(SENTINEL_CONTEXT)
- return ret
- def assertObjectHasAttributes(self, attrs: Dict[str, object], obj: object) -> None:
- """Asserts that the given object has each of the attributes given, and
- that the value of each matches according to assertEqual."""
- for key in attrs.keys():
- if not hasattr(obj, key):
- raise AssertionError("Expected obj to have a '.%s'" % key)
- try:
- self.assertEqual(attrs[key], getattr(obj, key))
- except AssertionError as e:
- raise (type(e))(f"Assert error for '.{key}':") from e
- def assert_dict(self, required: dict, actual: dict) -> None:
- """Does a partial assert of a dict.
- Args:
- required: The keys and value which MUST be in 'actual'.
- actual: The test result. Extra keys will not be checked.
- """
- for key in required:
- self.assertEqual(
- required[key], actual[key], msg="%s mismatch. %s" % (key, actual)
- )
- def DEBUG(target: TV) -> TV:
- """A decorator to set the .loglevel attribute to logging.DEBUG.
- Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.DEBUG # type: ignore[attr-defined]
- return target
- def INFO(target: TV) -> TV:
- """A decorator to set the .loglevel attribute to logging.INFO.
- Can apply to either a TestCase or an individual test method."""
- target.loglevel = logging.INFO # type: ignore[attr-defined]
- return target
- def logcontext_clean(target: TV) -> TV:
- """A decorator which marks the TestCase or method as 'logcontext_clean'
- ... ie, any logcontext errors should cause a test failure
- """
- def logcontext_error(msg: str) -> NoReturn:
- raise AssertionError("logcontext error: %s" % (msg))
- patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
- return patcher(target) # type: ignore[call-overload]
- class HomeserverTestCase(TestCase):
- """
- A base TestCase that reduces boilerplate for HomeServer-using test cases.
- Defines a setUp method which creates a mock reactor, and instantiates a homeserver
- running on that reactor.
- There are various hooks for modifying the way that the homeserver is instantiated:
- * override make_homeserver, for example by making it pass different parameters into
- setup_test_homeserver.
- * override default_config, to return a modified configuration dictionary for use
- by setup_test_homeserver.
- * On a per-test basis, you can use the @override_config decorator to give a
- dictionary containing additional configuration settings to be added to the basic
- config dict.
- Attributes:
- servlets: List of servlet registration function.
- user_id (str): The user ID to assume if auth is hijacked.
- hijack_auth: Whether to hijack auth to return the user specified
- in user_id.
- """
- hijack_auth: ClassVar[bool] = True
- needs_threadpool: ClassVar[bool] = False
- servlets: ClassVar[List[RegisterServletsFunc]] = []
- def __init__(self, methodName: str):
- super().__init__(methodName)
- # see if we have any additional config for this test
- method = getattr(self, methodName)
- self._extra_config = getattr(method, "_extra_config", None)
- def setUp(self) -> None:
- """
- Set up the TestCase by calling the homeserver constructor, optionally
- hijacking the authentication system to return a fixed user, and then
- calling the prepare function.
- """
- self.reactor, self.clock = get_clock()
- self._hs_args = {"clock": self.clock, "reactor": self.reactor}
- self.hs = self.make_homeserver(self.reactor, self.clock)
- # Honour the `use_frozen_dicts` config option. We have to do this
- # manually because this is taken care of in the app `start` code, which
- # we don't run. Plus we want to reset it on tearDown.
- events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts
- if self.hs is None:
- raise Exception("No homeserver returned from make_homeserver.")
- if not isinstance(self.hs, HomeServer):
- raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
- # create the root resource, and a site to wrap it.
- self.resource = self.create_test_resource()
- self.site = SynapseSite(
- logger_name="synapse.access.http.fake",
- site_tag=self.hs.config.server.server_name,
- config=self.hs.config.server.listeners[0],
- resource=self.resource,
- server_version_string="1",
- max_request_body_size=4096,
- reactor=self.reactor,
- )
- from tests.rest.client.utils import RestHelper
- self.helper = RestHelper(
- self.hs,
- checked_cast(MemoryReactorClock, self.hs.get_reactor()),
- self.site,
- getattr(self, "user_id", None),
- )
- if hasattr(self, "user_id"):
- if self.hijack_auth:
- assert self.helper.auth_user_id is not None
- token = "some_fake_token"
- # We need a valid token ID to satisfy foreign key constraints.
- token_id = self.get_success(
- self.hs.get_datastores().main.add_access_token_to_user(
- self.helper.auth_user_id,
- token,
- None,
- None,
- )
- )
- # This has to be a function and not just a Mock, because
- # `self.helper.auth_user_id` is temporarily reassigned in some tests
- async def get_requester(*args: Any, **kwargs: Any) -> Requester:
- assert self.helper.auth_user_id is not None
- return create_requester(
- user_id=UserID.from_string(self.helper.auth_user_id),
- access_token_id=token_id,
- )
- # Type ignore: mypy doesn't like us assigning to methods.
- self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
- self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]
- self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment]
- if self.needs_threadpool:
- self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
- self.addCleanup(self.reactor.threadpool.stop)
- self.reactor.threadpool.start()
- if hasattr(self, "prepare"):
- self.prepare(self.reactor, self.clock, self.hs)
- def tearDown(self) -> None:
- # Reset to not use frozen dicts.
- events.USE_FROZEN_DICTS = False
- def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
- """
- Wait until a Deferred is done, where it's waiting on a real thread.
- """
- start_time = time.time()
- while not deferred.called:
- if start_time + timeout < time.time():
- raise ValueError("Timed out waiting for threadpool")
- self.reactor.advance(0.01)
- time.sleep(0.01)
- def wait_for_background_updates(self) -> None:
- """Block until all background database updates have completed."""
- store = self.hs.get_datastores().main
- while not self.get_success(
- store.db_pool.updates.has_completed_background_updates()
- ):
- self.get_success(
- store.db_pool.updates.do_next_background_update(False), by=0.1
- )
- def make_homeserver(
- self, reactor: ThreadedMemoryReactorClock, clock: Clock
- ) -> HomeServer:
- """
- Make and return a homeserver.
- Args:
- reactor: A Twisted Reactor, or something that pretends to be one.
- clock: The Clock, associated with the reactor.
- Returns:
- A homeserver suitable for testing.
- Function to be overridden in subclasses.
- """
- hs = self.setup_test_homeserver()
- return hs
- def create_test_resource(self) -> Resource:
- """
- Create a the root resource for the test server.
- The default calls `self.create_resource_dict` and builds the resultant dict
- into a tree.
- """
- root_resource = Resource()
- create_resource_tree(self.create_resource_dict(), root_resource)
- return root_resource
- def create_resource_dict(self) -> Dict[str, Resource]:
- """Create a resource tree for the test server
- A resource tree is a mapping from path to twisted.web.resource.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servlets against it.
- """
- servlet_resource = JsonResource(self.hs)
- for servlet in self.servlets:
- servlet(self.hs, servlet_resource)
- return {
- "/_matrix/client": servlet_resource,
- "/_synapse/admin": servlet_resource,
- }
- def default_config(self) -> JsonDict:
- """
- Get a default HomeServer config dict.
- """
- config = default_config("test")
- # apply any additional config which was specified via the override_config
- # decorator.
- if self._extra_config is not None:
- config.update(self._extra_config)
- return config
- def prepare(
- self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
- ) -> None:
- """
- Prepare for the test. This involves things like mocking out parts of
- the homeserver, or building test data common across the whole test
- suite.
- Args:
- reactor: A Twisted Reactor, or something that pretends to be one.
- clock: The Clock, associated with the reactor.
- homeserver: The HomeServer to test against.
- Function to optionally be overridden in subclasses.
- """
- def make_request(
- self,
- method: Union[bytes, str],
- path: Union[bytes, str],
- content: Union[bytes, str, JsonDict] = b"",
- access_token: Optional[str] = None,
- request: Type[Request] = SynapseRequest,
- shorthand: bool = True,
- federation_auth_origin: Optional[bytes] = None,
- content_is_form: bool = False,
- await_result: bool = True,
- custom_headers: Optional[Iterable[CustomHeaderType]] = None,
- client_ip: str = "127.0.0.1",
- ) -> FakeChannel:
- """
- Create a SynapseRequest at the path using the method and containing the
- given content.
- Args:
- method: The HTTP request method ("verb").
- path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces
- and such). content (bytes or dict): The body of the request.
- JSON-encoded, if a dict.
- shorthand: Whether to try and be helpful and prefix the given URL
- with the usual REST API path, if it doesn't contain it.
- federation_auth_origin: if set to not-None, we will add a fake
- Authorization header pretenting to be the given server name.
- content_is_form: Whether the content is URL encoded form data. Adds the
- 'Content-Type': 'application/x-www-form-urlencoded' header.
- await_result: whether to wait for the request to complete rendering. If
- true (the default), will pump the test reactor until the the renderer
- tells the channel the request is finished.
- custom_headers: (name, value) pairs to add as request headers
- client_ip: The IP to use as the requesting IP. Useful for testing
- ratelimiting.
- Returns:
- The FakeChannel object which stores the result of the request.
- """
- return make_request(
- self.reactor,
- self.site,
- method,
- path,
- content,
- access_token,
- request,
- shorthand,
- federation_auth_origin,
- content_is_form,
- await_result,
- custom_headers,
- client_ip,
- )
- def setup_test_homeserver(
- self, name: Optional[str] = None, **kwargs: Any
- ) -> HomeServer:
- """
- Set up the test homeserver, meant to be called by the overridable
- make_homeserver. It automatically passes through the test class's
- clock & reactor.
- Args:
- See tests.utils.setup_test_homeserver.
- Returns:
- synapse.server.HomeServer
- """
- kwargs = dict(kwargs)
- kwargs.update(self._hs_args)
- if "config" not in kwargs:
- config = self.default_config()
- else:
- config = kwargs["config"]
- # The server name can be specified using either the `name` argument or a config
- # override. The `name` argument takes precedence over any config overrides.
- if name is not None:
- config["server_name"] = name
- # Parse the config from a config dict into a HomeServerConfig
- config_obj = make_homeserver_config_obj(config)
- kwargs["config"] = config_obj
- # The server name in the config is now `name`, if provided, or the `server_name`
- # from a config override, or the default of "test". Whichever it is, we
- # construct a homeserver with a matching name.
- kwargs["name"] = config_obj.server.server_name
- async def run_bg_updates() -> None:
- with LoggingContext("run_bg_updates"):
- self.get_success(stor.db_pool.updates.run_background_updates(False))
- hs = setup_test_homeserver(self.addCleanup, **kwargs)
- stor = hs.get_datastores().main
- # Run the database background updates, when running against "master".
- if hs.__class__.__name__ == "TestHomeServer":
- self.get_success(run_bg_updates())
- return hs
- def pump(self, by: float = 0.0) -> None:
- """
- Pump the reactor enough that Deferreds will fire.
- """
- self.reactor.pump([by] * 100)
- def get_success(self, d: Awaitable[TV], by: float = 0.0) -> TV:
- deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
- self.pump(by=by)
- return self.successResultOf(deferred)
- def get_failure(
- self, d: Awaitable[Any], exc: Type[_ExcType]
- ) -> _TypedFailure[_ExcType]:
- """
- Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
- """
- deferred: Deferred[Any] = ensureDeferred(d) # type: ignore[arg-type]
- self.pump()
- return self.failureResultOf(deferred, exc)
- def get_success_or_raise(self, d: Awaitable[TV], by: float = 0.0) -> TV:
- """Drive deferred to completion and return result or raise exception
- on failure.
- """
- deferred: Deferred[TV] = ensureDeferred(d) # type: ignore[arg-type]
- results: list = []
- deferred.addBoth(results.append)
- self.pump(by=by)
- if not results:
- self.fail(
- "Success result expected on {!r}, found no result instead".format(
- deferred
- )
- )
- result = results[0]
- if isinstance(result, Failure):
- result.raiseException()
- return result
- def register_user(
- self,
- username: str,
- password: str,
- admin: Optional[bool] = False,
- displayname: Optional[str] = None,
- ) -> str:
- """
- Register a user. Requires the Admin API be registered.
- Args:
- username: The user part of the new user.
- password: The password of the new user.
- admin: Whether the user should be created as an admin or not.
- displayname: The displayname of the new user.
- Returns:
- The MXID of the new user.
- """
- self.hs.config.registration.registration_shared_secret = "shared"
- # Create the user
- channel = self.make_request("GET", "/_synapse/admin/v1/register")
- self.assertEqual(channel.code, 200, msg=channel.result)
- nonce = channel.json_body["nonce"]
- want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
- if admin:
- nonce_str += b"\x00admin"
- else:
- nonce_str += b"\x00notadmin"
- want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
- want_mac_digest = want_mac.hexdigest()
- body = {
- "nonce": nonce,
- "username": username,
- "displayname": displayname,
- "password": password,
- "admin": admin,
- "mac": want_mac_digest,
- "inhibit_login": True,
- }
- channel = self.make_request("POST", "/_synapse/admin/v1/register", body)
- self.assertEqual(channel.code, 200, channel.json_body)
- user_id = channel.json_body["user_id"]
- return user_id
- def register_appservice_user(
- self,
- username: str,
- appservice_token: str,
- ) -> Tuple[str, str]:
- """Register an appservice user as an application service.
- Requires the client-facing registration API be registered.
- Args:
- username: the user to be registered by an application service.
- Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
- appservice_token: the acccess token for that application service.
- Raises: if the request to '/register' does not return 200 OK.
- Returns:
- The MXID of the new user, the device ID of the new user's first device.
- """
- channel = self.make_request(
- "POST",
- "/_matrix/client/r0/register",
- {
- "username": username,
- "type": "m.login.application_service",
- },
- access_token=appservice_token,
- )
- self.assertEqual(channel.code, 200, channel.json_body)
- return channel.json_body["user_id"], channel.json_body["device_id"]
- def login(
- self,
- username: str,
- password: str,
- device_id: Optional[str] = None,
- additional_request_fields: Optional[Dict[str, str]] = None,
- custom_headers: Optional[Iterable[CustomHeaderType]] = None,
- ) -> str:
- """
- Log in a user, and get an access token. Requires the Login API be registered.
- Args:
- username: The localpart to assign to the new user.
- password: The password to assign to the new user.
- device_id: An optional device ID to assign to the new device created during
- login.
- additional_request_fields: A dictionary containing any additional /login
- request fields and their values.
- custom_headers: Custom HTTP headers and values to add to the /login request.
- Returns:
- The newly registered user's Matrix ID.
- """
- body = {"type": "m.login.password", "user": username, "password": password}
- if device_id:
- body["device_id"] = device_id
- if additional_request_fields:
- body.update(additional_request_fields)
- channel = self.make_request(
- "POST",
- "/_matrix/client/r0/login",
- body,
- custom_headers=custom_headers,
- )
- self.assertEqual(channel.code, 200, channel.result)
- access_token = channel.json_body["access_token"]
- return access_token
- def create_and_send_event(
- self,
- room_id: str,
- user: UserID,
- soft_failed: bool = False,
- prev_event_ids: Optional[List[str]] = None,
- ) -> str:
- """
- Create and send an event.
- Args:
- soft_failed: Whether to create a soft failed event or not
- prev_event_ids: Explicitly set the prev events,
- or if None just use the default
- Returns:
- The new event's ID.
- """
- event_creator = self.hs.get_event_creation_handler()
- requester = create_requester(user)
- event, unpersisted_context = self.get_success(
- event_creator.create_event(
- requester,
- {
- "type": EventTypes.Message,
- "room_id": room_id,
- "sender": user.to_string(),
- "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
- },
- prev_event_ids=prev_event_ids,
- )
- )
- context = self.get_success(unpersisted_context.persist(event))
- if soft_failed:
- event.internal_metadata.soft_failed = True
- self.get_success(
- event_creator.handle_new_client_event(
- requester, events_and_context=[(event, context)]
- )
- )
- return event.event_id
- def inject_room_member(self, room: str, user: str, membership: str) -> None:
- """
- Inject a membership event into a room.
- Deprecated: use event_injection.inject_room_member directly
- Args:
- room: Room ID to inject the event into.
- user: MXID of the user to inject the membership for.
- membership: The membership type.
- """
- self.get_success(
- event_injection.inject_member_event(self.hs, room, user, membership)
- )
- class FederatingHomeserverTestCase(HomeserverTestCase):
- """
- A federating homeserver, set up to validate incoming federation requests
- """
- OTHER_SERVER_NAME = "other.example.com"
- OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- super().prepare(reactor, clock, hs)
- # poke the other server's signing key into the key store, so that we don't
- # make requests for it
- verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
- verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
- self.get_success(
- hs.get_datastores().main.store_server_keys_json(
- self.OTHER_SERVER_NAME,
- verify_key_id,
- from_server=self.OTHER_SERVER_NAME,
- ts_now_ms=clock.time_msec(),
- ts_expires_ms=clock.time_msec() + 10000,
- key_json_bytes=canonicaljson.encode_canonical_json(
- {
- "verify_keys": {
- verify_key_id: {
- "key": signedjson.key.encode_verify_key_base64(
- verify_key
- )
- }
- }
- }
- ),
- )
- )
- def create_resource_dict(self) -> Dict[str, Resource]:
- d = super().create_resource_dict()
- d["/_matrix/federation"] = TransportLayerServer(self.hs)
- return d
- def make_signed_federation_request(
- self,
- method: str,
- path: str,
- content: Optional[JsonDict] = None,
- await_result: bool = True,
- custom_headers: Optional[Iterable[CustomHeaderType]] = None,
- client_ip: str = "127.0.0.1",
- ) -> FakeChannel:
- """Make an inbound signed federation request to this server
- The request is signed as if it came from "other.example.com", which our HS
- already has the keys for.
- """
- if custom_headers is None:
- custom_headers = []
- else:
- custom_headers = list(custom_headers)
- custom_headers.append(
- (
- "Authorization",
- _auth_header_for_request(
- origin=self.OTHER_SERVER_NAME,
- destination=self.hs.hostname,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
- method=method,
- path=path,
- content=content,
- ),
- )
- )
- return make_request(
- self.reactor,
- self.site,
- method=method,
- path=path,
- content=content if content is not None else "",
- shorthand=False,
- await_result=await_result,
- custom_headers=custom_headers,
- client_ip=client_ip,
- )
- def add_hashes_and_signatures_from_other_server(
- self,
- event_dict: JsonDict,
- room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
- ) -> JsonDict:
- """Adds hashes and signatures to the given event dict
- Returns:
- The modified event dict, for convenience
- """
- add_hashes_and_signatures(
- room_version,
- event_dict,
- signature_name=self.OTHER_SERVER_NAME,
- signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
- )
- return event_dict
- def _auth_header_for_request(
- origin: str,
- destination: str,
- signing_key: signedjson.key.SigningKey,
- method: str,
- path: str,
- content: Optional[JsonDict],
- ) -> str:
- """Build a suitable Authorization header for an outgoing federation request"""
- request_description: JsonDict = {
- "method": method,
- "uri": path,
- "destination": destination,
- "origin": origin,
- }
- if content is not None:
- request_description["content"] = content
- signature_base64 = unpaddedbase64.encode_base64(
- signing_key.sign(
- canonicaljson.encode_canonical_json(request_description)
- ).signature
- )
- return (
- f"X-Matrix origin={origin},"
- f"key={signing_key.alg}:{signing_key.version},"
- f"sig={signature_base64}"
- )
- def override_config(extra_config: JsonDict) -> Callable[[TV], TV]:
- """A decorator which can be applied to test functions to give additional HS config
- For use
- For example:
- class MyTestCase(HomeserverTestCase):
- @override_config({"enable_registration": False, ...})
- def test_foo(self):
- ...
- Args:
- extra_config: Additional config settings to be merged into the default
- config dict before instantiating the test homeserver.
- """
- def decorator(func: TV) -> TV:
- # This attribute is being defined.
- func._extra_config = extra_config # type: ignore[attr-defined]
- return func
- return decorator
- def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
- """A test decorator which will skip the decorated test unless a condition is set
- For example:
- class MyTestCase(TestCase):
- @skip_unless(HAS_FOO, "Cannot test without foo")
- def test_foo(self):
- ...
- Args:
- condition: If true, the test will be skipped
- reason: the reason to give for skipping the test
- """
- def decorator(f: TV) -> TV:
- if not condition:
- f.skip = reason # type: ignore
- return f
- return decorator
|