utils.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import json
  2. import logging
  3. import os
  4. from io import BytesIO
  5. from typing import Dict
  6. from unittest.mock import MagicMock
  7. import attr
  8. import twisted.logger
  9. from OpenSSL import crypto
  10. from twisted.internet import address
  11. from twisted.internet._resolver import SimpleResolverComplexifier
  12. from twisted.internet.defer import fail, succeed
  13. from twisted.internet.error import DNSLookupError
  14. from twisted.internet.interfaces import (
  15. IHostnameResolver,
  16. IReactorPluggableNameResolver,
  17. IResolverSimple,
  18. )
  19. from twisted.test.proto_helpers import MemoryReactorClock
  20. from twisted.web.http import unquote
  21. from twisted.web.http_headers import Headers
  22. from twisted.web.server import Request, Site
  23. from zope.interface import implementer
  24. from sydent.config import SydentConfig
  25. from sydent.sydent import Sydent
  26. # Expires on Jan 11 2030 at 17:53:40 GMT
  27. FAKE_SERVER_CERT_PEM = """
  28. -----BEGIN CERTIFICATE-----
  29. MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
  30. BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
  31. GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
  32. HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
  33. MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
  34. dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
  35. ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
  36. 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
  37. 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
  38. VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
  39. CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
  40. C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
  41. FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
  42. 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
  43. Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
  44. qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
  45. wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
  46. c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
  47. XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
  48. tWVEpHfT+G7AjA8=
  49. -----END CERTIFICATE-----
  50. """
  51. DEFAULT_SIGNING_KEY = "ed25519 0 broXDusfghcDAamylh2RmOEHfPJCi4snha7NNCJKOao"
  52. def make_sydent(test_config={}):
  53. """Create a new sydent
  54. Args:
  55. test_config (dict): any configuration variables for overriding the default sydent
  56. config
  57. """
  58. # Use an in-memory SQLite database. Note that the database isn't cleaned up between
  59. # tests, so by default the same database will be used for each test if changed to be
  60. # a file on disk.
  61. if "db" not in test_config:
  62. test_config["db"] = {"db.file": ":memory:"}
  63. else:
  64. test_config["db"].setdefault("db.file", ":memory:")
  65. # Set a value for the signingkey if it hasn't been set by the test
  66. if "crypto" not in test_config:
  67. test_config["crypto"] = {"ed25519.signingkey": DEFAULT_SIGNING_KEY}
  68. elif "ed25519.signingkey" not in test_config["crypto"]:
  69. test_config["crypto"] = {"ed25519.signingkey": DEFAULT_SIGNING_KEY}
  70. reactor = ResolvingMemoryReactorClock()
  71. sydent_config = SydentConfig()
  72. sydent_config.parse_config_dict(test_config)
  73. return Sydent(
  74. reactor=reactor,
  75. sydent_config=sydent_config,
  76. use_tls_for_federation=False,
  77. )
  78. @attr.s
  79. class FakeChannel:
  80. """
  81. A fake Twisted Web Channel (the part that interfaces with the
  82. wire). Mostly copied from Synapse's tests framework.
  83. """
  84. site = attr.ib(type=Site)
  85. _reactor = attr.ib()
  86. result = attr.ib(default=attr.Factory(dict))
  87. _producer = None
  88. @property
  89. def json_body(self):
  90. if not self.result:
  91. raise Exception("No result yet.")
  92. return json.loads(self.result["body"].decode("utf8"))
  93. @property
  94. def code(self):
  95. if not self.result:
  96. raise Exception("No result yet.")
  97. return int(self.result["code"])
  98. @property
  99. def headers(self):
  100. if not self.result:
  101. raise Exception("No result yet.")
  102. h = Headers()
  103. for i in self.result["headers"]:
  104. h.addRawHeader(*i)
  105. return h
  106. def writeHeaders(self, version, code, reason, headers):
  107. self.result["version"] = version
  108. self.result["code"] = code
  109. self.result["reason"] = reason
  110. self.result["headers"] = headers
  111. def write(self, content):
  112. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  113. if "body" not in self.result:
  114. self.result["body"] = b""
  115. self.result["body"] += content
  116. def registerProducer(self, producer, streaming):
  117. self._producer = producer
  118. self.producerStreaming = streaming
  119. def _produce():
  120. if self._producer:
  121. self._producer.resumeProducing()
  122. self._reactor.callLater(0.1, _produce)
  123. if not streaming:
  124. self._reactor.callLater(0.0, _produce)
  125. def unregisterProducer(self):
  126. if self._producer is None:
  127. return
  128. self._producer = None
  129. def requestDone(self, _self):
  130. self.result["done"] = True
  131. def getPeer(self):
  132. # We give an address so that getClientAddress().host returns a non null entry,
  133. # causing us to record the MAU
  134. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  135. def getHost(self):
  136. return None
  137. @property
  138. def transport(self):
  139. return self
  140. def getPeerCertificate(self):
  141. """Returns the hardcoded TLS certificate for fake.server."""
  142. return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
  143. class FakeSite:
  144. """A fake Twisted Web Site."""
  145. pass
  146. def make_request(
  147. reactor,
  148. method,
  149. path,
  150. content=b"",
  151. access_token=None,
  152. request=Request,
  153. shorthand=True,
  154. federation_auth_origin=None,
  155. ):
  156. """
  157. Make a web request using the given method and path, feed it the
  158. content, and return the Request and the Channel underneath. Mostly
  159. Args:
  160. reactor (IReactor): The Twisted reactor to use when performing the request.
  161. method (bytes or unicode): The HTTP request method ("verb").
  162. path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
  163. escaped UTF-8 & spaces and such).
  164. content (bytes or dict): The body of the request. JSON-encoded, if
  165. a dict.
  166. access_token (unicode): An access token to use to authenticate the request,
  167. None if no access token needs to be included.
  168. request (IRequest): The class to use when instantiating the request object.
  169. shorthand: Whether to try and be helpful and prefix the given URL
  170. with the usual REST API path, if it doesn't contain it.
  171. federation_auth_origin (bytes|None): if set to not-None, we will add a fake
  172. Authorization header pretenting to be the given server name.
  173. Returns:
  174. Tuple[synapse.http.site.SynapseRequest, channel]
  175. """
  176. if not isinstance(method, bytes):
  177. method = method.encode("ascii")
  178. if not isinstance(path, bytes):
  179. path = path.encode("ascii")
  180. # Decorate it to be the full path, if we're using shorthand
  181. if shorthand and not path.startswith(b"/_matrix"):
  182. path = b"/_matrix/identity/v2/" + path
  183. path = path.replace(b"//", b"/")
  184. if not path.startswith(b"/"):
  185. path = b"/" + path
  186. if isinstance(content, dict):
  187. content = json.dumps(content)
  188. if isinstance(content, str):
  189. content = content.encode("utf8")
  190. site = FakeSite()
  191. channel = FakeChannel(site, reactor)
  192. req = request(channel)
  193. req.process = lambda: b""
  194. req.content = BytesIO(content)
  195. req.postpath = list(map(unquote, path[1:].split(b"/")))
  196. if access_token:
  197. req.requestHeaders.addRawHeader(
  198. b"Authorization", b"Bearer " + access_token.encode("ascii")
  199. )
  200. if federation_auth_origin is not None:
  201. req.requestHeaders.addRawHeader(
  202. b"Authorization",
  203. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  204. )
  205. if content:
  206. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  207. req.requestReceived(method, path, b"1.1")
  208. return req, channel
  209. class ToTwistedHandler(logging.Handler):
  210. """logging handler which sends the logs to the twisted log"""
  211. tx_log = twisted.logger.Logger()
  212. def emit(self, record):
  213. log_entry = self.format(record)
  214. log_level = record.levelname.lower().replace("warning", "warn")
  215. self.tx_log.emit(
  216. twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
  217. )
  218. def setup_logging():
  219. """Configure the python logging appropriately for the tests.
  220. (Logs will end up in _trial_temp.)
  221. """
  222. root_logger = logging.getLogger()
  223. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  224. handler = ToTwistedHandler()
  225. formatter = logging.Formatter(log_format)
  226. handler.setFormatter(formatter)
  227. root_logger.addHandler(handler)
  228. log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
  229. root_logger.setLevel(log_level)
  230. setup_logging()
  231. @implementer(IReactorPluggableNameResolver)
  232. class ResolvingMemoryReactorClock(MemoryReactorClock):
  233. """
  234. A MemoryReactorClock that supports name resolution.
  235. """
  236. def __init__(self):
  237. lookups = self.lookups = {} # type: Dict[str, str]
  238. @implementer(IResolverSimple)
  239. class FakeResolver:
  240. def getHostByName(self, name, timeout=None):
  241. if name not in lookups:
  242. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  243. return succeed(lookups[name])
  244. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  245. super().__init__()
  246. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  247. raise NotImplementedError()
  248. class AsyncMock(MagicMock):
  249. async def __call__(self, *args, **kwargs):
  250. return super().__call__(*args, **kwargs)