import json import logging import os from io import BytesIO from typing import Dict from unittest.mock import MagicMock import attr import twisted.logger from OpenSSL import crypto from twisted.internet import address from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.defer import fail, succeed from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import ( IHostnameResolver, IReactorPluggableNameResolver, IResolverSimple, ) from twisted.test.proto_helpers import MemoryReactorClock from twisted.web.http import unquote from twisted.web.http_headers import Headers from twisted.web.server import Request, Site from zope.interface import implementer from sydent.config import SydentConfig from sydent.sydent import Sydent # Expires on Jan 11 2030 at 17:53:40 GMT FAKE_SERVER_CERT_PEM = """ -----BEGIN CERTIFICATE----- MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+ Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE tWVEpHfT+G7AjA8= -----END CERTIFICATE----- """ def make_sydent(test_config={}): """Create a new sydent Args: test_config (dict): any configuration variables for overriding the default sydent config """ # Use an in-memory SQLite database. Note that the database isn't cleaned up between # tests, so by default the same database will be used for each test if changed to be # a file on disk. if "db" not in test_config: test_config["db"] = {"db.file": ":memory:"} else: test_config["db"].setdefault("db.file", ":memory:") reactor = ResolvingMemoryReactorClock() sydent_config = SydentConfig() sydent_config.parse_config_dict(test_config) return Sydent( reactor=reactor, sydent_config=sydent_config, use_tls_for_federation=False, ) @attr.s class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). Mostly copied from Synapse's tests framework. """ site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(default=attr.Factory(dict)) _producer = None @property def json_body(self): if not self.result: raise Exception("No result yet.") return json.loads(self.result["body"].decode("utf8")) @property def code(self): if not self.result: raise Exception("No result yet.") return int(self.result["code"]) @property def headers(self): if not self.result: raise Exception("No result yet.") h = Headers() for i in self.result["headers"]: h.addRawHeader(*i) return h def writeHeaders(self, version, code, reason, headers): self.result["version"] = version self.result["code"] = code self.result["reason"] = reason self.result["headers"] = headers def write(self, content): assert isinstance(content, bytes), "Should be bytes! " + repr(content) if "body" not in self.result: self.result["body"] = b"" self.result["body"] += content def registerProducer(self, producer, streaming): self._producer = producer self.producerStreaming = streaming def _produce(): if self._producer: self._producer.resumeProducing() self._reactor.callLater(0.1, _produce) if not streaming: self._reactor.callLater(0.0, _produce) def unregisterProducer(self): if self._producer is None: return self._producer = None def requestDone(self, _self): self.result["done"] = True def getPeer(self): # We give an address so that getClientAddress().host returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", "127.0.0.1", 3423) def getHost(self): return None @property def transport(self): return self def getPeerCertificate(self): """Returns the hardcoded TLS certificate for fake.server.""" return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM) class FakeSite: """A fake Twisted Web Site.""" pass def make_request( reactor, method, path, content=b"", access_token=None, request=Request, shorthand=True, federation_auth_origin=None, ): """ Make a web request using the given method and path, feed it the content, and return the Request and the Channel underneath. Mostly Args: reactor (IReactor): The Twisted reactor to use when performing the request. method (bytes or unicode): The HTTP request method ("verb"). path (bytes or unicode): The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). content (bytes or dict): The body of the request. JSON-encoded, if a dict. access_token (unicode): An access token to use to authenticate the request, None if no access token needs to be included. request (IRequest): The class to use when instantiating the request object. shorthand: Whether to try and be helpful and prefix the given URL with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. Returns: Tuple[synapse.http.site.SynapseRequest, channel] """ if not isinstance(method, bytes): method = method.encode("ascii") if not isinstance(path, bytes): path = path.encode("ascii") # Decorate it to be the full path, if we're using shorthand if shorthand and not path.startswith(b"/_matrix"): path = b"/_matrix/identity/v2/" + path path = path.replace(b"//", b"/") if not path.startswith(b"/"): path = b"/" + path if isinstance(content, dict): content = json.dumps(content) if isinstance(content, str): content = content.encode("utf8") site = FakeSite() channel = FakeChannel(site, reactor) req = request(channel) req.process = lambda: b"" req.content = BytesIO(content) req.postpath = list(map(unquote, path[1:].split(b"/"))) if access_token: req.requestHeaders.addRawHeader( b"Authorization", b"Bearer " + access_token.encode("ascii") ) if federation_auth_origin is not None: req.requestHeaders.addRawHeader( b"Authorization", b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,), ) if content: req.requestHeaders.addRawHeader(b"Content-Type", b"application/json") req.requestReceived(method, path, b"1.1") return req, channel class ToTwistedHandler(logging.Handler): """logging handler which sends the logs to the twisted log""" tx_log = twisted.logger.Logger() def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) def setup_logging(): """Configure the python logging appropriately for the tests. (Logs will end up in _trial_temp.) """ root_logger = logging.getLogger() log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s" handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) root_logger.addHandler(handler) log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR") root_logger.setLevel(log_level) setup_logging() @implementer(IReactorPluggableNameResolver) class ResolvingMemoryReactorClock(MemoryReactorClock): """ A MemoryReactorClock that supports name resolution. """ def __init__(self): lookups = self.lookups = {} # type: Dict[str, str] @implementer(IResolverSimple) class FakeResolver: def getHostByName(self, name, timeout=None): if name not in lookups: return fail(DNSLookupError("OH NO: unknown %s" % (name,))) return succeed(lookups[name]) self.nameResolver = SimpleResolverComplexifier(FakeResolver()) super().__init__() def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() class AsyncMock(MagicMock): async def __call__(self, *args, **kwargs): return super().__call__(*args, **kwargs)