_base.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import logging
  16. import re
  17. from synapse.api.errors import Codes, FederationDeniedError, SynapseError
  18. from synapse.api.urls import FEDERATION_V1_PREFIX
  19. from synapse.http.servlet import parse_json_object_from_request
  20. from synapse.logging import opentracing
  21. from synapse.logging.context import run_in_background
  22. from synapse.logging.opentracing import (
  23. SynapseTags,
  24. start_active_span,
  25. start_active_span_from_request,
  26. tags,
  27. whitelisted_homeserver,
  28. )
  29. from synapse.server import HomeServer
  30. from synapse.util.ratelimitutils import FederationRateLimiter
  31. from synapse.util.stringutils import parse_and_validate_server_name
  32. logger = logging.getLogger(__name__)
  33. class AuthenticationError(SynapseError):
  34. """There was a problem authenticating the request"""
  35. class NoAuthenticationError(AuthenticationError):
  36. """The request had no authentication information"""
  37. class Authenticator:
  38. def __init__(self, hs: HomeServer):
  39. self._clock = hs.get_clock()
  40. self.keyring = hs.get_keyring()
  41. self.server_name = hs.hostname
  42. self.store = hs.get_datastore()
  43. self.federation_domain_whitelist = hs.config.federation_domain_whitelist
  44. self.notifier = hs.get_notifier()
  45. self.replication_client = None
  46. if hs.config.worker.worker_app:
  47. self.replication_client = hs.get_tcp_replication()
  48. # A method just so we can pass 'self' as the authenticator to the Servlets
  49. async def authenticate_request(self, request, content):
  50. now = self._clock.time_msec()
  51. json_request = {
  52. "method": request.method.decode("ascii"),
  53. "uri": request.uri.decode("ascii"),
  54. "destination": self.server_name,
  55. "signatures": {},
  56. }
  57. if content is not None:
  58. json_request["content"] = content
  59. origin = None
  60. auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
  61. if not auth_headers:
  62. raise NoAuthenticationError(
  63. 401, "Missing Authorization headers", Codes.UNAUTHORIZED
  64. )
  65. for auth in auth_headers:
  66. if auth.startswith(b"X-Matrix"):
  67. (origin, key, sig) = _parse_auth_header(auth)
  68. json_request["origin"] = origin
  69. json_request["signatures"].setdefault(origin, {})[key] = sig
  70. if (
  71. self.federation_domain_whitelist is not None
  72. and origin not in self.federation_domain_whitelist
  73. ):
  74. raise FederationDeniedError(origin)
  75. if origin is None or not json_request["signatures"]:
  76. raise NoAuthenticationError(
  77. 401, "Missing Authorization headers", Codes.UNAUTHORIZED
  78. )
  79. await self.keyring.verify_json_for_server(
  80. origin,
  81. json_request,
  82. now,
  83. )
  84. logger.debug("Request from %s", origin)
  85. request.requester = origin
  86. # If we get a valid signed request from the other side, its probably
  87. # alive
  88. retry_timings = await self.store.get_destination_retry_timings(origin)
  89. if retry_timings and retry_timings.retry_last_ts:
  90. run_in_background(self._reset_retry_timings, origin)
  91. return origin
  92. async def _reset_retry_timings(self, origin):
  93. try:
  94. logger.info("Marking origin %r as up", origin)
  95. await self.store.set_destination_retry_timings(origin, None, 0, 0)
  96. # Inform the relevant places that the remote server is back up.
  97. self.notifier.notify_remote_server_up(origin)
  98. if self.replication_client:
  99. # If we're on a worker we try and inform master about this. The
  100. # replication client doesn't hook into the notifier to avoid
  101. # infinite loops where we send a `REMOTE_SERVER_UP` command to
  102. # master, which then echoes it back to us which in turn pokes
  103. # the notifier.
  104. self.replication_client.send_remote_server_up(origin)
  105. except Exception:
  106. logger.exception("Error resetting retry timings on %s", origin)
  107. def _parse_auth_header(header_bytes):
  108. """Parse an X-Matrix auth header
  109. Args:
  110. header_bytes (bytes): header value
  111. Returns:
  112. Tuple[str, str, str]: origin, key id, signature.
  113. Raises:
  114. AuthenticationError if the header could not be parsed
  115. """
  116. try:
  117. header_str = header_bytes.decode("utf-8")
  118. params = header_str.split(" ")[1].split(",")
  119. param_dict = dict(kv.split("=") for kv in params)
  120. def strip_quotes(value):
  121. if value.startswith('"'):
  122. return value[1:-1]
  123. else:
  124. return value
  125. origin = strip_quotes(param_dict["origin"])
  126. # ensure that the origin is a valid server name
  127. parse_and_validate_server_name(origin)
  128. key = strip_quotes(param_dict["key"])
  129. sig = strip_quotes(param_dict["sig"])
  130. return origin, key, sig
  131. except Exception as e:
  132. logger.warning(
  133. "Error parsing auth header '%s': %s",
  134. header_bytes.decode("ascii", "replace"),
  135. e,
  136. )
  137. raise AuthenticationError(
  138. 400, "Malformed Authorization header", Codes.UNAUTHORIZED
  139. )
  140. class BaseFederationServlet:
  141. """Abstract base class for federation servlet classes.
  142. The servlet object should have a PATH attribute which takes the form of a regexp to
  143. match against the request path (excluding the /federation/v1 prefix).
  144. The servlet should also implement one or more of on_GET, on_POST, on_PUT, to match
  145. the appropriate HTTP method. These methods must be *asynchronous* and have the
  146. signature:
  147. on_<METHOD>(self, origin, content, query, **kwargs)
  148. With arguments:
  149. origin (unicode|None): The authenticated server_name of the calling server,
  150. unless REQUIRE_AUTH is set to False and authentication failed.
  151. content (unicode|None): decoded json body of the request. None if the
  152. request was a GET.
  153. query (dict[bytes, list[bytes]]): Query params from the request. url-decoded
  154. (ie, '+' and '%xx' are decoded) but note that it is *not* utf8-decoded
  155. yet.
  156. **kwargs (dict[unicode, unicode]): the dict mapping keys to path
  157. components as specified in the path match regexp.
  158. Returns:
  159. Optional[Tuple[int, object]]: either (response code, response object) to
  160. return a JSON response, or None if the request has already been handled.
  161. Raises:
  162. SynapseError: to return an error code
  163. Exception: other exceptions will be caught, logged, and a 500 will be
  164. returned.
  165. """
  166. PATH = "" # Overridden in subclasses, the regex to match against the path.
  167. REQUIRE_AUTH = True
  168. PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version
  169. RATELIMIT = True # Whether to rate limit requests or not
  170. def __init__(
  171. self,
  172. hs: HomeServer,
  173. authenticator: Authenticator,
  174. ratelimiter: FederationRateLimiter,
  175. server_name: str,
  176. ):
  177. self.hs = hs
  178. self.authenticator = authenticator
  179. self.ratelimiter = ratelimiter
  180. self.server_name = server_name
  181. def _wrap(self, func):
  182. authenticator = self.authenticator
  183. ratelimiter = self.ratelimiter
  184. @functools.wraps(func)
  185. async def new_func(request, *args, **kwargs):
  186. """A callback which can be passed to HttpServer.RegisterPaths
  187. Args:
  188. request (twisted.web.http.Request):
  189. *args: unused?
  190. **kwargs (dict[unicode, unicode]): the dict mapping keys to path
  191. components as specified in the path match regexp.
  192. Returns:
  193. Tuple[int, object]|None: (response code, response object) as returned by
  194. the callback method. None if the request has already been handled.
  195. """
  196. content = None
  197. if request.method in [b"PUT", b"POST"]:
  198. # TODO: Handle other method types? other content types?
  199. content = parse_json_object_from_request(request)
  200. try:
  201. origin = await authenticator.authenticate_request(request, content)
  202. except NoAuthenticationError:
  203. origin = None
  204. if self.REQUIRE_AUTH:
  205. logger.warning(
  206. "authenticate_request failed: missing authentication"
  207. )
  208. raise
  209. except Exception as e:
  210. logger.warning("authenticate_request failed: %s", e)
  211. raise
  212. request_tags = {
  213. SynapseTags.REQUEST_ID: request.get_request_id(),
  214. tags.SPAN_KIND: tags.SPAN_KIND_RPC_SERVER,
  215. tags.HTTP_METHOD: request.get_method(),
  216. tags.HTTP_URL: request.get_redacted_uri(),
  217. tags.PEER_HOST_IPV6: request.getClientIP(),
  218. "authenticated_entity": origin,
  219. "servlet_name": request.request_metrics.name,
  220. }
  221. # Only accept the span context if the origin is authenticated
  222. # and whitelisted
  223. if origin and whitelisted_homeserver(origin):
  224. scope = start_active_span_from_request(
  225. request, "incoming-federation-request", tags=request_tags
  226. )
  227. else:
  228. scope = start_active_span(
  229. "incoming-federation-request", tags=request_tags
  230. )
  231. with scope:
  232. opentracing.inject_response_headers(request.responseHeaders)
  233. if origin and self.RATELIMIT:
  234. with ratelimiter.ratelimit(origin) as d:
  235. await d
  236. if request._disconnected:
  237. logger.warning(
  238. "client disconnected before we started processing "
  239. "request"
  240. )
  241. return -1, None
  242. response = await func(
  243. origin, content, request.args, *args, **kwargs
  244. )
  245. else:
  246. response = await func(
  247. origin, content, request.args, *args, **kwargs
  248. )
  249. return response
  250. return new_func
  251. def register(self, server):
  252. pattern = re.compile("^" + self.PREFIX + self.PATH + "$")
  253. for method in ("GET", "PUT", "POST"):
  254. code = getattr(self, "on_%s" % (method), None)
  255. if code is None:
  256. continue
  257. server.register_paths(
  258. method,
  259. (pattern,),
  260. self._wrap(code),
  261. self.__class__.__name__,
  262. )