utils.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import json
  2. from io import BytesIO
  3. import logging
  4. import os
  5. from typing import Dict
  6. import attr
  7. from six import text_type
  8. from zope.interface import implementer
  9. from twisted.internet._resolver import SimpleResolverComplexifier
  10. from twisted.internet.defer import fail, succeed
  11. from twisted.internet.error import DNSLookupError
  12. from twisted.internet.interfaces import (
  13. IHostnameResolver,
  14. IReactorPluggableNameResolver,
  15. IResolverSimple,
  16. )
  17. from twisted.internet import address
  18. import twisted.logger
  19. from twisted.web.http_headers import Headers
  20. from twisted.web.server import Request, Site
  21. from twisted.web.http import unquote
  22. from twisted.test.proto_helpers import MemoryReactorClock
  23. from OpenSSL import crypto
  24. from sydent.sydent import Sydent, parse_config_dict
  25. # Expires on Jan 11 2030 at 17:53:40 GMT
  26. FAKE_SERVER_CERT_PEM = """
  27. -----BEGIN CERTIFICATE-----
  28. MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
  29. BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
  30. GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
  31. HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
  32. MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
  33. dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
  34. ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
  35. 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
  36. 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
  37. VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
  38. CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
  39. C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
  40. FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
  41. 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
  42. Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
  43. qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
  44. wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
  45. c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
  46. XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
  47. tWVEpHfT+G7AjA8=
  48. -----END CERTIFICATE-----
  49. """
  50. def make_sydent(test_config={}):
  51. """Create a new sydent
  52. Args:
  53. test_config (dict): any configuration variables for overriding the default sydent
  54. config
  55. """
  56. # Use an in-memory SQLite database. Note that the database isn't cleaned up between
  57. # tests, so by default the same database will be used for each test if changed to be
  58. # a file on disk.
  59. if "db" not in test_config:
  60. test_config["db"] = {"db.file": ":memory:"}
  61. else:
  62. test_config["db"].setdefault("db.file", ":memory:")
  63. reactor = ResolvingMemoryReactorClock()
  64. return Sydent(reactor=reactor, cfg=parse_config_dict(test_config), use_tls_for_federation=False)
  65. @attr.s
  66. class FakeChannel(object):
  67. """
  68. A fake Twisted Web Channel (the part that interfaces with the
  69. wire). Mostly copied from Synapse's tests framework.
  70. """
  71. site = attr.ib(type=Site)
  72. _reactor = attr.ib()
  73. result = attr.ib(default=attr.Factory(dict))
  74. _producer = None
  75. @property
  76. def json_body(self):
  77. if not self.result:
  78. raise Exception("No result yet.")
  79. return json.loads(self.result["body"].decode("utf8"))
  80. @property
  81. def code(self):
  82. if not self.result:
  83. raise Exception("No result yet.")
  84. return int(self.result["code"])
  85. @property
  86. def headers(self):
  87. if not self.result:
  88. raise Exception("No result yet.")
  89. h = Headers()
  90. for i in self.result["headers"]:
  91. h.addRawHeader(*i)
  92. return h
  93. def writeHeaders(self, version, code, reason, headers):
  94. self.result["version"] = version
  95. self.result["code"] = code
  96. self.result["reason"] = reason
  97. self.result["headers"] = headers
  98. def write(self, content):
  99. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  100. if "body" not in self.result:
  101. self.result["body"] = b""
  102. self.result["body"] += content
  103. def registerProducer(self, producer, streaming):
  104. self._producer = producer
  105. self.producerStreaming = streaming
  106. def _produce():
  107. if self._producer:
  108. self._producer.resumeProducing()
  109. self._reactor.callLater(0.1, _produce)
  110. if not streaming:
  111. self._reactor.callLater(0.0, _produce)
  112. def unregisterProducer(self):
  113. if self._producer is None:
  114. return
  115. self._producer = None
  116. def requestDone(self, _self):
  117. self.result["done"] = True
  118. def getPeer(self):
  119. # We give an address so that getClientIP returns a non null entry,
  120. # causing us to record the MAU
  121. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  122. def getHost(self):
  123. return None
  124. @property
  125. def transport(self):
  126. return self
  127. def getPeerCertificate(self):
  128. """Returns the hardcoded TLS certificate for fake.server."""
  129. return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
  130. class FakeSite:
  131. """A fake Twisted Web Site."""
  132. pass
  133. def make_request(
  134. reactor,
  135. method,
  136. path,
  137. content=b"",
  138. access_token=None,
  139. request=Request,
  140. shorthand=True,
  141. federation_auth_origin=None,
  142. ):
  143. """
  144. Make a web request using the given method and path, feed it the
  145. content, and return the Request and the Channel underneath. Mostly
  146. Args:
  147. reactor (IReactor): The Twisted reactor to use when performing the request.
  148. method (bytes or unicode): The HTTP request method ("verb").
  149. path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
  150. escaped UTF-8 & spaces and such).
  151. content (str or dict): The body of the request. JSON-encoded, if
  152. a dict.
  153. access_token (unicode): An access token to use to authenticate the request,
  154. None if no access token needs to be included.
  155. request (IRequest): The class to use when instantiating the request object.
  156. shorthand: Whether to try and be helpful and prefix the given URL
  157. with the usual REST API path, if it doesn't contain it.
  158. federation_auth_origin (bytes|None): if set to not-None, we will add a fake
  159. Authorization header pretenting to be the given server name.
  160. Returns:
  161. Tuple[synapse.http.site.SynapseRequest, channel]
  162. """
  163. if not isinstance(method, bytes):
  164. method = method.encode("ascii")
  165. if not isinstance(path, bytes):
  166. path = path.encode("ascii")
  167. # Decorate it to be the full path, if we're using shorthand
  168. if shorthand and not path.startswith(b"/_matrix"):
  169. path = b"/_matrix/identity/v2/" + path
  170. path = path.replace(b"//", b"/")
  171. if not path.startswith(b"/"):
  172. path = b"/" + path
  173. if isinstance(content, dict):
  174. content = json.dumps(content)
  175. if isinstance(content, text_type):
  176. content = content.encode("utf8")
  177. site = FakeSite()
  178. channel = FakeChannel(site, reactor)
  179. req = request(channel)
  180. req.process = lambda: b""
  181. req.content = BytesIO(content)
  182. req.postpath = list(map(unquote, path[1:].split(b"/")))
  183. if access_token:
  184. req.requestHeaders.addRawHeader(
  185. b"Authorization", b"Bearer " + access_token.encode("ascii")
  186. )
  187. if federation_auth_origin is not None:
  188. req.requestHeaders.addRawHeader(
  189. b"Authorization",
  190. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  191. )
  192. if content:
  193. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  194. req.requestReceived(method, path, b"1.1")
  195. return req, channel
  196. class ToTwistedHandler(logging.Handler):
  197. """logging handler which sends the logs to the twisted log"""
  198. tx_log = twisted.logger.Logger()
  199. def emit(self, record):
  200. log_entry = self.format(record)
  201. log_level = record.levelname.lower().replace("warning", "warn")
  202. self.tx_log.emit(
  203. twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
  204. )
  205. def setup_logging():
  206. """Configure the python logging appropriately for the tests.
  207. (Logs will end up in _trial_temp.)
  208. """
  209. root_logger = logging.getLogger()
  210. log_format = "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s" " - %(message)s"
  211. handler = ToTwistedHandler()
  212. formatter = logging.Formatter(log_format)
  213. handler.setFormatter(formatter)
  214. root_logger.addHandler(handler)
  215. log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
  216. root_logger.setLevel(log_level)
  217. setup_logging()
  218. @implementer(IReactorPluggableNameResolver)
  219. class ResolvingMemoryReactorClock(MemoryReactorClock):
  220. """
  221. A MemoryReactorClock that supports name resolution.
  222. """
  223. def __init__(self):
  224. lookups = self.lookups = {} # type: Dict[str, str]
  225. @implementer(IResolverSimple)
  226. class FakeResolver:
  227. def getHostByName(self, name, timeout=None):
  228. if name not in lookups:
  229. return fail(DNSLookupError("OH NO: unknown %s" % (name,)))
  230. return succeed(lookups[name])
  231. self.nameResolver = SimpleResolverComplexifier(FakeResolver())
  232. super().__init__()
  233. def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver:
  234. raise NotImplementedError()