client.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968
  1. # Copyright 2014-2016 OpenMarket Ltd
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import urllib.parse
  17. from http import HTTPStatus
  18. from io import BytesIO
  19. from typing import (
  20. TYPE_CHECKING,
  21. Any,
  22. BinaryIO,
  23. Callable,
  24. Dict,
  25. List,
  26. Mapping,
  27. Optional,
  28. Tuple,
  29. Union,
  30. )
  31. import treq
  32. from canonicaljson import encode_canonical_json
  33. from netaddr import AddrFormatError, IPAddress, IPSet
  34. from prometheus_client import Counter
  35. from typing_extensions import Protocol
  36. from zope.interface import implementer, provider
  37. from OpenSSL import SSL
  38. from OpenSSL.SSL import VERIFY_NONE
  39. from twisted.internet import defer, error as twisted_error, protocol, ssl
  40. from twisted.internet.address import IPv4Address, IPv6Address
  41. from twisted.internet.interfaces import (
  42. IAddress,
  43. IDelayedCall,
  44. IHostResolution,
  45. IOpenSSLContextFactory,
  46. IReactorCore,
  47. IReactorPluggableNameResolver,
  48. IReactorTime,
  49. IResolutionReceiver,
  50. ITCPTransport,
  51. )
  52. from twisted.internet.protocol import connectionDone
  53. from twisted.internet.task import Cooperator
  54. from twisted.python.failure import Failure
  55. from twisted.web._newclient import ResponseDone
  56. from twisted.web.client import (
  57. Agent,
  58. HTTPConnectionPool,
  59. ResponseNeverReceived,
  60. readBody,
  61. )
  62. from twisted.web.http import PotentialDataLoss
  63. from twisted.web.http_headers import Headers
  64. from twisted.web.iweb import (
  65. UNKNOWN_LENGTH,
  66. IAgent,
  67. IBodyProducer,
  68. IPolicyForHTTPS,
  69. IResponse,
  70. )
  71. from synapse.api.errors import Codes, HttpResponseException, SynapseError
  72. from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
  73. from synapse.http.proxyagent import ProxyAgent
  74. from synapse.http.types import QueryParams
  75. from synapse.logging.context import make_deferred_yieldable
  76. from synapse.logging.opentracing import set_tag, start_active_span, tags
  77. from synapse.types import ISynapseReactor
  78. from synapse.util import json_decoder
  79. from synapse.util.async_helpers import timeout_deferred
  80. if TYPE_CHECKING:
  81. from synapse.server import HomeServer
  82. logger = logging.getLogger(__name__)
  83. outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
  84. incoming_responses_counter = Counter(
  85. "synapse_http_client_responses", "", ["method", "code"]
  86. )
  87. # the type of the headers map, to be passed to the t.w.h.Headers.
  88. #
  89. # The actual type accepted by Twisted is
  90. # Mapping[Union[str, bytes], Sequence[Union[str, bytes]] ,
  91. # allowing us to mix and match str and bytes freely. However: any str is also a
  92. # Sequence[str]; passing a header string value which is a
  93. # standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun.
  94. # We use a narrower value type (RawHeaderValue) to avoid this footgun.
  95. #
  96. # We also simplify the keys to be either all str or all bytes. This helps because
  97. # Dict[K, V] is invariant in K (and indeed V).
  98. RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
  99. # the value actually has to be a List, but List is invariant so we can't specify that
  100. # the entries can either be Lists or bytes.
  101. RawHeaderValue = Union[
  102. List[str],
  103. List[bytes],
  104. List[Union[str, bytes]],
  105. Tuple[str, ...],
  106. Tuple[bytes, ...],
  107. Tuple[Union[str, bytes], ...],
  108. ]
  109. def check_against_blacklist(
  110. ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
  111. ) -> bool:
  112. """
  113. Compares an IP address to allowed and disallowed IP sets.
  114. Args:
  115. ip_address: The IP address to check
  116. ip_whitelist: Allowed IP addresses.
  117. ip_blacklist: Disallowed IP addresses.
  118. Returns:
  119. True if the IP address is in the blacklist and not in the whitelist.
  120. """
  121. if ip_address in ip_blacklist:
  122. if ip_whitelist is None or ip_address not in ip_whitelist:
  123. return True
  124. return False
  125. _EPSILON = 0.00000001
  126. def _make_scheduler(
  127. reactor: IReactorTime,
  128. ) -> Callable[[Callable[[], object]], IDelayedCall]:
  129. """Makes a schedular suitable for a Cooperator using the given reactor.
  130. (This is effectively just a copy from `twisted.internet.task`)
  131. """
  132. def _scheduler(x: Callable[[], object]) -> IDelayedCall:
  133. return reactor.callLater(_EPSILON, x)
  134. return _scheduler
  135. class _IPBlacklistingResolver:
  136. """
  137. A proxy for reactor.nameResolver which only produces non-blacklisted IP
  138. addresses, preventing DNS rebinding attacks on URL preview.
  139. """
  140. def __init__(
  141. self,
  142. reactor: IReactorPluggableNameResolver,
  143. ip_whitelist: Optional[IPSet],
  144. ip_blacklist: IPSet,
  145. ):
  146. """
  147. Args:
  148. reactor: The twisted reactor.
  149. ip_whitelist: IP addresses to allow.
  150. ip_blacklist: IP addresses to disallow.
  151. """
  152. self._reactor = reactor
  153. self._ip_whitelist = ip_whitelist
  154. self._ip_blacklist = ip_blacklist
  155. def resolveHostName(
  156. self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
  157. ) -> IResolutionReceiver:
  158. addresses: List[IAddress] = []
  159. def _callback() -> None:
  160. has_bad_ip = False
  161. for address in addresses:
  162. # We only expect IPv4 and IPv6 addresses since only A/AAAA lookups
  163. # should go through this path.
  164. if not isinstance(address, (IPv4Address, IPv6Address)):
  165. continue
  166. ip_address = IPAddress(address.host)
  167. if check_against_blacklist(
  168. ip_address, self._ip_whitelist, self._ip_blacklist
  169. ):
  170. logger.info(
  171. "Dropped %s from DNS resolution to %s due to blacklist"
  172. % (ip_address, hostname)
  173. )
  174. has_bad_ip = True
  175. # if we have a blacklisted IP, we'd like to raise an error to block the
  176. # request, but all we can really do from here is claim that there were no
  177. # valid results.
  178. if not has_bad_ip:
  179. for address in addresses:
  180. recv.addressResolved(address)
  181. recv.resolutionComplete()
  182. @provider(IResolutionReceiver)
  183. class EndpointReceiver:
  184. @staticmethod
  185. def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
  186. recv.resolutionBegan(resolutionInProgress)
  187. @staticmethod
  188. def addressResolved(address: IAddress) -> None:
  189. addresses.append(address)
  190. @staticmethod
  191. def resolutionComplete() -> None:
  192. _callback()
  193. self._reactor.nameResolver.resolveHostName(
  194. EndpointReceiver, hostname, portNumber=portNumber
  195. )
  196. return recv
  197. # ISynapseReactor implies IReactorCore, but explicitly marking it this as an implementer
  198. # of IReactorCore seems to keep mypy-zope happier.
  199. @implementer(IReactorCore, ISynapseReactor)
  200. class BlacklistingReactorWrapper:
  201. """
  202. A Reactor wrapper which will prevent DNS resolution to blacklisted IP
  203. addresses, to prevent DNS rebinding.
  204. """
  205. def __init__(
  206. self,
  207. reactor: IReactorPluggableNameResolver,
  208. ip_whitelist: Optional[IPSet],
  209. ip_blacklist: IPSet,
  210. ):
  211. self._reactor = reactor
  212. # We need to use a DNS resolver which filters out blacklisted IP
  213. # addresses, to prevent DNS rebinding.
  214. self._nameResolver = _IPBlacklistingResolver(
  215. self._reactor, ip_whitelist, ip_blacklist
  216. )
  217. def __getattr__(self, attr: str) -> Any:
  218. # Passthrough to the real reactor except for the DNS resolver.
  219. if attr == "nameResolver":
  220. return self._nameResolver
  221. else:
  222. return getattr(self._reactor, attr)
  223. class BlacklistingAgentWrapper(Agent):
  224. """
  225. An Agent wrapper which will prevent access to IP addresses being accessed
  226. directly (without an IP address lookup).
  227. """
  228. def __init__(
  229. self,
  230. agent: IAgent,
  231. ip_blacklist: IPSet,
  232. ip_whitelist: Optional[IPSet] = None,
  233. ):
  234. """
  235. Args:
  236. agent: The Agent to wrap.
  237. ip_whitelist: IP addresses to allow.
  238. ip_blacklist: IP addresses to disallow.
  239. """
  240. self._agent = agent
  241. self._ip_whitelist = ip_whitelist
  242. self._ip_blacklist = ip_blacklist
  243. def request(
  244. self,
  245. method: bytes,
  246. uri: bytes,
  247. headers: Optional[Headers] = None,
  248. bodyProducer: Optional[IBodyProducer] = None,
  249. ) -> defer.Deferred:
  250. h = urllib.parse.urlparse(uri.decode("ascii"))
  251. try:
  252. # h.hostname is Optional[str], None raises an AddrFormatError, so
  253. # this is safe even though IPAddress requires a str.
  254. ip_address = IPAddress(h.hostname) # type: ignore[arg-type]
  255. except AddrFormatError:
  256. # Not an IP
  257. pass
  258. else:
  259. if check_against_blacklist(
  260. ip_address, self._ip_whitelist, self._ip_blacklist
  261. ):
  262. logger.info("Blocking access to %s due to blacklist" % (ip_address,))
  263. e = SynapseError(
  264. HTTPStatus.FORBIDDEN, "IP address blocked by IP blacklist entry"
  265. )
  266. return defer.fail(Failure(e))
  267. return self._agent.request(
  268. method, uri, headers=headers, bodyProducer=bodyProducer
  269. )
  270. class SimpleHttpClient:
  271. """
  272. A simple, no-frills HTTP client with methods that wrap up common ways of
  273. using HTTP in Matrix
  274. """
  275. def __init__(
  276. self,
  277. hs: "HomeServer",
  278. treq_args: Optional[Dict[str, Any]] = None,
  279. ip_whitelist: Optional[IPSet] = None,
  280. ip_blacklist: Optional[IPSet] = None,
  281. use_proxy: bool = False,
  282. ):
  283. """
  284. Args:
  285. hs
  286. treq_args: Extra keyword arguments to be given to treq.request.
  287. ip_blacklist: The IP addresses that are blacklisted that
  288. we may not request.
  289. ip_whitelist: The whitelisted IP addresses, that we can
  290. request if it were otherwise caught in a blacklist.
  291. use_proxy: Whether proxy settings should be discovered and used
  292. from conventional environment variables.
  293. """
  294. self.hs = hs
  295. self._ip_whitelist = ip_whitelist
  296. self._ip_blacklist = ip_blacklist
  297. self._extra_treq_args = treq_args or {}
  298. self.clock = hs.get_clock()
  299. user_agent = hs.version_string
  300. if hs.config.server.user_agent_suffix:
  301. user_agent = "%s %s" % (
  302. user_agent,
  303. hs.config.server.user_agent_suffix,
  304. )
  305. self.user_agent = user_agent.encode("ascii")
  306. # We use this for our body producers to ensure that they use the correct
  307. # reactor.
  308. self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor()))
  309. if self._ip_blacklist:
  310. # If we have an IP blacklist, we need to use a DNS resolver which
  311. # filters out blacklisted IP addresses, to prevent DNS rebinding.
  312. self.reactor: ISynapseReactor = BlacklistingReactorWrapper(
  313. hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
  314. )
  315. else:
  316. self.reactor = hs.get_reactor()
  317. # the pusher makes lots of concurrent SSL connections to sygnal, and
  318. # tends to do so in batches, so we need to allow the pool to keep
  319. # lots of idle connections around.
  320. pool = HTTPConnectionPool(self.reactor)
  321. # XXX: The justification for using the cache factor here is that larger instances
  322. # will need both more cache and more connections.
  323. # Still, this should probably be a separate dial
  324. pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5)
  325. pool.cachedConnectionTimeout = 2 * 60
  326. self.agent: IAgent = ProxyAgent(
  327. self.reactor,
  328. hs.get_reactor(),
  329. connectTimeout=15,
  330. contextFactory=self.hs.get_http_client_context_factory(),
  331. pool=pool,
  332. use_proxy=use_proxy,
  333. )
  334. if self._ip_blacklist:
  335. # If we have an IP blacklist, we then install the blacklisting Agent
  336. # which prevents direct access to IP addresses, that are not caught
  337. # by the DNS resolution.
  338. self.agent = BlacklistingAgentWrapper(
  339. self.agent,
  340. ip_blacklist=self._ip_blacklist,
  341. ip_whitelist=self._ip_whitelist,
  342. )
  343. async def request(
  344. self,
  345. method: str,
  346. uri: str,
  347. data: Optional[bytes] = None,
  348. headers: Optional[Headers] = None,
  349. ) -> IResponse:
  350. """
  351. Args:
  352. method: HTTP method to use.
  353. uri: URI to query.
  354. data: Data to send in the request body, if applicable.
  355. headers: Request headers.
  356. Returns:
  357. Response object, once the headers have been read.
  358. Raises:
  359. RequestTimedOutError if the request times out before the headers are read
  360. """
  361. outgoing_requests_counter.labels(method).inc()
  362. # log request but strip `access_token` (AS requests for example include this)
  363. logger.debug("Sending request %s %s", method, redact_uri(uri))
  364. with start_active_span(
  365. "outgoing-client-request",
  366. tags={
  367. tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
  368. tags.HTTP_METHOD: method,
  369. tags.HTTP_URL: uri,
  370. },
  371. finish_on_close=True,
  372. ):
  373. try:
  374. body_producer = None
  375. if data is not None:
  376. body_producer = QuieterFileBodyProducer(
  377. BytesIO(data),
  378. cooperator=self._cooperator,
  379. )
  380. request_deferred: defer.Deferred = treq.request(
  381. method,
  382. uri,
  383. agent=self.agent,
  384. data=body_producer,
  385. headers=headers,
  386. # Avoid buffering the body in treq since we do not reuse
  387. # response bodies.
  388. unbuffered=True,
  389. **self._extra_treq_args,
  390. )
  391. # we use our own timeout mechanism rather than treq's as a workaround
  392. # for https://twistedmatrix.com/trac/ticket/9534.
  393. request_deferred = timeout_deferred(
  394. request_deferred,
  395. 60,
  396. self.hs.get_reactor(),
  397. )
  398. # turn timeouts into RequestTimedOutErrors
  399. request_deferred.addErrback(_timeout_to_request_timed_out_error)
  400. response = await make_deferred_yieldable(request_deferred)
  401. incoming_responses_counter.labels(method, response.code).inc()
  402. logger.info(
  403. "Received response to %s %s: %s",
  404. method,
  405. redact_uri(uri),
  406. response.code,
  407. )
  408. return response
  409. except Exception as e:
  410. incoming_responses_counter.labels(method, "ERR").inc()
  411. logger.info(
  412. "Error sending request to %s %s: %s %s",
  413. method,
  414. redact_uri(uri),
  415. type(e).__name__,
  416. e.args[0],
  417. )
  418. set_tag(tags.ERROR, True)
  419. set_tag("error_reason", e.args[0])
  420. raise
  421. async def post_urlencoded_get_json(
  422. self,
  423. uri: str,
  424. args: Optional[Mapping[str, Union[str, List[str]]]] = None,
  425. headers: Optional[RawHeaders] = None,
  426. ) -> Any:
  427. """
  428. Args:
  429. uri: uri to query
  430. args: parameters to be url-encoded in the body
  431. headers: a map from header name to a list of values for that header
  432. Returns:
  433. parsed json
  434. Raises:
  435. RequestTimedOutError: if there is a timeout before the response headers
  436. are received. Note there is currently no timeout on reading the response
  437. body.
  438. HttpResponseException: On a non-2xx HTTP response.
  439. ValueError: if the response was not JSON
  440. """
  441. # TODO: Do we ever want to log message contents?
  442. logger.debug("post_urlencoded_get_json args: %s", args)
  443. query_bytes = encode_query_args(args)
  444. actual_headers = {
  445. b"Content-Type": [b"application/x-www-form-urlencoded"],
  446. b"User-Agent": [self.user_agent],
  447. b"Accept": [b"application/json"],
  448. }
  449. if headers:
  450. actual_headers.update(headers) # type: ignore
  451. response = await self.request(
  452. "POST", uri, headers=Headers(actual_headers), data=query_bytes
  453. )
  454. body = await make_deferred_yieldable(readBody(response))
  455. if 200 <= response.code < 300:
  456. return json_decoder.decode(body.decode("utf-8"))
  457. else:
  458. raise HttpResponseException(
  459. response.code, response.phrase.decode("ascii", errors="replace"), body
  460. )
  461. async def post_json_get_json(
  462. self, uri: str, post_json: Any, headers: Optional[RawHeaders] = None
  463. ) -> Any:
  464. """
  465. Args:
  466. uri: URI to query.
  467. post_json: request body, to be encoded as json
  468. headers: a map from header name to a list of values for that header
  469. Returns:
  470. parsed json
  471. Raises:
  472. RequestTimedOutError: if there is a timeout before the response headers
  473. are received. Note there is currently no timeout on reading the response
  474. body.
  475. HttpResponseException: On a non-2xx HTTP response.
  476. ValueError: if the response was not JSON
  477. """
  478. json_str = encode_canonical_json(post_json)
  479. logger.debug("HTTP POST %s -> %s", json_str, uri)
  480. actual_headers = {
  481. b"Content-Type": [b"application/json"],
  482. b"User-Agent": [self.user_agent],
  483. b"Accept": [b"application/json"],
  484. }
  485. if headers:
  486. actual_headers.update(headers) # type: ignore
  487. response = await self.request(
  488. "POST", uri, headers=Headers(actual_headers), data=json_str
  489. )
  490. body = await make_deferred_yieldable(readBody(response))
  491. if 200 <= response.code < 300:
  492. return json_decoder.decode(body.decode("utf-8"))
  493. else:
  494. raise HttpResponseException(
  495. response.code, response.phrase.decode("ascii", errors="replace"), body
  496. )
  497. async def get_json(
  498. self,
  499. uri: str,
  500. args: Optional[QueryParams] = None,
  501. headers: Optional[RawHeaders] = None,
  502. ) -> Any:
  503. """Gets some json from the given URI.
  504. Args:
  505. uri: The URI to request, not including query parameters
  506. args: A dictionary used to create query string
  507. headers: a map from header name to a list of values for that header
  508. Returns:
  509. Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
  510. Raises:
  511. RequestTimedOutError: if there is a timeout before the response headers
  512. are received. Note there is currently no timeout on reading the response
  513. body.
  514. HttpResponseException On a non-2xx HTTP response.
  515. ValueError: if the response was not JSON
  516. """
  517. actual_headers = {b"Accept": [b"application/json"]}
  518. if headers:
  519. actual_headers.update(headers) # type: ignore
  520. body = await self.get_raw(uri, args, headers=actual_headers)
  521. return json_decoder.decode(body.decode("utf-8"))
  522. async def put_json(
  523. self,
  524. uri: str,
  525. json_body: Any,
  526. args: Optional[QueryParams] = None,
  527. headers: Optional[RawHeaders] = None,
  528. ) -> Any:
  529. """Puts some json to the given URI.
  530. Args:
  531. uri: The URI to request, not including query parameters
  532. json_body: The JSON to put in the HTTP body,
  533. args: A dictionary used to create query strings
  534. headers: a map from header name to a list of values for that header
  535. Returns:
  536. Succeeds when we get a 2xx HTTP response, with the HTTP body as JSON.
  537. Raises:
  538. RequestTimedOutError: if there is a timeout before the response headers
  539. are received. Note there is currently no timeout on reading the response
  540. body.
  541. HttpResponseException On a non-2xx HTTP response.
  542. ValueError: if the response was not JSON
  543. """
  544. if args:
  545. query_str = urllib.parse.urlencode(args, True)
  546. uri = "%s?%s" % (uri, query_str)
  547. json_str = encode_canonical_json(json_body)
  548. actual_headers = {
  549. b"Content-Type": [b"application/json"],
  550. b"User-Agent": [self.user_agent],
  551. b"Accept": [b"application/json"],
  552. }
  553. if headers:
  554. actual_headers.update(headers) # type: ignore
  555. response = await self.request(
  556. "PUT", uri, headers=Headers(actual_headers), data=json_str
  557. )
  558. body = await make_deferred_yieldable(readBody(response))
  559. if 200 <= response.code < 300:
  560. return json_decoder.decode(body.decode("utf-8"))
  561. else:
  562. raise HttpResponseException(
  563. response.code, response.phrase.decode("ascii", errors="replace"), body
  564. )
  565. async def get_raw(
  566. self,
  567. uri: str,
  568. args: Optional[QueryParams] = None,
  569. headers: Optional[RawHeaders] = None,
  570. ) -> bytes:
  571. """Gets raw text from the given URI.
  572. Args:
  573. uri: The URI to request, not including query parameters
  574. args: A dictionary used to create query strings
  575. headers: a map from header name to a list of values for that header
  576. Returns:
  577. Succeeds when we get a 2xx HTTP response, with the
  578. HTTP body as bytes.
  579. Raises:
  580. RequestTimedOutError: if there is a timeout before the response headers
  581. are received. Note there is currently no timeout on reading the response
  582. body.
  583. HttpResponseException on a non-2xx HTTP response.
  584. """
  585. if args:
  586. query_str = urllib.parse.urlencode(args, True)
  587. uri = "%s?%s" % (uri, query_str)
  588. actual_headers = {b"User-Agent": [self.user_agent]}
  589. if headers:
  590. actual_headers.update(headers) # type: ignore
  591. response = await self.request("GET", uri, headers=Headers(actual_headers))
  592. body = await make_deferred_yieldable(readBody(response))
  593. if 200 <= response.code < 300:
  594. return body
  595. else:
  596. raise HttpResponseException(
  597. response.code, response.phrase.decode("ascii", errors="replace"), body
  598. )
  599. # XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
  600. # The two should be factored out.
  601. async def get_file(
  602. self,
  603. url: str,
  604. output_stream: BinaryIO,
  605. max_size: Optional[int] = None,
  606. headers: Optional[RawHeaders] = None,
  607. is_allowed_content_type: Optional[Callable[[str], bool]] = None,
  608. ) -> Tuple[int, Dict[bytes, List[bytes]], str, int]:
  609. """GETs a file from a given URL
  610. Args:
  611. url: The URL to GET
  612. output_stream: File to write the response body to.
  613. headers: A map from header name to a list of values for that header
  614. is_allowed_content_type: A predicate to determine whether the
  615. content type of the file we're downloading is allowed. If set and
  616. it evaluates to False when called with the content type, the
  617. request will be terminated before completing the download by
  618. raising SynapseError.
  619. Returns:
  620. A tuple of the file length, dict of the response
  621. headers, absolute URI of the response and HTTP response code.
  622. Raises:
  623. RequestTimedOutError: if there is a timeout before the response headers
  624. are received. Note there is currently no timeout on reading the response
  625. body.
  626. SynapseError: if the response is not a 2xx, the remote file is too large, or
  627. another exception happens during the download.
  628. """
  629. actual_headers = {b"User-Agent": [self.user_agent]}
  630. if headers:
  631. actual_headers.update(headers) # type: ignore
  632. response = await self.request("GET", url, headers=Headers(actual_headers))
  633. resp_headers = dict(response.headers.getAllRawHeaders())
  634. if response.code > 299:
  635. logger.warning("Got %d when downloading %s" % (response.code, url))
  636. raise SynapseError(
  637. HTTPStatus.BAD_GATEWAY, "Got error %d" % (response.code,), Codes.UNKNOWN
  638. )
  639. if is_allowed_content_type and b"Content-Type" in resp_headers:
  640. content_type = resp_headers[b"Content-Type"][0].decode("ascii")
  641. if not is_allowed_content_type(content_type):
  642. raise SynapseError(
  643. HTTPStatus.BAD_GATEWAY,
  644. (
  645. "Requested file's content type not allowed for this operation: %s"
  646. % content_type
  647. ),
  648. )
  649. # TODO: if our Content-Type is HTML or something, just read the first
  650. # N bytes into RAM rather than saving it all to disk only to read it
  651. # straight back in again
  652. try:
  653. d = read_body_with_max_size(response, output_stream, max_size)
  654. # Ensure that the body is not read forever.
  655. d = timeout_deferred(d, 30, self.hs.get_reactor())
  656. length = await make_deferred_yieldable(d)
  657. except BodyExceededMaxSize:
  658. raise SynapseError(
  659. HTTPStatus.BAD_GATEWAY,
  660. "Requested file is too large > %r bytes" % (max_size,),
  661. Codes.TOO_LARGE,
  662. )
  663. except defer.TimeoutError:
  664. raise SynapseError(
  665. HTTPStatus.BAD_GATEWAY,
  666. "Requested file took too long to download",
  667. Codes.TOO_LARGE,
  668. )
  669. except Exception as e:
  670. raise SynapseError(
  671. HTTPStatus.BAD_GATEWAY, ("Failed to download remote body: %s" % e)
  672. ) from e
  673. return (
  674. length,
  675. resp_headers,
  676. response.request.absoluteURI.decode("ascii"),
  677. response.code,
  678. )
  679. def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
  680. if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
  681. # The TCP connection has its own timeout (set by the 'connectTimeout' param
  682. # on the Agent), which raises twisted_error.TimeoutError exception.
  683. raise RequestTimedOutError("Timeout connecting to remote server")
  684. elif f.check(defer.TimeoutError, ResponseNeverReceived):
  685. # this one means that we hit our overall timeout on the request
  686. raise RequestTimedOutError("Timeout waiting for response from remote server")
  687. return f
  688. class ByteWriteable(Protocol):
  689. """The type of object which must be passed into read_body_with_max_size.
  690. Typically this is a file object.
  691. """
  692. def write(self, data: bytes) -> int:
  693. pass
  694. class BodyExceededMaxSize(Exception):
  695. """The maximum allowed size of the HTTP body was exceeded."""
  696. class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
  697. """A protocol which immediately errors upon receiving data."""
  698. transport: Optional[ITCPTransport] = None
  699. def __init__(self, deferred: defer.Deferred):
  700. self.deferred = deferred
  701. def _maybe_fail(self) -> None:
  702. """
  703. Report a max size exceed error and disconnect the first time this is called.
  704. """
  705. if not self.deferred.called:
  706. self.deferred.errback(BodyExceededMaxSize())
  707. # Close the connection (forcefully) since all the data will get
  708. # discarded anyway.
  709. assert self.transport is not None
  710. self.transport.abortConnection()
  711. def dataReceived(self, data: bytes) -> None:
  712. self._maybe_fail()
  713. def connectionLost(self, reason: Failure = connectionDone) -> None:
  714. self._maybe_fail()
  715. class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
  716. """A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
  717. transport: Optional[ITCPTransport] = None
  718. def __init__(
  719. self, stream: ByteWriteable, deferred: defer.Deferred, max_size: Optional[int]
  720. ):
  721. self.stream = stream
  722. self.deferred = deferred
  723. self.length = 0
  724. self.max_size = max_size
  725. def dataReceived(self, data: bytes) -> None:
  726. # If the deferred was called, bail early.
  727. if self.deferred.called:
  728. return
  729. try:
  730. self.stream.write(data)
  731. except Exception:
  732. self.deferred.errback()
  733. return
  734. self.length += len(data)
  735. # The first time the maximum size is exceeded, error and cancel the
  736. # connection. dataReceived might be called again if data was received
  737. # in the meantime.
  738. if self.max_size is not None and self.length >= self.max_size:
  739. self.deferred.errback(BodyExceededMaxSize())
  740. # Close the connection (forcefully) since all the data will get
  741. # discarded anyway.
  742. assert self.transport is not None
  743. self.transport.abortConnection()
  744. def connectionLost(self, reason: Failure = connectionDone) -> None:
  745. # If the maximum size was already exceeded, there's nothing to do.
  746. if self.deferred.called:
  747. return
  748. if reason.check(ResponseDone):
  749. self.deferred.callback(self.length)
  750. elif reason.check(PotentialDataLoss):
  751. # stolen from https://github.com/twisted/treq/pull/49/files
  752. # http://twistedmatrix.com/trac/ticket/4840
  753. self.deferred.callback(self.length)
  754. else:
  755. self.deferred.errback(reason)
  756. def read_body_with_max_size(
  757. response: IResponse, stream: ByteWriteable, max_size: Optional[int]
  758. ) -> "defer.Deferred[int]":
  759. """
  760. Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
  761. If the maximum file size is reached, the returned Deferred will resolve to a
  762. Failure with a BodyExceededMaxSize exception.
  763. Args:
  764. response: The HTTP response to read from.
  765. stream: The file-object to write to.
  766. max_size: The maximum file size to allow.
  767. Returns:
  768. A Deferred which resolves to the length of the read body.
  769. """
  770. d: "defer.Deferred[int]" = defer.Deferred()
  771. # If the Content-Length header gives a size larger than the maximum allowed
  772. # size, do not bother downloading the body.
  773. if max_size is not None and response.length != UNKNOWN_LENGTH:
  774. if response.length > max_size:
  775. response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
  776. return d
  777. response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
  778. return d
  779. def encode_query_args(args: Optional[QueryParams]) -> bytes:
  780. """
  781. Encodes a map of query arguments to bytes which can be appended to a URL.
  782. Args:
  783. args: The query arguments, a mapping of string to string or list of strings.
  784. Returns:
  785. The query arguments encoded as bytes.
  786. """
  787. if args is None:
  788. return b""
  789. query_str = urllib.parse.urlencode(args, True)
  790. return query_str.encode("utf8")
  791. @implementer(IPolicyForHTTPS)
  792. class InsecureInterceptableContextFactory(ssl.ContextFactory):
  793. """
  794. Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
  795. Do not use this since it allows an attacker to intercept your communications.
  796. """
  797. def __init__(self) -> None:
  798. self._context = SSL.Context(SSL.SSLv23_METHOD)
  799. self._context.set_verify(VERIFY_NONE, lambda *_: False)
  800. def getContext(self) -> SSL.Context:
  801. return self._context
  802. def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory:
  803. return self