utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325
  1. import json
  2. import logging
  3. import os
  4. from io import BytesIO
  5. from typing import Dict, Optional
  6. from unittest.mock import MagicMock
  7. import attr
  8. import twisted.logger
  9. from OpenSSL import crypto
  10. from twisted.internet import address
  11. from twisted.internet._resolver import SimpleResolverComplexifier
  12. from twisted.internet.defer import fail, succeed
  13. from twisted.internet.error import DNSLookupError
  14. from twisted.internet.interfaces import (
  15. IHostnameResolver,
  16. IReactorPluggableNameResolver,
  17. IResolverSimple,
  18. )
  19. from twisted.test.proto_helpers import MemoryReactorClock
  20. from twisted.web.http import unquote
  21. from twisted.web.http_headers import Headers
  22. from twisted.web.server import Request, Site
  23. from zope.interface import implementer
  24. from sydent.config import SydentConfig
  25. from sydent.sydent import Sydent
  26. # Expires on Jan 11 2030 at 17:53:40 GMT
  27. FAKE_SERVER_CERT_PEM = """
  28. -----BEGIN CERTIFICATE-----
  29. MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
  30. BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
  31. GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
  32. HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
  33. MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
  34. dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
  35. ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
  36. 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
  37. 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
  38. VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
  39. CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
  40. C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
  41. FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
  42. 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
  43. Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
  44. qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
  45. wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
  46. c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
  47. XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
  48. tWVEpHfT+G7AjA8=
  49. -----END CERTIFICATE-----
  50. """
  51. def make_sydent(test_config: Optional[dict] = None) -> Sydent:
  52. """Create a new sydent
  53. Args:
  54. test_config: Configuration variables for overriding the default sydent
  55. config
  56. """
  57. if test_config is None:
  58. test_config = {}
  59. # Use an in-memory SQLite database. Note that the database isn't cleaned up between
  60. # tests, so by default the same database will be used for each test if changed to be
  61. # a file on disk.
  62. test_config.setdefault("db", {}).setdefault("db.file", ":memory:")
  63. # Specify a server name to avoid warnings.
  64. general_config = test_config.setdefault("general", {})
  65. general_config.setdefault("server.name", ":test:")
  66. # Specify the default templates.
  67. general_config.setdefault(
  68. "templates.path",
  69. os.path.join(os.path.dirname(os.path.dirname(__file__)), "res"),
  70. )
  71. # Specify a signing key.
  72. test_config.setdefault("crypto", {}).setdefault(
  73. "ed25519.signingkey", "ed25519 0 FJi1Rnpj3/otydngacrwddFvwz/dTDsBv62uZDN2fZM"
  74. )
  75. reactor = ResolvingMemoryReactorClock()
  76. sydent_config = SydentConfig()
  77. sydent_config.parse_config_dict(test_config)
  78. return Sydent(
  79. reactor=reactor,
  80. sydent_config=sydent_config,
  81. use_tls_for_federation=False,
  82. )
  83. @attr.s
  84. class FakeChannel:
  85. """
  86. A fake Twisted Web Channel (the part that interfaces with the
  87. wire). Mostly copied from Synapse's tests framework.
  88. """
  89. site = attr.ib(type=Site)
  90. _reactor = attr.ib()
  91. result = attr.ib(default=attr.Factory(dict))
  92. _producer = None
  93. @property
  94. def json_body(self):
  95. if not self.result:
  96. raise Exception("No result yet.")
  97. return json.loads(self.result["body"].decode("utf8"))
  98. @property
  99. def code(self):
  100. if not self.result:
  101. raise Exception("No result yet.")
  102. return int(self.result["code"])
  103. @property
  104. def headers(self):
  105. if not self.result:
  106. raise Exception("No result yet.")
  107. h = Headers()
  108. for i in self.result["headers"]:
  109. h.addRawHeader(*i)
  110. return h
  111. def writeHeaders(self, version, code, reason, headers):
  112. self.result["version"] = version
  113. self.result["code"] = code
  114. self.result["reason"] = reason
  115. self.result["headers"] = headers
  116. def write(self, content):
  117. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  118. if "body" not in self.result:
  119. self.result["body"] = b""
  120. self.result["body"] += content
  121. def registerProducer(self, producer, streaming):
  122. self._producer = producer
  123. self.producerStreaming = streaming
  124. def _produce():
  125. if self._producer:
  126. self._producer.resumeProducing()
  127. self._reactor.callLater(0.1, _produce)
  128. if not streaming:
  129. self._reactor.callLater(0.0, _produce)
  130. def unregisterProducer(self):
  131. if self._producer is None:
  132. return
  133. self._producer = None
  134. def requestDone(self, _self):
  135. self.result["done"] = True
  136. def getPeer(self):
  137. # We give an address so that getClientAddress().host returns a non null entry,
  138. # causing us to record the MAU
  139. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  140. def getHost(self):
  141. return None
  142. @property
  143. def transport(self):
  144. return self
  145. def getPeerCertificate(self):
  146. """Returns the hardcoded TLS certificate for fake.server."""
  147. return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
  148. class FakeSite:
  149. """A fake Twisted Web Site."""
  150. pass
  151. def make_request(
  152. reactor,
  153. site,
  154. method,
  155. path,
  156. content=b"",
  157. access_token=None,
  158. request=Request,
  159. shorthand=True,
  160. federation_auth_origin=None,
  161. ):
  162. """
  163. Make a web request using the given method and path, feed it the
  164. content, and return the Request and the Channel underneath. Mostly
  165. Args:
  166. reactor (IReactor): The Twisted reactor to use when performing the request.
  167. site (
  168. method (bytes or unicode): The HTTP request method ("verb").
  169. path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
  170. escaped UTF-8 & spaces and such).
  171. content (bytes or dict): The body of the request. JSON-encoded, if
  172. a dict.
  173. access_token (unicode): An access token to use to authenticate the request,
  174. None if no access token needs to be included.
  175. request (IRequest): The class to use when instantiating the request object.
  176. shorthand: Whether to try and be helpful and prefix the given URL
  177. with the usual REST API path, if it doesn't contain it.
  178. federation_auth_origin (bytes|None): if set to not-None, we will add a fake
  179. Authorization header pretenting to be the given server name.
  180. Returns:
  181. Tuple[synapse.http.site.SynapseRequest, channel]
  182. """
  183. if not isinstance(method, bytes):
  184. method = method.encode("ascii")
  185. if not isinstance(path, bytes):
  186. path = path.encode("ascii")
  187. # Decorate it to be the full path, if we're using shorthand
  188. if shorthand and not path.startswith(b"/_matrix"):
  189. path = b"/_matrix/identity/v2/" + path
  190. path = path.replace(b"//", b"/")
  191. if not path.startswith(b"/"):
  192. path = b"/" + path
  193. if isinstance(content, dict):
  194. content = json.dumps(content)
  195. if isinstance(content, str):
  196. content = content.encode("utf8")
  197. channel = FakeChannel(site, reactor)
  198. req = request(channel)
  199. req.content = BytesIO(content)
  200. req.postpath = list(map(unquote, path[1:].split(b"/")))
  201. if access_token:
  202. req.requestHeaders.addRawHeader(
  203. b"Authorization", b"Bearer " + access_token.encode("ascii")
  204. )
  205. if federation_auth_origin is not None:
  206. req.requestHeaders.addRawHeader(
  207. b"Authorization",
  208. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  209. )
  210. if content:
  211. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  212. req.requestReceived(method, path, b"1.1")
  213. return req, channel
  214. class ToTwistedHandler(logging.Handler):
  215. """logging handler which sends the logs to the twisted log"""
  216. tx_log = twisted.logger.Logger()
  217. def emit(self, record):
  218. log_entry = self.format(record)
  219. log_level = record.levelname.lower().replace("warning", "warn")
  220. self.tx_log.emit(
  221. twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
  222. )
  223. def setup_logging():
  224. """Configure the python logging appropriately for the tests.
  225. (Logs will end up in _trial_temp.)
  226. """
  227. root_logger = logging.getLogger()
  228. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  229. handler = ToTwistedHandler()
  230. formatter = logging.Formatter(log_format)
  231. handler.setFormatter(formatter)
  232. root_logger.addHandler(handler)
  233. log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
  234. root_logger.setLevel(log_level)
  235. setup_logging()
  236. @implementer(IReactorPluggableNameResolver)
  237. class ResolvingMemoryReactorClock(MemoryReactorClock):
  238. """
  239. A MemoryReactorClock that supports name resolution.
  240. """
  241. def __init__(self):
  242. lookups = self.lookups = {} # type: Dict[str, str]
  243. @implementer(IResolverSimple)
  244. class FakeResolver:
  245. def getHostByName(self, name, timeout=None):
  246. if name not in lookups:
  247. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  248. return succeed(lookups[name])
  249. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  250. super().__init__()
  251. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  252. raise NotImplementedError()
  253. class AsyncMock(MagicMock):
  254. async def __call__(self, *args, **kwargs):
  255. return super().__call__(*args, **kwargs)