keyclient.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. from six.moves import urllib
  17. from canonicaljson import json
  18. from twisted.internet import defer, reactor
  19. from twisted.internet.error import ConnectError
  20. from twisted.internet.protocol import Factory
  21. from twisted.names.error import DomainError
  22. from twisted.web.http import HTTPClient
  23. from synapse.http.endpoint import matrix_federation_endpoint
  24. from synapse.util import logcontext
  25. logger = logging.getLogger(__name__)
  26. KEY_API_V2 = "/_matrix/key/v2/server/%s"
  27. @defer.inlineCallbacks
  28. def fetch_server_key(server_name, tls_client_options_factory, key_id):
  29. """Fetch the keys for a remote server."""
  30. factory = SynapseKeyClientFactory()
  31. factory.path = KEY_API_V2 % (urllib.parse.quote(key_id), )
  32. factory.host = server_name
  33. endpoint = matrix_federation_endpoint(
  34. reactor, server_name, tls_client_options_factory, timeout=30
  35. )
  36. for i in range(5):
  37. try:
  38. with logcontext.PreserveLoggingContext():
  39. protocol = yield endpoint.connect(factory)
  40. server_response, server_certificate = yield protocol.remote_key
  41. defer.returnValue((server_response, server_certificate))
  42. except SynapseKeyClientError as e:
  43. logger.warn("Error getting key for %r: %s", server_name, e)
  44. if e.status.startswith(b"4"):
  45. # Don't retry for 4xx responses.
  46. raise IOError("Cannot get key for %r" % server_name)
  47. except (ConnectError, DomainError) as e:
  48. logger.warn("Error getting key for %r: %s", server_name, e)
  49. except Exception:
  50. logger.exception("Error getting key for %r", server_name)
  51. raise IOError("Cannot get key for %r" % server_name)
  52. class SynapseKeyClientError(Exception):
  53. """The key wasn't retrieved from the remote server."""
  54. status = None
  55. pass
  56. class SynapseKeyClientProtocol(HTTPClient):
  57. """Low level HTTPS client which retrieves an application/json response from
  58. the server and extracts the X.509 certificate for the remote peer from the
  59. SSL connection."""
  60. timeout = 30
  61. def __init__(self):
  62. self.remote_key = defer.Deferred()
  63. self.host = None
  64. self._peer = None
  65. def connectionMade(self):
  66. self._peer = self.transport.getPeer()
  67. logger.debug("Connected to %s", self._peer)
  68. if not isinstance(self.path, bytes):
  69. self.path = self.path.encode('ascii')
  70. if not isinstance(self.host, bytes):
  71. self.host = self.host.encode('ascii')
  72. self.sendCommand(b"GET", self.path)
  73. if self.host:
  74. self.sendHeader(b"Host", self.host)
  75. self.endHeaders()
  76. self.timer = reactor.callLater(
  77. self.timeout,
  78. self.on_timeout
  79. )
  80. def errback(self, error):
  81. if not self.remote_key.called:
  82. self.remote_key.errback(error)
  83. def callback(self, result):
  84. if not self.remote_key.called:
  85. self.remote_key.callback(result)
  86. def handleStatus(self, version, status, message):
  87. if status != b"200":
  88. # logger.info("Non-200 response from %s: %s %s",
  89. # self.transport.getHost(), status, message)
  90. error = SynapseKeyClientError(
  91. "Non-200 response %r from %r" % (status, self.host)
  92. )
  93. error.status = status
  94. self.errback(error)
  95. self.transport.abortConnection()
  96. def handleResponse(self, response_body_bytes):
  97. try:
  98. json_response = json.loads(response_body_bytes)
  99. except ValueError:
  100. # logger.info("Invalid JSON response from %s",
  101. # self.transport.getHost())
  102. self.transport.abortConnection()
  103. return
  104. certificate = self.transport.getPeerCertificate()
  105. self.callback((json_response, certificate))
  106. self.transport.abortConnection()
  107. self.timer.cancel()
  108. def on_timeout(self):
  109. logger.debug(
  110. "Timeout waiting for response from %s: %s",
  111. self.host, self._peer,
  112. )
  113. self.errback(IOError("Timeout waiting for response"))
  114. self.transport.abortConnection()
  115. class SynapseKeyClientFactory(Factory):
  116. def protocol(self):
  117. protocol = SynapseKeyClientProtocol()
  118. protocol.path = self.path
  119. protocol.host = self.host
  120. return protocol