client.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import urllib.parse
  17. from http import HTTPStatus
  18. from io import BytesIO
  19. from typing import (
  20. TYPE_CHECKING,
  21. Any,
  22. BinaryIO,
  23. Callable,
  24. Dict,
  25. Iterable,
  26. List,
  27. Mapping,
  28. Optional,
  29. Sequence,
  30. Tuple,
  31. Union,
  32. )
  33. import treq
  34. from canonicaljson import encode_canonical_json
  35. from netaddr import AddrFormatError, IPAddress, IPSet
  36. from prometheus_client import Counter
  37. from typing_extensions import Protocol
  38. from zope.interface import implementer, provider
  39. from OpenSSL import SSL
  40. from OpenSSL.SSL import VERIFY_NONE
  41. from twisted.internet import defer, error as twisted_error, protocol, ssl
  42. from twisted.internet.address import IPv4Address, IPv6Address
  43. from twisted.internet.interfaces import (
  44. IAddress,
  45. IHostResolution,
  46. IReactorPluggableNameResolver,
  47. IResolutionReceiver,
  48. ITCPTransport,
  49. )
  50. from twisted.internet.protocol import connectionDone
  51. from twisted.internet.task import Cooperator
  52. from twisted.python.failure import Failure
  53. from twisted.web._newclient import ResponseDone
  54. from twisted.web.client import (
  55. Agent,
  56. HTTPConnectionPool,
  57. ResponseNeverReceived,
  58. readBody,
  59. )
  60. from twisted.web.http import PotentialDataLoss
  61. from twisted.web.http_headers import Headers
  62. from twisted.web.iweb import (
  63. UNKNOWN_LENGTH,
  64. IAgent,
  65. IBodyProducer,
  66. IPolicyForHTTPS,
  67. IResponse,
  68. )
  69. from synapse.api.errors import Codes, HttpResponseException, SynapseError
  70. from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
  71. from synapse.http.proxyagent import ProxyAgent
  72. from synapse.logging.context import make_deferred_yieldable
  73. from synapse.logging.opentracing import set_tag, start_active_span, tags
  74. from synapse.types import ISynapseReactor
  75. from synapse.util import json_decoder
  76. from synapse.util.async_helpers import timeout_deferred
  77. if TYPE_CHECKING:
  78. from synapse.server import HomeServer
  79. logger = logging.getLogger(__name__)
  80. outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
  81. incoming_responses_counter = Counter(
  82. "synapse_http_client_responses", "", ["method", "code"]
  83. )
  84. # the type of the headers list, to be passed to the t.w.h.Headers.
  85. # Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
  86. # we simplify.
  87. RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
  88. # the value actually has to be a List, but List is invariant so we can't specify that
  89. # the entries can either be Lists or bytes.
  90. RawHeaderValue = Sequence[Union[str, bytes]]
  91. # the type of the query params, to be passed into `urlencode`
  92. QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]]
  93. QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]
  94. def check_against_blacklist(
  95. ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
  96. ) -> bool:
  97. """
  98. Compares an IP address to allowed and disallowed IP sets.
  99. Args:
  100. ip_address: The IP address to check
  101. ip_whitelist: Allowed IP addresses.
  102. ip_blacklist: Disallowed IP addresses.
  103. Returns:
  104. True if the IP address is in the blacklist and not in the whitelist.
  105. """
  106. if ip_address in ip_blacklist:
  107. if ip_whitelist is None or ip_address not in ip_whitelist:
  108. return True
  109. return False
  110. _EPSILON = 0.00000001
  111. def _make_scheduler(reactor):
  112. """Makes a schedular suitable for a Cooperator using the given reactor.
  113. (This is effectively just a copy from `twisted.internet.task`)
  114. """
  115. def _scheduler(x):
  116. return reactor.callLater(_EPSILON, x)
  117. return _scheduler
  118. class _IPBlacklistingResolver:
  119. """
  120. A proxy for reactor.nameResolver which only produces non-blacklisted IP
  121. addresses, preventing DNS rebinding attacks on URL preview.
  122. """
  123. def __init__(
  124. self,
  125. reactor: IReactorPluggableNameResolver,
  126. ip_whitelist: Optional[IPSet],
  127. ip_blacklist: IPSet,
  128. ):
  129. """
  130. Args:
  131. reactor: The twisted reactor.
  132. ip_whitelist: IP addresses to allow.
  133. ip_blacklist: IP addresses to disallow.
  134. """
  135. self._reactor = reactor
  136. self._ip_whitelist = ip_whitelist
  137. self._ip_blacklist = ip_blacklist
  138. def resolveHostName(
  139. self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
  140. ) -> IResolutionReceiver:
  141. addresses: List[IAddress] = []
  142. def _callback() -> None:
  143. has_bad_ip = False
  144. for address in addresses:
  145. # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
  146. # should go through this path.
  147. if not isinstance(address, (IPv4Address, IPv6Address)):
  148. continue
  149. ip_address = IPAddress(address.host)
  150. if check_against_blacklist(
  151. ip_address, self._ip_whitelist, self._ip_blacklist
  152. ):
  153. logger.info(
  154. "Dropped %s from DNS resolution to %s due to blacklist"
  155. % (ip_address, hostname)
  156. )
  157. has_bad_ip = True
  158. # if we have a blacklisted IP, we'd like to raise an error to block the
  159. # request, but all we can really do from here is claim that there were no
  160. # valid results.
  161. if not has_bad_ip:
  162. for address in addresses:
  163. recv.addressResolved(address)
  164. recv.resolutionComplete()
  165. @provider(IResolutionReceiver)
  166. class EndpointReceiver:
  167. @staticmethod
  168. def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
  169. recv.resolutionBegan(resolutionInProgress)
  170. @staticmethod
  171. def addressResolved(address: IAddress) -> None:
  172. addresses.append(address)
  173. @staticmethod
  174. def resolutionComplete() -> None:
  175. _callback()
  176. self._reactor.nameResolver.resolveHostName(
  177. EndpointReceiver, hostname, portNumber=portNumber
  178. )
  179. return recv
  180. @implementer(ISynapseReactor)
  181. class BlacklistingReactorWrapper:
  182. """
  183. A Reactor wrapper which will prevent DNS resolution to blacklisted IP
  184. addresses, to prevent DNS rebinding.
  185. """
  186. def __init__(
  187. self,
  188. reactor: IReactorPluggableNameResolver,
  189. ip_whitelist: Optional[IPSet],
  190. ip_blacklist: IPSet,
  191. ):
  192. self._reactor = reactor
  193. # We need to use a DNS resolver which filters out blacklisted IP
  194. # addresses, to prevent DNS rebinding.
  195. self._nameResolver = _IPBlacklistingResolver(
  196. self._reactor, ip_whitelist, ip_blacklist
  197. )
  198. def __getattr__(self, attr: str) -> Any:
  199. # Passthrough to the real reactor except for the DNS resolver.
  200. if attr == "nameResolver":
  201. return self._nameResolver
  202. else:
  203. return getattr(self._reactor, attr)
  204. class BlacklistingAgentWrapper(Agent):
  205. """
  206. An Agent wrapper which will prevent access to IP addresses being accessed
  207. directly (without an IP address lookup).
  208. """
  209. def __init__(
  210. self,
  211. agent: IAgent,
  212. ip_whitelist: Optional[IPSet] = None,
  213. ip_blacklist: Optional[IPSet] = None,
  214. ):
  215. """
  216. Args:
  217. agent: The Agent to wrap.
  218. ip_whitelist: IP addresses to allow.
  219. ip_blacklist: IP addresses to disallow.
  220. """
  221. self._agent = agent
  222. self._ip_whitelist = ip_whitelist
  223. self._ip_blacklist = ip_blacklist
  224. def request(
  225. self,
  226. method: bytes,
  227. uri: bytes,
  228. headers: Optional[Headers] = None,
  229. bodyProducer: Optional[IBodyProducer] = None,
  230. ) -> defer.Deferred:
  231. h = urllib.parse.urlparse(uri.decode("ascii"))
  232. try:
  233. ip_address = IPAddress(h.hostname)
  234. except AddrFormatError:
  235. # Not an IP
  236. pass
  237. else:
  238. if check_against_blacklist(
  239. ip_address, self._ip_whitelist, self._ip_blacklist
  240. ):
  241. logger.info("Blocking access to %s due to blacklist" % (ip_address,))
  242. e = SynapseError(
  243. HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
  244. )
  245. return defer.fail(Failure(e))
  246. return self._agent.request(
  247. method, uri, headers=headers, bodyProducer=bodyProducer
  248. )
  249. class SimpleHttpClient:
  250. """
  251. A simple, no-frills HTTP client with methods that wrap up common ways of
  252. using HTTP in Matrix
  253. """
  254. def __init__(
  255. self,
  256. hs: "HomeServer",
  257. treq_args: Optional[Dict[str, Any]] = None,
  258. ip_whitelist: Optional[IPSet] = None,
  259. ip_blacklist: Optional[IPSet] = None,
  260. use_proxy: bool = False,
  261. ):
  262. """
  263. Args:
  264. hs
  265. treq_args: Extra keyword arguments to be given to treq.request.
  266. ip_blacklist: The IP addresses that are blacklisted that
  267. we may not request.
  268. ip_whitelist: The whitelisted IP addresses, that we can
  269. request if it were otherwise caught in a blacklist.
  270. use_proxy: Whether proxy settings should be discovered and used
  271. from conventional environment variables.
  272. """
  273. self.hs = hs
  274. self._ip_whitelist = ip_whitelist
  275. self._ip_blacklist = ip_blacklist
  276. self._extra_treq_args = treq_args or {}
  277. self.clock = hs.get_clock()
  278. user_agent = hs.version_string
  279. if hs.config.server.user_agent_suffix:
  280. user_agent = "%s %s" % (
  281. user_agent,
  282. hs.config.server.user_agent_suffix,
  283. )
  284. self.user_agent = user_agent.encode("ascii")
  285. # We use this for our body producers to ensure that they use the correct
  286. # reactor.
  287. self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
  288. if self._ip_blacklist:
  289. # If we have an IP blacklist, we need to use a DNS resolver which
  290. # filters out blacklisted IP addresses, to prevent DNS rebinding.
  291. self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
  292. hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
  293. )
  294. else:
  295. self.reactor = hs.get_reactor()
  296. # the pusher makes lots of concurrent SSL connections to sygnal, and
  297. # tends to do so in batches, so we need to allow the pool to keep
  298. # lots of idle connections around.
  299. pool = HTTPConnectionPool(self.reactor)
  300. # XXX: The justification for using the cache factor here is that larger instances
  301. # will need both more cache and more connections.
  302. # Still, this should probably be a separate dial
  303. pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
  304. pool.cachedConnectionTimeout = 2 * 60
  305. self.agent: IAgent = ProxyAgent(
  306. self.reactor,
  307. hs.get_reactor(),
  308. connectTimeout=15,
  309. contextFactory=self.hs.get_http_client_context_factory(),
  310. pool=pool,
  311. use_proxy=use_proxy,
  312. )
  313. if self._ip_blacklist:
  314. # If we have an IP blacklist, we then install the blacklisting Agent
  315. # which prevents direct access to IP addresses, that are not caught
  316. # by the DNS resolution.
  317. self.agent = BlacklistingAgentWrapper(
  318. self.agent,
  319. ip_whitelist=self._ip_whitelist,
  320. ip_blacklist=self._ip_blacklist,
  321. )
  322. async def request(
  323. self,
  324. method: str,
  325. uri: str,
  326. data: Optional[bytes] = None,
  327. headers: Optional[Headers] = None,
  328. ) -> IResponse:
  329. """
  330. Args:
  331. method: HTTP method to use.
  332. uri: URI to query.
  333. data: Data to send in the request body, if applicable.
  334. headers: Request headers.
  335. Returns:
  336. Response object, once the headers have been read.
  337. Raises:
  338. RequestTimedOutError if the request times out before the headers are read
  339. """
  340. outgoing_requests_counter.labels(method).inc()
  341. # log request but strip `access_token` (AS requests for example include this)
  342. logger.debug("Sending request %s %s", method, redact_uri(uri))
  343. with start_active_span(
  344. "outgoing-client-request",
  345. tags={
  346. tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
  347. tags.HTTP_METHOD: method,
  348. tags.HTTP_URL: uri,
  349. },
  350. finish_on_close=True,
  351. ):
  352. try:
  353. body_producer = None
  354. if data is not None:
  355. body_producer = QuieterFileBodyProducer(
  356. BytesIO(data),
  357. cooperator=self._cooperator,
  358. )
  359. request_deferred: defer.Deferred = treq.request(
  360. method,
  361. uri,
  362. agent=self.agent,
  363. data=body_producer,
  364. headers=headers,
  365. # Avoid buffering the body in treq since we do not reuse
  366. # response bodies.
  367. unbuffered=True,
  368. **self._extra_treq_args,
  369. )
  370. # we use our own timeout mechanism rather than treq's as a workaround
  371. # for https://twistedmatrix.com/trac/ticket/9534.
  372. request_deferred = timeout_deferred(
  373. request_deferred,
  374. 60,
  375. self.hs.get_reactor(),
  376. )
  377. # turn timeouts into RequestTimedOutErrors
  378. request_deferred.addErrback(_timeout_to_request_timed_out_error)
  379. response = await make_deferred_yieldable(request_deferred)
  380. incoming_responses_counter.labels(method, response.code).inc()
  381. logger.info(
  382. "Received response to %s %s: %s",
  383. method,
  384. redact_uri(uri),
  385. response.code,
  386. )
  387. return response
  388. except Exception as e:
  389. incoming_responses_counter.labels(method, "ERR").inc()
  390. logger.info(
  391. "Error sending request to %s %s: %s %s",
  392. method,
  393. redact_uri(uri),
  394. type(e).__name__,
  395. e.args[0],
  396. )
  397. set_tag(tags.ERROR, True)
  398. set_tag("error_reason", e.args[0])
  399. raise
  400. async def post_urlencoded_get_json(
  401. self,
  402. uri: str,
  403. args: Optional[Mapping[str, Union[str, List[str]]]] = None,
  404. headers: Optional[RawHeaders] = None,
  405. ) -> Any:
  406. """
  407. Args:
  408. uri: uri to query
  409. args: parameters to be url-encoded in the body
  410. headers: a map from header name to a list of values for that header
  411. Returns:
  412. parsed json
  413. Raises:
  414. RequestTimedOutError: if there is a timeout before the response headers
  415. are received. Note there is currently no timeout on reading the response
  416. body.
  417. HttpResponseException: On a non-2xx HTTP response.
  418. ValueError: if the response was not JSON
  419. """
  420. # TODO: Do we ever want to log message contents?
  421. logger.debug("post_urlencoded_get_json args: %s", args)
  422. query_bytes = encode_query_args(args)
  423. actual_headers = {
  424. b"Content-Type": [b"application/x-www-form-urlencoded"],
  425. b"User-Agent": [self.user_agent],
  426. b"Accept": [b"application/json"],
  427. }
  428. if headers:
  429. actual_headers.update(headers) # type: ignore
  430. response = await self.request(
  431. "POST", uri, headers=Headers(actual_headers), data=query_bytes
  432. )
  433. body = await make_deferred_yieldable(readBody(response))
  434. if 200 <= response.code < 300:
  435. return json_decoder.decode(body.decode("utf-8"))
  436. else:
  437. raise HttpResponseException(
  438. response.code, response.phrase.decode("ascii", errors="replace"), body
  439. )
  440. async def post_json_get_json(
  441. self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
  442. ) -> Any:
  443. """
  444. Args:
  445. uri: URI to query.
  446. post_json: request body, to be encoded as json
  447. headers: a map from header name to a list of values for that header
  448. Returns:
  449. parsed json
  450. Raises:
  451. RequestTimedOutError: if there is a timeout before the response headers
  452. are received. Note there is currently no timeout on reading the response
  453. body.
  454. HttpResponseException: On a non-2xx HTTP response.
  455. ValueError: if the response was not JSON
  456. """
  457. json_str = encode_canonical_json(post_json)
  458. logger.debug("HTTP POST %s -> %s", json_str, uri)
  459. actual_headers = {
  460. b"Content-Type": [b"application/json"],
  461. b"User-Agent": [self.user_agent],
  462. b"Accept": [b"application/json"],
  463. }
  464. if headers:
  465. actual_headers.update(headers) # type: ignore
  466. response = await self.request(
  467. "POST", uri, headers=Headers(actual_headers), data=json_str
  468. )
  469. body = await make_deferred_yieldable(readBody(response))
  470. if 200 <= response.code < 300:
  471. return json_decoder.decode(body.decode("utf-8"))
  472. else:
  473. raise HttpResponseException(
  474. response.code, response.phrase.decode("ascii", errors="replace"), body
  475. )
  476. async def get_json(
  477. self,
  478. uri: str,
  479. args: Optional[QueryParams] = None,
  480. headers: Optional[RawHeaders] = None,
  481. ) -> Any:
  482. """Gets some json from the given URI.
  483. Args:
  484. uri: The URI to request, not including query parameters
  485. args: A dictionary used to create query string
  486. headers: a map from header name to a list of values for that header
  487. Returns:
  488. Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
  489. Raises:
  490. RequestTimedOutError: if there is a timeout before the response headers
  491. are received. Note there is currently no timeout on reading the response
  492. body.
  493. HttpResponseException On a non-2xx HTTP response.
  494. ValueError: if the response was not JSON
  495. """
  496. actual_headers = {b"Accept": [b"application/json"]}
  497. if headers:
  498. actual_headers.update(headers) # type: ignore
  499. body = await self.get_raw(uri, args, headers=actual_headers)
  500. return json_decoder.decode(body.decode("utf-8"))
  501. async def put_json(
  502. self,
  503. uri: str,
  504. json_body: Any,
  505. args: Optional[QueryParams] = None,
  506. headers: Optional[RawHeaders] = None,
  507. ) -> Any:
  508. """Puts some json to the given URI.
  509. Args:
  510. uri: The URI to request, not including query parameters
  511. json_body: The JSON to put in the HTTP body,
  512. args: A dictionary used to create query strings
  513. headers: a map from header name to a list of values for that header
  514. Returns:
  515. Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
  516. Raises:
  517. RequestTimedOutError: if there is a timeout before the response headers
  518. are received. Note there is currently no timeout on reading the response
  519. body.
  520. HttpResponseException On a non-2xx HTTP response.
  521. ValueError: if the response was not JSON
  522. """
  523. if args:
  524. query_str = urllib.parse.urlencode(args, True)
  525. uri = "%s?%s" % (uri, query_str)
  526. json_str = encode_canonical_json(json_body)
  527. actual_headers = {
  528. b"Content-Type": [b"application/json"],
  529. b"User-Agent": [self.user_agent],
  530. b"Accept": [b"application/json"],
  531. }
  532. if headers:
  533. actual_headers.update(headers) # type: ignore
  534. response = await self.request(
  535. "PUT", uri, headers=Headers(actual_headers), data=json_str
  536. )
  537. body = await make_deferred_yieldable(readBody(response))
  538. if 200 <= response.code < 300:
  539. return json_decoder.decode(body.decode("utf-8"))
  540. else:
  541. raise HttpResponseException(
  542. response.code, response.phrase.decode("ascii", errors="replace"), body
  543. )
  544. async def get_raw(
  545. self,
  546. uri: str,
  547. args: Optional[QueryParams] = None,
  548. headers: Optional[RawHeaders] = None,
  549. ) -> bytes:
  550. """Gets raw text from the given URI.
  551. Args:
  552. uri: The URI to request, not including query parameters
  553. args: A dictionary used to create query strings
  554. headers: a map from header name to a list of values for that header
  555. Returns:
  556. Succeeds when we get a 2xx HTTP response, with the
  557. HTTP body as bytes.
  558. Raises:
  559. RequestTimedOutError: if there is a timeout before the response headers
  560. are received. Note there is currently no timeout on reading the response
  561. body.
  562. HttpResponseException on a non-2xx HTTP response.
  563. """
  564. if args:
  565. query_str = urllib.parse.urlencode(args, True)
  566. uri = "%s?%s" % (uri, query_str)
  567. actual_headers = {b"User-Agent": [self.user_agent]}
  568. if headers:
  569. actual_headers.update(headers) # type: ignore
  570. response = await self.request("GET", uri, headers=Headers(actual_headers))
  571. body = await make_deferred_yieldable(readBody(response))
  572. if 200 <= response.code < 300:
  573. return body
  574. else:
  575. raise HttpResponseException(
  576. response.code, response.phrase.decode("ascii", errors="replace"), body
  577. )
  578. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
  579. # The two should be factored out.
  580. async def get_file(
  581. self,
  582. url: str,
  583. output_stream: BinaryIO,
  584. max_size: Optional[int] = None,
  585. headers: Optional[RawHeaders] = None,
  586. is_allowed_content_type: Optional[Callable[[str], bool]] = None,
  587. ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
  588. """GETs a file from a given URL
  589. Args:
  590. url: The URL to GET
  591. output_stream: File to write the response body to.
  592. headers: A map from header name to a list of values for that header
  593. is_allowed_content_type: A predicate to determine whether the
  594. content type of the file we're downloading is allowed. If set and
  595. it evaluates to False when called with the content type, the
  596. request will be terminated before completing the download by
  597. raising SynapseError.
  598. Returns:
  599. A tuple of the file length, dict of the response
  600. headers, absolute URI of the response and HTTP response code.
  601. Raises:
  602. RequestTimedOutError: if there is a timeout before the response headers
  603. are received. Note there is currently no timeout on reading the response
  604. body.
  605. SynapseError: if the response is not a 2xx, the remote file is too large, or
  606. another exception happens during the download.
  607. """
  608. actual_headers = {b"User-Agent": [self.user_agent]}
  609. if headers:
  610. actual_headers.update(headers) # type: ignore
  611. response = await self.request("GET", url, headers=Headers(actual_headers))
  612. resp_headers = dict(response.headers.getAllRawHeaders())
  613. if response.code > 299:
  614. logger.warning("Got %d when downloading %s" % (response.code, url))
  615. raise SynapseError(
  616. HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
  617. )
  618. if is_allowed_content_type and b"Content-Type" in resp_headers:
  619. content_type = resp_headers[b"Content-Type"][0].decode("ascii")
  620. if not is_allowed_content_type(content_type):
  621. raise SynapseError(
  622. HTTPStatus.BAD_GATEWAY,
  623. (
  624. "Requested file's content type not allowed for this operation: %s"
  625. % content_type
  626. ),
  627. )
  628. # TODO: if our Content-Type is HTML or something, just read the first
  629. # N bytes into RAM rather than saving it all to disk only to read it
  630. # straight back in again
  631. try:
  632. d = read_body_with_max_size(response, output_stream, max_size)
  633. # Ensure that the body is not read forever.
  634. d = timeout_deferred(d, 30, self.hs.get_reactor())
  635. length = await make_deferred_yieldable(d)
  636. except BodyExceededMaxSize:
  637. raise SynapseError(
  638. HTTPStatus.BAD_GATEWAY,
  639. "Requested file is too large > %r bytes" % (max_size,),
  640. Codes.TOO_LARGE,
  641. )
  642. except defer.TimeoutError:
  643. raise SynapseError(
  644. HTTPStatus.BAD_GATEWAY,
  645. "Requested file took too long to download",
  646. Codes.TOO_LARGE,
  647. )
  648. except Exception as e:
  649. raise SynapseError(
  650. HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
  651. ) from e
  652. return (
  653. length,
  654. resp_headers,
  655. response.request.absoluteURI.decode("ascii"),
  656. response.code,
  657. )
  658. def _timeout_to_request_timed_out_error(f: Failure):
  659. if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
  660. # The TCP connection has its own timeout (set by the 'connectTimeout' param
  661. # on the Agent), which raises twisted_error.TimeoutError exception.
  662. raise RequestTimedOutError("Timeout connecting to remote server")
  663. elif f.check(defer.TimeoutError, ResponseNeverReceived):
  664. # this one means that we hit our overall timeout on the request
  665. raise RequestTimedOutError("Timeout waiting for response from remote server")
  666. return f
  667. class ByteWriteable(Protocol):
  668. """The type of object which must be passed into read_body_with_max_size.
  669. Typically this is a file object.
  670. """
  671. def write(self, data: bytes) -> int:
  672. pass
  673. class BodyExceededMaxSize(Exception):
  674. """The maximum allowed size of the HTTP body was exceeded."""
  675. class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
  676. """A protocol which immediately errors upon receiving data."""
  677. transport: Optional[ITCPTransport] = None
  678. def __init__(self, deferred: defer.Deferred):
  679. self.deferred = deferred
  680. def _maybe_fail(self):
  681. """
  682. Report a max size exceed error and disconnect the first time this is called.
  683. """
  684. if not self.deferred.called:
  685. self.deferred.errback(BodyExceededMaxSize())
  686. # Close the connection (forcefully) since all the data will get
  687. # discarded anyway.
  688. assert self.transport is not None
  689. self.transport.abortConnection()
  690. def dataReceived(self, data: bytes) -> None:
  691. self._maybe_fail()
  692. def connectionLost(self, reason: Failure = connectionDone) -> None:
  693. self._maybe_fail()
  694. class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
  695. """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
  696. transport: Optional[ITCPTransport] = None
  697. def __init__(
  698. self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
  699. ):
  700. self.stream = stream
  701. self.deferred = deferred
  702. self.length = 0
  703. self.max_size = max_size
  704. def dataReceived(self, data: bytes) -> None:
  705. # If the deferred was called, bail early.
  706. if self.deferred.called:
  707. return
  708. try:
  709. self.stream.write(data)
  710. except Exception:
  711. self.deferred.errback()
  712. return
  713. self.length += len(data)
  714. # The first time the maximum size is exceeded, error and cancel the
  715. # connection. dataReceived might be called again if data was received
  716. # in the meantime.
  717. if self.max_size is not None and self.length >= self.max_size:
  718. self.deferred.errback(BodyExceededMaxSize())
  719. # Close the connection (forcefully) since all the data will get
  720. # discarded anyway.
  721. assert self.transport is not None
  722. self.transport.abortConnection()
  723. def connectionLost(self, reason: Failure = connectionDone) -> None:
  724. # If the maximum size was already exceeded, there's nothing to do.
  725. if self.deferred.called:
  726. return
  727. if reason.check(ResponseDone):
  728. self.deferred.callback(self.length)
  729. elif reason.check(PotentialDataLoss):
  730. # stolen from https://github.com/twisted/treq/pull/49/files
  731. # http://twistedmatrix.com/trac/ticket/4840
  732. self.deferred.callback(self.length)
  733. else:
  734. self.deferred.errback(reason)
  735. def read_body_with_max_size(
  736. response: IResponse, stream: ByteWriteable, max_size: Optional[int]
  737. ) -> "defer.Deferred[int]":
  738. """
  739. Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
  740. If the maximum file size is reached, the returned Deferred will resolve to a
  741. Failure with a BodyExceededMaxSize exception.
  742. Args:
  743. response: The HTTP response to read from.
  744. stream: The file-object to write to.
  745. max_size: The maximum file size to allow.
  746. Returns:
  747. A Deferred which resolves to the length of the read body.
  748. """
  749. d: "defer.Deferred[int]" = defer.Deferred()
  750. # If the Content-Length header gives a size larger than the maximum allowed
  751. # size, do not bother downloading the body.
  752. if max_size is not None and response.length != UNKNOWN_LENGTH:
  753. if response.length > max_size:
  754. response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
  755. return d
  756. response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
  757. return d
  758. def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
  759. """
  760. Encodes a map of query arguments to bytes which can be appended to a URL.
  761. Args:
  762. args: The query arguments, a mapping of string to string or list of strings.
  763. Returns:
  764. The query arguments encoded as bytes.
  765. """
  766. if args is None:
  767. return b""
  768. encoded_args = {}
  769. for k, vs in args.items():
  770. if isinstance(vs, str):
  771. vs = [vs]
  772. encoded_args[k] = [v.encode("utf8") for v in vs]
  773. query_str = urllib.parse.urlencode(encoded_args, True)
  774. return query_str.encode("utf8")
  775. @implementer(IPolicyForHTTPS)
  776. class InsecureInterceptableContextFactory(ssl.ContextFactory):
  777. """
  778. Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
  779. Do not use this since it allows an attacker to intercept your communications.
  780. """
  781. def __init__(self):
  782. self._context = SSL.Context(SSL.SSLv23_METHOD)
  783. self._context.set_verify(VERIFY_NONE, lambda *_: False)
  784. def getContext(self, hostname=None, port=None):
  785. return self._context
  786. def creatorForNetloc(self, hostname, port):
  787. return self