utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. import json
  2. from io import BytesIO
  3. import logging
  4. import os
  5. import attr
  6. from six import text_type
  7. from twisted.internet import address
  8. import twisted.logger
  9. from twisted.web.http_headers import Headers
  10. from twisted.web.server import Request, Site
  11. from twisted.web.http import unquote
  12. from twisted.test.proto_helpers import MemoryReactorClock
  13. from OpenSSL import crypto
  14. from sydent.sydent import Sydent, parse_config_dict
  15. # Expires on Jan 11 2030 at 17:53:40 GMT
  16. FAKE_SERVER_CERT_PEM = """
  17. -----BEGIN CERTIFICATE-----
  18. MIIDlzCCAn+gAwIBAgIUC8tnJVZ8Cawh5tqr7PCAOfvyGTYwDQYJKoZIhvcNAQEL
  19. BQAwWzELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM
  20. GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDEUMBIGA1UEAwwLZmFrZS5zZXJ2ZXIw
  21. HhcNMjAwMTE0MTc1MzQwWhcNMzAwMTExMTc1MzQwWjBbMQswCQYDVQQGEwJBVTET
  22. MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ
  23. dHkgTHRkMRQwEgYDVQQDDAtmYWtlLnNlcnZlcjCCASIwDQYJKoZIhvcNAQEBBQAD
  24. ggEPADCCAQoCggEBANNzY7YHBLm4uj52ojQc/dfQCoR+63IgjxZ6QdnThhIlOYgE
  25. 3y0Ks49bt3GKmAweOFRRKfDhJRKCYfqZTYudMcdsQg696s2HhiTY0SpqO0soXwW4
  26. 6kEIxnTy2TqkPjWlsWgGTtbVnKc5pnLs7MaQwLIQfxirqD2znn+9r68WMOJRlzkv
  27. VmrXDXjxKPANJJ9b0PiGrL2SF4QcF3zHk8Tjf24OGRX4JTNwiGraU/VN9rrqSHug
  28. CLWcfZ1mvcav3scvtGfgm4kxcw8K6heiQAc3QAMWIrdWhiunaWpQYgw7euS8lZ/O
  29. C7HZ7YbdoldknWdK8o7HJZmxUP9yW9Pqa3n8p9UCAwEAAaNTMFEwHQYDVR0OBBYE
  30. FHwfTq0Mdk9YKqjyfdYm4v9zRP8nMB8GA1UdIwQYMBaAFHwfTq0Mdk9YKqjyfdYm
  31. 4v9zRP8nMA8GA1UdEwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBAEPVM5/+
  32. Sj9P/CvNG7F2PxlDQC1/+aVl6ARAz/bZmm7yJnWEleBSwwFLerEQU6KFrgjA243L
  33. qgY6Qf2EYUn1O9jroDg/IumlcQU1H4DXZ03YLKS2bXFGj630Piao547/l4/PaKOP
  34. wSvwDcJlBatKfwjMVl3Al/EcAgUJL8eVosnqHDSINdBuFEc8Kw4LnDSFoTEIx19i
  35. c+DKmtnJNI68wNydLJ3lhSaj4pmsX4PsRqsRzw+jgkPXIG1oGlUDMO3k7UwxfYKR
  36. XkU5mFYkohPTgxv5oYGq2FCOPixkbov7geCEvEUs8m8c8MAm4ErBUzemOAj8KVhE
  37. tWVEpHfT+G7AjA8=
  38. -----END CERTIFICATE-----
  39. """
  40. def make_sydent(test_config={}):
  41. """Create a new sydent
  42. Args:
  43. test_config (dict): any configuration variables for overriding the default sydent
  44. config
  45. """
  46. # Use an in-memory SQLite database. Note that the database isn't cleaned up between
  47. # tests, so by default the same database will be used for each test if changed to be
  48. # a file on disk.
  49. if 'db' not in test_config:
  50. test_config['db'] = {'db.file': ':memory:'}
  51. else:
  52. test_config['db'].setdefault('db.file', ':memory:')
  53. reactor = MemoryReactorClock()
  54. return Sydent(reactor=reactor, cfg=parse_config_dict(test_config))
  55. @attr.s
  56. class FakeChannel(object):
  57. """
  58. A fake Twisted Web Channel (the part that interfaces with the
  59. wire). Mostly copied from Synapse's tests framework.
  60. """
  61. site = attr.ib(type=Site)
  62. _reactor = attr.ib()
  63. result = attr.ib(default=attr.Factory(dict))
  64. _producer = None
  65. @property
  66. def json_body(self):
  67. if not self.result:
  68. raise Exception("No result yet.")
  69. return json.loads(self.result["body"].decode("utf8"))
  70. @property
  71. def code(self):
  72. if not self.result:
  73. raise Exception("No result yet.")
  74. return int(self.result["code"])
  75. @property
  76. def headers(self):
  77. if not self.result:
  78. raise Exception("No result yet.")
  79. h = Headers()
  80. for i in self.result["headers"]:
  81. h.addRawHeader(*i)
  82. return h
  83. def writeHeaders(self, version, code, reason, headers):
  84. self.result["version"] = version
  85. self.result["code"] = code
  86. self.result["reason"] = reason
  87. self.result["headers"] = headers
  88. def write(self, content):
  89. assert isinstance(content, bytes), "Should be bytes! " + repr(content)
  90. if "body" not in self.result:
  91. self.result["body"] = b""
  92. self.result["body"] += content
  93. def registerProducer(self, producer, streaming):
  94. self._producer = producer
  95. self.producerStreaming = streaming
  96. def _produce():
  97. if self._producer:
  98. self._producer.resumeProducing()
  99. self._reactor.callLater(0.1, _produce)
  100. if not streaming:
  101. self._reactor.callLater(0.0, _produce)
  102. def unregisterProducer(self):
  103. if self._producer is None:
  104. return
  105. self._producer = None
  106. def requestDone(self, _self):
  107. self.result["done"] = True
  108. def getPeer(self):
  109. # We give an address so that getClientIP returns a non null entry,
  110. # causing us to record the MAU
  111. return address.IPv4Address("TCP", "127.0.0.1", 3423)
  112. def getHost(self):
  113. return None
  114. @property
  115. def transport(self):
  116. return self
  117. def getPeerCertificate(self):
  118. """Returns the hardcoded TLS certificate for fake.server."""
  119. return crypto.load_certificate(crypto.FILETYPE_PEM, FAKE_SERVER_CERT_PEM)
  120. class FakeSite:
  121. """A fake Twisted Web Site."""
  122. pass
  123. def make_request(
  124. reactor,
  125. method,
  126. path,
  127. content=b"",
  128. access_token=None,
  129. request=Request,
  130. shorthand=True,
  131. federation_auth_origin=None,
  132. ):
  133. """
  134. Make a web request using the given method and path, feed it the
  135. content, and return the Request and the Channel underneath. Mostly
  136. Args:
  137. reactor (IReactor): The Twisted reactor to use when performing the request.
  138. method (bytes or unicode): The HTTP request method ("verb").
  139. path (bytes or unicode): The HTTP path, suitably URL encoded (e.g.
  140. escaped UTF-8 & spaces and such).
  141. content (str or dict): The body of the request. JSON-encoded, if
  142. a dict.
  143. access_token (unicode): An access token to use to authenticate the request,
  144. None if no access token needs to be included.
  145. request (IRequest): The class to use when instantiating the request object.
  146. shorthand: Whether to try and be helpful and prefix the given URL
  147. with the usual REST API path, if it doesn't contain it.
  148. federation_auth_origin (bytes|None): if set to not-None, we will add a fake
  149. Authorization header pretenting to be the given server name.
  150. Returns:
  151. Tuple[synapse.http.site.SynapseRequest, channel]
  152. """
  153. if not isinstance(method, bytes):
  154. method = method.encode("ascii")
  155. if not isinstance(path, bytes):
  156. path = path.encode("ascii")
  157. # Decorate it to be the full path, if we're using shorthand
  158. if (
  159. shorthand
  160. and not path.startswith(b"/_matrix")
  161. ):
  162. path = b"/_matrix/identity/v2/" + path
  163. path = path.replace(b"//", b"/")
  164. if not path.startswith(b"/"):
  165. path = b"/" + path
  166. if isinstance(content, dict):
  167. content = json.dumps(content)
  168. if isinstance(content, text_type):
  169. content = content.encode("utf8")
  170. site = FakeSite()
  171. channel = FakeChannel(site, reactor)
  172. req = request(channel)
  173. req.process = lambda: b""
  174. req.content = BytesIO(content)
  175. req.postpath = list(map(unquote, path[1:].split(b"/")))
  176. if access_token:
  177. req.requestHeaders.addRawHeader(
  178. b"Authorization", b"Bearer " + access_token.encode("ascii")
  179. )
  180. if federation_auth_origin is not None:
  181. req.requestHeaders.addRawHeader(
  182. b"Authorization",
  183. b"X-Matrix origin=%s,key=,sig=" % (federation_auth_origin,),
  184. )
  185. if content:
  186. req.requestHeaders.addRawHeader(b"Content-Type", b"application/json")
  187. req.requestReceived(method, path, b"1.1")
  188. return req, channel
  189. class ToTwistedHandler(logging.Handler):
  190. """logging handler which sends the logs to the twisted log"""
  191. tx_log = twisted.logger.Logger()
  192. def emit(self, record):
  193. log_entry = self.format(record)
  194. log_level = record.levelname.lower().replace("warning", "warn")
  195. self.tx_log.emit(
  196. twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
  197. )
  198. def setup_logging():
  199. """Configure the python logging appropriately for the tests.
  200. (Logs will end up in _trial_temp.)
  201. """
  202. root_logger = logging.getLogger()
  203. log_format = (
  204. "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s"
  205. " - %(message)s"
  206. )
  207. handler = ToTwistedHandler()
  208. formatter = logging.Formatter(log_format)
  209. handler.setFormatter(formatter)
  210. root_logger.addHandler(handler)
  211. log_level = os.environ.get("SYDENT_TEST_LOG_LEVEL", "ERROR")
  212. root_logger.setLevel(log_level)
  213. setup_logging()