test_url_preview.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import attr
  17. from netaddr import IPSet
  18. from twisted.internet._resolver import HostResolution
  19. from twisted.internet.address import IPv4Address, IPv6Address
  20. from twisted.internet.error import DNSLookupError
  21. from twisted.python.failure import Failure
  22. from twisted.test.proto_helpers import AccumulatingProtocol
  23. from twisted.web._newclient import ResponseDone
  24. from synapse.config.repository import MediaStorageProviderConfig
  25. from synapse.util.module_loader import load_module
  26. from tests import unittest
  27. from tests.server import FakeTransport
  28. @attr.s
  29. class FakeResponse(object):
  30. version = attr.ib()
  31. code = attr.ib()
  32. phrase = attr.ib()
  33. headers = attr.ib()
  34. body = attr.ib()
  35. absoluteURI = attr.ib()
  36. @property
  37. def request(self):
  38. @attr.s
  39. class FakeTransport(object):
  40. absoluteURI = self.absoluteURI
  41. return FakeTransport()
  42. def deliverBody(self, protocol):
  43. protocol.dataReceived(self.body)
  44. protocol.connectionLost(Failure(ResponseDone()))
  45. class URLPreviewTests(unittest.HomeserverTestCase):
  46. hijack_auth = True
  47. user_id = "@test:user"
  48. end_content = (
  49. b'<html><head>'
  50. b'<meta property="og:title" content="~matrix~" />'
  51. b'<meta property="og:description" content="hi" />'
  52. b'</head></html>'
  53. )
  54. def make_homeserver(self, reactor, clock):
  55. self.storage_path = self.mktemp()
  56. os.mkdir(self.storage_path)
  57. config = self.default_config()
  58. config.url_preview_enabled = True
  59. config.max_spider_size = 9999999
  60. config.url_preview_ip_range_blacklist = IPSet(
  61. (
  62. "192.168.1.1",
  63. "1.0.0.0/8",
  64. "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
  65. "2001:800::/21",
  66. )
  67. )
  68. config.url_preview_ip_range_whitelist = IPSet(("1.1.1.1",))
  69. config.url_preview_url_blacklist = []
  70. config.media_store_path = self.storage_path
  71. provider_config = {
  72. "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
  73. "store_local": True,
  74. "store_synchronous": False,
  75. "store_remote": True,
  76. "config": {"directory": self.storage_path},
  77. }
  78. loaded = list(load_module(provider_config)) + [
  79. MediaStorageProviderConfig(False, False, False)
  80. ]
  81. config.media_storage_providers = [loaded]
  82. hs = self.setup_test_homeserver(config=config)
  83. return hs
  84. def prepare(self, reactor, clock, hs):
  85. self.media_repo = hs.get_media_repository_resource()
  86. self.preview_url = self.media_repo.children[b'preview_url']
  87. self.lookups = {}
  88. class Resolver(object):
  89. def resolveHostName(
  90. _self,
  91. resolutionReceiver,
  92. hostName,
  93. portNumber=0,
  94. addressTypes=None,
  95. transportSemantics='TCP',
  96. ):
  97. resolution = HostResolution(hostName)
  98. resolutionReceiver.resolutionBegan(resolution)
  99. if hostName not in self.lookups:
  100. raise DNSLookupError("OH NO")
  101. for i in self.lookups[hostName]:
  102. resolutionReceiver.addressResolved(i[0]('TCP', i[1], portNumber))
  103. resolutionReceiver.resolutionComplete()
  104. return resolutionReceiver
  105. self.reactor.nameResolver = Resolver()
  106. def test_cache_returns_correct_type(self):
  107. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  108. request, channel = self.make_request(
  109. "GET", "url_preview?url=http://matrix.org", shorthand=False
  110. )
  111. request.render(self.preview_url)
  112. self.pump()
  113. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  114. server = AccumulatingProtocol()
  115. server.makeConnection(FakeTransport(client, self.reactor))
  116. client.makeConnection(FakeTransport(server, self.reactor))
  117. client.dataReceived(
  118. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  119. % (len(self.end_content),)
  120. + self.end_content
  121. )
  122. self.pump()
  123. self.assertEqual(channel.code, 200)
  124. self.assertEqual(
  125. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  126. )
  127. # Check the cache returns the correct response
  128. request, channel = self.make_request(
  129. "GET", "url_preview?url=http://matrix.org", shorthand=False
  130. )
  131. request.render(self.preview_url)
  132. self.pump()
  133. # Check the cache response has the same content
  134. self.assertEqual(channel.code, 200)
  135. self.assertEqual(
  136. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  137. )
  138. # Clear the in-memory cache
  139. self.assertIn("http://matrix.org", self.preview_url._cache)
  140. self.preview_url._cache.pop("http://matrix.org")
  141. self.assertNotIn("http://matrix.org", self.preview_url._cache)
  142. # Check the database cache returns the correct response
  143. request, channel = self.make_request(
  144. "GET", "url_preview?url=http://matrix.org", shorthand=False
  145. )
  146. request.render(self.preview_url)
  147. self.pump()
  148. # Check the cache response has the same content
  149. self.assertEqual(channel.code, 200)
  150. self.assertEqual(
  151. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  152. )
  153. def test_non_ascii_preview_httpequiv(self):
  154. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  155. end_content = (
  156. b'<html><head>'
  157. b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
  158. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  159. b'<meta property="og:description" content="hi" />'
  160. b'</head></html>'
  161. )
  162. request, channel = self.make_request(
  163. "GET", "url_preview?url=http://matrix.org", shorthand=False
  164. )
  165. request.render(self.preview_url)
  166. self.pump()
  167. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  168. server = AccumulatingProtocol()
  169. server.makeConnection(FakeTransport(client, self.reactor))
  170. client.makeConnection(FakeTransport(server, self.reactor))
  171. client.dataReceived(
  172. (
  173. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  174. b"Content-Type: text/html; charset=\"utf8\"\r\n\r\n"
  175. )
  176. % (len(end_content),)
  177. + end_content
  178. )
  179. self.pump()
  180. self.assertEqual(channel.code, 200)
  181. self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
  182. def test_non_ascii_preview_content_type(self):
  183. self.lookups["matrix.org"] = [(IPv4Address, "8.8.8.8")]
  184. end_content = (
  185. b'<html><head>'
  186. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  187. b'<meta property="og:description" content="hi" />'
  188. b'</head></html>'
  189. )
  190. request, channel = self.make_request(
  191. "GET", "url_preview?url=http://matrix.org", shorthand=False
  192. )
  193. request.render(self.preview_url)
  194. self.pump()
  195. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  196. server = AccumulatingProtocol()
  197. server.makeConnection(FakeTransport(client, self.reactor))
  198. client.makeConnection(FakeTransport(server, self.reactor))
  199. client.dataReceived(
  200. (
  201. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  202. b"Content-Type: text/html; charset=\"windows-1251\"\r\n\r\n"
  203. )
  204. % (len(end_content),)
  205. + end_content
  206. )
  207. self.pump()
  208. self.assertEqual(channel.code, 200)
  209. self.assertEqual(channel.json_body["og:title"], u"\u0434\u043a\u0430")
  210. def test_ipaddr(self):
  211. """
  212. IP addresses can be previewed directly.
  213. """
  214. self.lookups["example.com"] = [(IPv4Address, "8.8.8.8")]
  215. request, channel = self.make_request(
  216. "GET", "url_preview?url=http://example.com", shorthand=False
  217. )
  218. request.render(self.preview_url)
  219. self.pump()
  220. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  221. server = AccumulatingProtocol()
  222. server.makeConnection(FakeTransport(client, self.reactor))
  223. client.makeConnection(FakeTransport(server, self.reactor))
  224. client.dataReceived(
  225. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  226. % (len(self.end_content),)
  227. + self.end_content
  228. )
  229. self.pump()
  230. self.assertEqual(channel.code, 200)
  231. self.assertEqual(
  232. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  233. )
  234. def test_blacklisted_ip_specific(self):
  235. """
  236. Blacklisted IP addresses, found via DNS, are not spidered.
  237. """
  238. self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
  239. request, channel = self.make_request(
  240. "GET", "url_preview?url=http://example.com", shorthand=False
  241. )
  242. request.render(self.preview_url)
  243. self.pump()
  244. # No requests made.
  245. self.assertEqual(len(self.reactor.tcpClients), 0)
  246. self.assertEqual(channel.code, 403)
  247. self.assertEqual(
  248. channel.json_body,
  249. {
  250. 'errcode': 'M_UNKNOWN',
  251. 'error': 'IP address blocked by IP blacklist entry',
  252. },
  253. )
  254. def test_blacklisted_ip_range(self):
  255. """
  256. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  257. """
  258. self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
  259. request, channel = self.make_request(
  260. "GET", "url_preview?url=http://example.com", shorthand=False
  261. )
  262. request.render(self.preview_url)
  263. self.pump()
  264. self.assertEqual(channel.code, 403)
  265. self.assertEqual(
  266. channel.json_body,
  267. {
  268. 'errcode': 'M_UNKNOWN',
  269. 'error': 'IP address blocked by IP blacklist entry',
  270. },
  271. )
  272. def test_blacklisted_ip_specific_direct(self):
  273. """
  274. Blacklisted IP addresses, accessed directly, are not spidered.
  275. """
  276. request, channel = self.make_request(
  277. "GET", "url_preview?url=http://192.168.1.1", shorthand=False
  278. )
  279. request.render(self.preview_url)
  280. self.pump()
  281. # No requests made.
  282. self.assertEqual(len(self.reactor.tcpClients), 0)
  283. self.assertEqual(channel.code, 403)
  284. self.assertEqual(
  285. channel.json_body,
  286. {
  287. 'errcode': 'M_UNKNOWN',
  288. 'error': 'IP address blocked by IP blacklist entry',
  289. },
  290. )
  291. def test_blacklisted_ip_range_direct(self):
  292. """
  293. Blacklisted IP ranges, accessed directly, are not spidered.
  294. """
  295. request, channel = self.make_request(
  296. "GET", "url_preview?url=http://1.1.1.2", shorthand=False
  297. )
  298. request.render(self.preview_url)
  299. self.pump()
  300. self.assertEqual(channel.code, 403)
  301. self.assertEqual(
  302. channel.json_body,
  303. {
  304. 'errcode': 'M_UNKNOWN',
  305. 'error': 'IP address blocked by IP blacklist entry',
  306. },
  307. )
  308. def test_blacklisted_ip_range_whitelisted_ip(self):
  309. """
  310. Blacklisted but then subsequently whitelisted IP addresses can be
  311. spidered.
  312. """
  313. self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
  314. request, channel = self.make_request(
  315. "GET", "url_preview?url=http://example.com", shorthand=False
  316. )
  317. request.render(self.preview_url)
  318. self.pump()
  319. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  320. server = AccumulatingProtocol()
  321. server.makeConnection(FakeTransport(client, self.reactor))
  322. client.makeConnection(FakeTransport(server, self.reactor))
  323. client.dataReceived(
  324. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  325. % (len(self.end_content),)
  326. + self.end_content
  327. )
  328. self.pump()
  329. self.assertEqual(channel.code, 200)
  330. self.assertEqual(
  331. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  332. )
  333. def test_blacklisted_ip_with_external_ip(self):
  334. """
  335. If a hostname resolves a blacklisted IP, even if there's a
  336. non-blacklisted one, it will be rejected.
  337. """
  338. # Hardcode the URL resolving to the IP we want.
  339. self.lookups[u"example.com"] = [
  340. (IPv4Address, "1.1.1.2"),
  341. (IPv4Address, "8.8.8.8"),
  342. ]
  343. request, channel = self.make_request(
  344. "GET", "url_preview?url=http://example.com", shorthand=False
  345. )
  346. request.render(self.preview_url)
  347. self.pump()
  348. self.assertEqual(channel.code, 403)
  349. self.assertEqual(
  350. channel.json_body,
  351. {
  352. 'errcode': 'M_UNKNOWN',
  353. 'error': 'IP address blocked by IP blacklist entry',
  354. },
  355. )
  356. def test_blacklisted_ipv6_specific(self):
  357. """
  358. Blacklisted IP addresses, found via DNS, are not spidered.
  359. """
  360. self.lookups["example.com"] = [
  361. (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
  362. ]
  363. request, channel = self.make_request(
  364. "GET", "url_preview?url=http://example.com", shorthand=False
  365. )
  366. request.render(self.preview_url)
  367. self.pump()
  368. # No requests made.
  369. self.assertEqual(len(self.reactor.tcpClients), 0)
  370. self.assertEqual(channel.code, 403)
  371. self.assertEqual(
  372. channel.json_body,
  373. {
  374. 'errcode': 'M_UNKNOWN',
  375. 'error': 'IP address blocked by IP blacklist entry',
  376. },
  377. )
  378. def test_blacklisted_ipv6_range(self):
  379. """
  380. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  381. """
  382. self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
  383. request, channel = self.make_request(
  384. "GET", "url_preview?url=http://example.com", shorthand=False
  385. )
  386. request.render(self.preview_url)
  387. self.pump()
  388. self.assertEqual(channel.code, 403)
  389. self.assertEqual(
  390. channel.json_body,
  391. {
  392. 'errcode': 'M_UNKNOWN',
  393. 'error': 'IP address blocked by IP blacklist entry',
  394. },
  395. )