123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325 |
- import json
- import logging
- import os
- from io import BytesIO
- from typing import Dict, Optional
- from unittest.mock import MagicMock
- import attr
- import twisted.logger
- from OpenSSL import crypto
- from twisted.internet import address
- from twisted.internet._resolver import SimpleResolverComplexifier
- from twisted.internet.defer import fail, succeed
- from twisted.internet.error import DNSLookupError
- from twisted.internet.interfaces import (
- IHostnameResolver,
- IReactorPluggableNameResolver,
- IResolverSimple,
- )
- from twisted.test.proto_helpers import MemoryReactorClock
- from twisted.web.http import unquote
- from twisted.web.http_headers import Headers
- from twisted.web.server import Request, Site
- from zope.interface import implementer
- from sydent.config import SydentConfig
- from sydent.sydent import Sydent
- # Expires on Jan 11 2030 at 17:53:40 GMT
- FAKE_SERVER_CERT_PEM = """
- -----BEGIN CERTIFICATE-----
- MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
- BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
- GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
- HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
- MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
- dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
- ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
- 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
- 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
- VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
- CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
- C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
- FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
- 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
- Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
- qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
- wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
- c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
- XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
- tWVEpHfT+G7AjA8=
- -----END CERTIFICATE-----
- """
- def make_sydent(test_config: Optional[dict] = None) -> Sydent:
- """Create a new sydent
- Args:
- test_config: Configuration variables for overriding the default sydent
- config
- """
- if test_config is None:
- test_config = {}
- # Use an in-memory SQLite database. Note that the database isn't cleaned up between
- # tests, so by default the same database will be used for each test if changed to be
- # a file on disk.
- test_config.setdefault("db", {}).setdefault("db.file", ":memory:")
- # Specify a server name to avoid warnings.
- general_config = test_config.setdefault("general", {})
- general_config.setdefault("server.name", ":test:")
- # Specify the default templates.
- general_config.setdefault(
- "templates.path",
- os.path.join(os.path.dirname(os.path.dirname(__file__)), "res"),
- )
- # Specify a signing key.
- test_config.setdefault("crypto", {}).setdefault(
- "ed25519.signingkey", "ed25519 0 FJi1Rnpj3/otydngacrwddFvwz/dTDsBv62uZDN2fZM"
- )
- reactor = ResolvingMemoryReactorClock()
- sydent_config = SydentConfig()
- sydent_config.parse_config_dict(test_config)
- return Sydent(
- reactor=reactor,
- sydent_config=sydent_config,
- use_tls_for_federation=False,
- )
- @attr.s
- class FakeChannel:
- """
- A fake Twisted Web Channel (the part that interfaces with the
- wire). Mostly copied from Synapse's tests framework.
- """
- site = attr.ib(type=Site)
- _reactor = attr.ib()
- result = attr.ib(default=attr.Factory(dict))
- _producer = None
- @property
- def json_body(self):
- if not self.result:
- raise Exception("No result yet.")
- return json.loads(self.result["body"].decode("utf8"))
- @property
- def code(self):
- if not self.result:
- raise Exception("No result yet.")
- return int(self.result["code"])
- @property
- def headers(self):
- if not self.result:
- raise Exception("No result yet.")
- h = Headers()
- for i in self.result["headers"]:
- h.addRawHeader(*i)
- return h
- def writeHeaders(self, version, code, reason, headers):
- self.result["version"] = version
- self.result["code"] = code
- self.result["reason"] = reason
- self.result["headers"] = headers
- def write(self, content):
- assert isinstance(content, bytes), "Should be bytes! " + repr(content)
- if "body" not in self.result:
- self.result["body"] = b""
- self.result["body"] += content
- def registerProducer(self, producer, streaming):
- self._producer = producer
- self.producerStreaming = streaming
- def _produce():
- if self._producer:
- self._producer.resumeProducing()
- self._reactor.callLater(0.1, _produce)
- if not streaming:
- self._reactor.callLater(0.0, _produce)
- def unregisterProducer(self):
- if self._producer is None:
- return
- self._producer = None
- def requestDone(self, _self):
- self.result["done"] = True
- def getPeer(self):
- # We give an address so that getClientAddress().host returns a non null entry,
- # causing us to record the MAU
- return address.IPv4Address("TCP", "127.0.0.1", 3423)
- def getHost(self):
- return None
- @property
- def transport(self):
- return self
- def getPeerCertificate(self):
- """Returns the hardcoded TLS certificate for fake.server."""
- return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
- class FakeSite:
- """A fake Twisted Web Site."""
- pass
- def make_request(
- reactor,
- site,
- method,
- path,
- content=b"",
- access_token=None,
- request=Request,
- shorthand=True,
- federation_auth_origin=None,
- ):
- """
- Make a web request using the given method and path, feed it the
- content, and return the Request and the Channel underneath. Mostly
- Args:
- reactor (IReactor): The Twisted reactor to use when performing the request.
- site (
- method (bytes or unicode): The HTTP request method ("verb").
- path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
- escaped UTF-8 & spaces and such).
- content (bytes or dict): The body of the request. JSON-encoded, if
- a dict.
- access_token (unicode): An access token to use to authenticate the request,
- None if no access token needs to be included.
- request (IRequest): The class to use when instantiating the request object.
- shorthand: Whether to try and be helpful and prefix the given URL
- with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
- Authorization header pretenting to be the given server name.
- Returns:
- Tuple[synapse.http.site.SynapseRequest, channel]
- """
- if not isinstance(method, bytes):
- method = method.encode("ascii")
- if not isinstance(path, bytes):
- path = path.encode("ascii")
- # Decorate it to be the full path, if we're using shorthand
- if shorthand and not path.startswith(b"/_matrix"):
- path = b"/_matrix/identity/v2/" + path
- path = path.replace(b"//", b"/")
- if not path.startswith(b"/"):
- path = b"/" + path
- if isinstance(content, dict):
- content = json.dumps(content)
- if isinstance(content, str):
- content = content.encode("utf8")
- channel = FakeChannel(site, reactor)
- req = request(channel)
- req.content = BytesIO(content)
- req.postpath = list(map(unquote, path[1:].split(b"/")))
- if access_token:
- req.requestHeaders.addRawHeader(
- b"Authorization", b"Bearer " + access_token.encode("ascii")
- )
- if federation_auth_origin is not None:
- req.requestHeaders.addRawHeader(
- b"Authorization",
- b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
- )
- if content:
- req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
- req.requestReceived(method, path, b"1.1")
- return req, channel
- class ToTwistedHandler(logging.Handler):
- """logging handler which sends the logs to the twisted log"""
- tx_log = twisted.logger.Logger()
- def emit(self, record):
- log_entry = self.format(record)
- log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit(
- twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
- )
- def setup_logging():
- """Configure the python logging appropriately for the tests.
- (Logs will end up in _trial_temp.)
- """
- root_logger = logging.getLogger()
- log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
- handler = ToTwistedHandler()
- formatter = logging.Formatter(log_format)
- handler.setFormatter(formatter)
- root_logger.addHandler(handler)
- log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
- root_logger.setLevel(log_level)
- setup_logging()
- @implementer(IReactorPluggableNameResolver)
- class ResolvingMemoryReactorClock(MemoryReactorClock):
- """
- A MemoryReactorClock that supports name resolution.
- """
- def __init__(self):
- lookups = self.lookups = {} # type: Dict[str, str]
- @implementer(IResolverSimple)
- class FakeResolver:
- def getHostByName(self, name, timeout=None):
- if name not in lookups:
- return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
- return succeed(lookups[name])
- self.nameResolver = SimpleResolverComplexifier(FakeResolver())
- super().__init__()
- def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
- raise NotImplementedError()
- class AsyncMock(MagicMock):
- async def __call__(self, *args, **kwargs):
- return super().__call__(*args, **kwargs)
|