test_url_preview.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import attr
  17. from twisted.internet._resolver import HostResolution
  18. from twisted.internet.address import IPv4Address, IPv6Address
  19. from twisted.internet.error import DNSLookupError
  20. from twisted.python.failure import Failure
  21. from twisted.test.proto_helpers import AccumulatingProtocol
  22. from twisted.web._newclient import ResponseDone
  23. from tests import unittest
  24. from tests.server import FakeTransport
  25. @attr.s
  26. class FakeResponse(object):
  27. version = attr.ib()
  28. code = attr.ib()
  29. phrase = attr.ib()
  30. headers = attr.ib()
  31. body = attr.ib()
  32. absoluteURI = attr.ib()
  33. @property
  34. def request(self):
  35. @attr.s
  36. class FakeTransport(object):
  37. absoluteURI = self.absoluteURI
  38. return FakeTransport()
  39. def deliverBody(self, protocol):
  40. protocol.dataReceived(self.body)
  41. protocol.connectionLost(Failure(ResponseDone()))
  42. class URLPreviewTests(unittest.HomeserverTestCase):
  43. hijack_auth = True
  44. user_id = "@test:user"
  45. end_content = (
  46. b"<html><head>"
  47. b'<meta property="og:title" content="~matrix~" />'
  48. b'<meta property="og:description" content="hi" />'
  49. b"</head></html>"
  50. )
  51. def make_homeserver(self, reactor, clock):
  52. config = self.default_config()
  53. config["url_preview_enabled"] = True
  54. config["max_spider_size"] = 9999999
  55. config["url_preview_ip_range_blacklist"] = (
  56. "192.168.1.1",
  57. "1.0.0.0/8",
  58. "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
  59. "2001:800::/21",
  60. )
  61. config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
  62. config["url_preview_url_blacklist"] = []
  63. self.storage_path = self.mktemp()
  64. self.media_store_path = self.mktemp()
  65. os.mkdir(self.storage_path)
  66. os.mkdir(self.media_store_path)
  67. config["media_store_path"] = self.media_store_path
  68. provider_config = {
  69. "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
  70. "store_local": True,
  71. "store_synchronous": False,
  72. "store_remote": True,
  73. "config": {"directory": self.storage_path},
  74. }
  75. config["media_storage_providers"] = [provider_config]
  76. hs = self.setup_test_homeserver(config=config)
  77. return hs
  78. def prepare(self, reactor, clock, hs):
  79. self.media_repo = hs.get_media_repository_resource()
  80. self.preview_url = self.media_repo.children[b"preview_url"]
  81. self.lookups = {}
  82. class Resolver(object):
  83. def resolveHostName(
  84. _self,
  85. resolutionReceiver,
  86. hostName,
  87. portNumber=0,
  88. addressTypes=None,
  89. transportSemantics="TCP",
  90. ):
  91. resolution = HostResolution(hostName)
  92. resolutionReceiver.resolutionBegan(resolution)
  93. if hostName not in self.lookups:
  94. raise DNSLookupError("OH NO")
  95. for i in self.lookups[hostName]:
  96. resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
  97. resolutionReceiver.resolutionComplete()
  98. return resolutionReceiver
  99. self.reactor.nameResolver = Resolver()
  100. def test_cache_returns_correct_type(self):
  101. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  102. request, channel = self.make_request(
  103. "GET", "url_preview?url=http://matrix.org", shorthand=False
  104. )
  105. request.render(self.preview_url)
  106. self.pump()
  107. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  108. server = AccumulatingProtocol()
  109. server.makeConnection(FakeTransport(client, self.reactor))
  110. client.makeConnection(FakeTransport(server, self.reactor))
  111. client.dataReceived(
  112. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  113. % (len(self.end_content),)
  114. + self.end_content
  115. )
  116. self.pump()
  117. self.assertEqual(channel.code, 200)
  118. self.assertEqual(
  119. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  120. )
  121. # Check the cache returns the correct response
  122. request, channel = self.make_request(
  123. "GET", "url_preview?url=http://matrix.org", shorthand=False
  124. )
  125. request.render(self.preview_url)
  126. self.pump()
  127. # Check the cache response has the same content
  128. self.assertEqual(channel.code, 200)
  129. self.assertEqual(
  130. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  131. )
  132. # Clear the in-memory cache
  133. self.assertIn("http://matrix.org", self.preview_url._cache)
  134. self.preview_url._cache.pop("http://matrix.org")
  135. self.assertNotIn("http://matrix.org", self.preview_url._cache)
  136. # Check the database cache returns the correct response
  137. request, channel = self.make_request(
  138. "GET", "url_preview?url=http://matrix.org", shorthand=False
  139. )
  140. request.render(self.preview_url)
  141. self.pump()
  142. # Check the cache response has the same content
  143. self.assertEqual(channel.code, 200)
  144. self.assertEqual(
  145. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  146. )
  147. def test_non_ascii_preview_httpequiv(self):
  148. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  149. end_content = (
  150. b"<html><head>"
  151. b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
  152. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  153. b'<meta property="og:description" content="hi" />'
  154. b"</head></html>"
  155. )
  156. request, channel = self.make_request(
  157. "GET", "url_preview?url=http://matrix.org", shorthand=False
  158. )
  159. request.render(self.preview_url)
  160. self.pump()
  161. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  162. server = AccumulatingProtocol()
  163. server.makeConnection(FakeTransport(client, self.reactor))
  164. client.makeConnection(FakeTransport(server, self.reactor))
  165. client.dataReceived(
  166. (
  167. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  168. b'Content-Type: text/html; charset="utf8"\r\n\r\n'
  169. )
  170. % (len(end_content),)
  171. + end_content
  172. )
  173. self.pump()
  174. self.assertEqual(channel.code, 200)
  175. self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
  176. def test_non_ascii_preview_content_type(self):
  177. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  178. end_content = (
  179. b"<html><head>"
  180. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  181. b'<meta property="og:description" content="hi" />'
  182. b"</head></html>"
  183. )
  184. request, channel = self.make_request(
  185. "GET", "url_preview?url=http://matrix.org", shorthand=False
  186. )
  187. request.render(self.preview_url)
  188. self.pump()
  189. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  190. server = AccumulatingProtocol()
  191. server.makeConnection(FakeTransport(client, self.reactor))
  192. client.makeConnection(FakeTransport(server, self.reactor))
  193. client.dataReceived(
  194. (
  195. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  196. b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
  197. )
  198. % (len(end_content),)
  199. + end_content
  200. )
  201. self.pump()
  202. self.assertEqual(channel.code, 200)
  203. self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
  204. def test_ipaddr(self):
  205. """
  206. IP addresses can be previewed directly.
  207. """
  208. self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
  209. request, channel = self.make_request(
  210. "GET", "url_preview?url=http://example.com", shorthand=False
  211. )
  212. request.render(self.preview_url)
  213. self.pump()
  214. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  215. server = AccumulatingProtocol()
  216. server.makeConnection(FakeTransport(client, self.reactor))
  217. client.makeConnection(FakeTransport(server, self.reactor))
  218. client.dataReceived(
  219. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  220. % (len(self.end_content),)
  221. + self.end_content
  222. )
  223. self.pump()
  224. self.assertEqual(channel.code, 200)
  225. self.assertEqual(
  226. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  227. )
  228. def test_blacklisted_ip_specific(self):
  229. """
  230. Blacklisted IP addresses, found via DNS, are not spidered.
  231. """
  232. self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
  233. request, channel = self.make_request(
  234. "GET", "url_preview?url=http://example.com", shorthand=False
  235. )
  236. request.render(self.preview_url)
  237. self.pump()
  238. # No requests made.
  239. self.assertEqual(len(self.reactor.tcpClients), 0)
  240. self.assertEqual(channel.code, 502)
  241. self.assertEqual(
  242. channel.json_body,
  243. {
  244. "errcode": "M_UNKNOWN",
  245. "error": "DNS resolution failure during URL preview generation",
  246. },
  247. )
  248. def test_blacklisted_ip_range(self):
  249. """
  250. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  251. """
  252. self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
  253. request, channel = self.make_request(
  254. "GET", "url_preview?url=http://example.com", shorthand=False
  255. )
  256. request.render(self.preview_url)
  257. self.pump()
  258. self.assertEqual(channel.code, 502)
  259. self.assertEqual(
  260. channel.json_body,
  261. {
  262. "errcode": "M_UNKNOWN",
  263. "error": "DNS resolution failure during URL preview generation",
  264. },
  265. )
  266. def test_blacklisted_ip_specific_direct(self):
  267. """
  268. Blacklisted IP addresses, accessed directly, are not spidered.
  269. """
  270. request, channel = self.make_request(
  271. "GET", "url_preview?url=http://192.168.1.1", shorthand=False
  272. )
  273. request.render(self.preview_url)
  274. self.pump()
  275. # No requests made.
  276. self.assertEqual(len(self.reactor.tcpClients), 0)
  277. self.assertEqual(
  278. channel.json_body,
  279. {
  280. "errcode": "M_UNKNOWN",
  281. "error": "IP address blocked by IP blacklist entry",
  282. },
  283. )
  284. self.assertEqual(channel.code, 403)
  285. def test_blacklisted_ip_range_direct(self):
  286. """
  287. Blacklisted IP ranges, accessed directly, are not spidered.
  288. """
  289. request, channel = self.make_request(
  290. "GET", "url_preview?url=http://1.1.1.2", shorthand=False
  291. )
  292. request.render(self.preview_url)
  293. self.pump()
  294. self.assertEqual(channel.code, 403)
  295. self.assertEqual(
  296. channel.json_body,
  297. {
  298. "errcode": "M_UNKNOWN",
  299. "error": "IP address blocked by IP blacklist entry",
  300. },
  301. )
  302. def test_blacklisted_ip_range_whitelisted_ip(self):
  303. """
  304. Blacklisted but then subsequently whitelisted IP addresses can be
  305. spidered.
  306. """
  307. self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
  308. request, channel = self.make_request(
  309. "GET", "url_preview?url=http://example.com", shorthand=False
  310. )
  311. request.render(self.preview_url)
  312. self.pump()
  313. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  314. server = AccumulatingProtocol()
  315. server.makeConnection(FakeTransport(client, self.reactor))
  316. client.makeConnection(FakeTransport(server, self.reactor))
  317. client.dataReceived(
  318. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  319. % (len(self.end_content),)
  320. + self.end_content
  321. )
  322. self.pump()
  323. self.assertEqual(channel.code, 200)
  324. self.assertEqual(
  325. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  326. )
  327. def test_blacklisted_ip_with_external_ip(self):
  328. """
  329. If a hostname resolves a blacklisted IP, even if there's a
  330. non-blacklisted one, it will be rejected.
  331. """
  332. # Hardcode the URL resolving to the IP we want.
  333. self.lookups["example.com"] = [
  334. (IPv4Address, "1.1.1.2"),
  335. (IPv4Address, "8.8.8.8"),
  336. ]
  337. request, channel = self.make_request(
  338. "GET", "url_preview?url=http://example.com", shorthand=False
  339. )
  340. request.render(self.preview_url)
  341. self.pump()
  342. self.assertEqual(channel.code, 502)
  343. self.assertEqual(
  344. channel.json_body,
  345. {
  346. "errcode": "M_UNKNOWN",
  347. "error": "DNS resolution failure during URL preview generation",
  348. },
  349. )
  350. def test_blacklisted_ipv6_specific(self):
  351. """
  352. Blacklisted IP addresses, found via DNS, are not spidered.
  353. """
  354. self.lookups["example.com"] = [
  355. (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
  356. ]
  357. request, channel = self.make_request(
  358. "GET", "url_preview?url=http://example.com", shorthand=False
  359. )
  360. request.render(self.preview_url)
  361. self.pump()
  362. # No requests made.
  363. self.assertEqual(len(self.reactor.tcpClients), 0)
  364. self.assertEqual(channel.code, 502)
  365. self.assertEqual(
  366. channel.json_body,
  367. {
  368. "errcode": "M_UNKNOWN",
  369. "error": "DNS resolution failure during URL preview generation",
  370. },
  371. )
  372. def test_blacklisted_ipv6_range(self):
  373. """
  374. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  375. """
  376. self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
  377. request, channel = self.make_request(
  378. "GET", "url_preview?url=http://example.com", shorthand=False
  379. )
  380. request.render(self.preview_url)
  381. self.pump()
  382. self.assertEqual(channel.code, 502)
  383. self.assertEqual(
  384. channel.json_body,
  385. {
  386. "errcode": "M_UNKNOWN",
  387. "error": "DNS resolution failure during URL preview generation",
  388. },
  389. )
  390. def test_OPTIONS(self):
  391. """
  392. OPTIONS returns the OPTIONS.
  393. """
  394. request, channel = self.make_request(
  395. "OPTIONS", "url_preview?url=http://example.com", shorthand=False
  396. )
  397. request.render(self.preview_url)
  398. self.pump()
  399. self.assertEqual(channel.code, 200)
  400. self.assertEqual(channel.json_body, {})