utils.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. import json
  2. import logging
  3. import os
  4. from io import BytesIO
  5. from typing import Dict
  6. from unittest.mock import MagicMock
  7. import attr
  8. import twisted.logger
  9. from OpenSSL import crypto
  10. from twisted.internet import address
  11. from twisted.internet._resolver import SimpleResolverComplexifier
  12. from twisted.internet.defer import fail, succeed
  13. from twisted.internet.error import DNSLookupError
  14. from twisted.internet.interfaces import (
  15. IHostnameResolver,
  16. IReactorPluggableNameResolver,
  17. IResolverSimple,
  18. )
  19. from twisted.test.proto_helpers import MemoryReactorClock
  20. from twisted.web.http import unquote
  21. from twisted.web.http_headers import Headers
  22. from twisted.web.server import Request, Site
  23. from zope.interface import implementer
  24. from sydent.config import SydentConfig
  25. from sydent.sydent import Sydent
  26. # Expires on Jan 11 2030 at 17:53:40 GMT
  27. FAKE_SERVER_CERT_PEM = """
  28. -----BEGIN CERTIFICATE-----
  29. MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
  30. BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
  31. GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
  32. HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
  33. MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
  34. dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
  35. ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
  36. 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
  37. 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
  38. VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
  39. CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
  40. C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
  41. FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
  42. 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
  43. Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
  44. qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
  45. wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
  46. c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
  47. XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
  48. tWVEpHfT+G7AjA8=
  49. -----END CERTIFICATE-----
  50. """
  51. def make_sydent(test_config={}):
  52. """Create a new sydent
  53. Args:
  54. test_config (dict): any configuration variables for overriding the default sydent
  55. config
  56. """
  57. # Use an in-memory SQLite database. Note that the database isn't cleaned up between
  58. # tests, so by default the same database will be used for each test if changed to be
  59. # a file on disk.
  60. if "db" not in test_config:
  61. test_config["db"] = {"db.file": ":memory:"}
  62. else:
  63. test_config["db"].setdefault("db.file", ":memory:")
  64. reactor = ResolvingMemoryReactorClock()
  65. sydent_config = SydentConfig()
  66. sydent_config.parse_config_dict(test_config)
  67. return Sydent(
  68. reactor=reactor,
  69. sydent_config=sydent_config,
  70. use_tls_for_federation=False,
  71. )
  72. @attr.s
  73. class FakeChannel:
  74. """
  75. A fake Twisted Web Channel (the part that interfaces with the
  76. wire). Mostly copied from Synapse's tests framework.
  77. """
  78. site = attr.ib(type=Site)
  79. _reactor = attr.ib()
  80. result = attr.ib(default=attr.Factory(dict))
  81. _producer = None
  82. @property
  83. def json_body(self):
  84. if not self.result:
  85. raise Exception("No result yet.")
  86. return json.loads(self.result["body"].decode("utf8"))
  87. @property
  88. def code(self):
  89. if not self.result:
  90. raise Exception("No result yet.")
  91. return int(self.result["code"])
  92. @property
  93. def headers(self):
  94. if not self.result:
  95. raise Exception("No result yet.")
  96. h = Headers()
  97. for i in self.result["headers"]:
  98. h.addRawHeader(*i)
  99. return h
  100. def writeHeaders(self, version, code, reason, headers):
  101. self.result["version"] = version
  102. self.result["code"] = code
  103. self.result["reason"] = reason
  104. self.result["headers"] = headers
  105. def write(self, content):
  106. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  107. if "body" not in self.result:
  108. self.result["body"] = b""
  109. self.result["body"] += content
  110. def registerProducer(self, producer, streaming):
  111. self._producer = producer
  112. self.producerStreaming = streaming
  113. def _produce():
  114. if self._producer:
  115. self._producer.resumeProducing()
  116. self._reactor.callLater(0.1, _produce)
  117. if not streaming:
  118. self._reactor.callLater(0.0, _produce)
  119. def unregisterProducer(self):
  120. if self._producer is None:
  121. return
  122. self._producer = None
  123. def requestDone(self, _self):
  124. self.result["done"] = True
  125. def getPeer(self):
  126. # We give an address so that getClientAddress().host returns a non null entry,
  127. # causing us to record the MAU
  128. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  129. def getHost(self):
  130. return None
  131. @property
  132. def transport(self):
  133. return self
  134. def getPeerCertificate(self):
  135. """Returns the hardcoded TLS certificate for fake.server."""
  136. return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
  137. class FakeSite:
  138. """A fake Twisted Web Site."""
  139. pass
  140. def make_request(
  141. reactor,
  142. method,
  143. path,
  144. content=b"",
  145. access_token=None,
  146. request=Request,
  147. shorthand=True,
  148. federation_auth_origin=None,
  149. ):
  150. """
  151. Make a web request using the given method and path, feed it the
  152. content, and return the Request and the Channel underneath. Mostly
  153. Args:
  154. reactor (IReactor): The Twisted reactor to use when performing the request.
  155. method (bytes or unicode): The HTTP request method ("verb").
  156. path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
  157. escaped UTF-8 & spaces and such).
  158. content (bytes or dict): The body of the request. JSON-encoded, if
  159. a dict.
  160. access_token (unicode): An access token to use to authenticate the request,
  161. None if no access token needs to be included.
  162. request (IRequest): The class to use when instantiating the request object.
  163. shorthand: Whether to try and be helpful and prefix the given URL
  164. with the usual REST API path, if it doesn't contain it.
  165. federation_auth_origin (bytes|None): if set to not-None, we will add a fake
  166. Authorization header pretenting to be the given server name.
  167. Returns:
  168. Tuple[synapse.http.site.SynapseRequest, channel]
  169. """
  170. if not isinstance(method, bytes):
  171. method = method.encode("ascii")
  172. if not isinstance(path, bytes):
  173. path = path.encode("ascii")
  174. # Decorate it to be the full path, if we're using shorthand
  175. if shorthand and not path.startswith(b"/_matrix"):
  176. path = b"/_matrix/identity/v2/" + path
  177. path = path.replace(b"//", b"/")
  178. if not path.startswith(b"/"):
  179. path = b"/" + path
  180. if isinstance(content, dict):
  181. content = json.dumps(content)
  182. if isinstance(content, str):
  183. content = content.encode("utf8")
  184. site = FakeSite()
  185. channel = FakeChannel(site, reactor)
  186. req = request(channel)
  187. req.process = lambda: b""
  188. req.content = BytesIO(content)
  189. req.postpath = list(map(unquote, path[1:].split(b"/")))
  190. if access_token:
  191. req.requestHeaders.addRawHeader(
  192. b"Authorization", b"Bearer " + access_token.encode("ascii")
  193. )
  194. if federation_auth_origin is not None:
  195. req.requestHeaders.addRawHeader(
  196. b"Authorization",
  197. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  198. )
  199. if content:
  200. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  201. req.requestReceived(method, path, b"1.1")
  202. return req, channel
  203. class ToTwistedHandler(logging.Handler):
  204. """logging handler which sends the logs to the twisted log"""
  205. tx_log = twisted.logger.Logger()
  206. def emit(self, record):
  207. log_entry = self.format(record)
  208. log_level = record.levelname.lower().replace("warning", "warn")
  209. self.tx_log.emit(
  210. twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
  211. )
  212. def setup_logging():
  213. """Configure the python logging appropriately for the tests.
  214. (Logs will end up in _trial_temp.)
  215. """
  216. root_logger = logging.getLogger()
  217. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  218. handler = ToTwistedHandler()
  219. formatter = logging.Formatter(log_format)
  220. handler.setFormatter(formatter)
  221. root_logger.addHandler(handler)
  222. log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
  223. root_logger.setLevel(log_level)
  224. setup_logging()
  225. @implementer(IReactorPluggableNameResolver)
  226. class ResolvingMemoryReactorClock(MemoryReactorClock):
  227. """
  228. A MemoryReactorClock that supports name resolution.
  229. """
  230. def __init__(self):
  231. lookups = self.lookups = {} # type: Dict[str, str]
  232. @implementer(IResolverSimple)
  233. class FakeResolver:
  234. def getHostByName(self, name, timeout=None):
  235. if name not in lookups:
  236. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  237. return succeed(lookups[name])
  238. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  239. super().__init__()
  240. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  241. raise NotImplementedError()
  242. class AsyncMock(MagicMock):
  243. async def __call__(self, *args, **kwargs):
  244. return super().__call__(*args, **kwargs)