test_url_preview.py 22 KB


  1. # Copyright 2018 New Vector Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import os
  16. import re
  17. from unittest.mock import patch
  18. from twisted.internet._resolver import HostResolution
  19. from twisted.internet.address import IPv4Address, IPv6Address
  20. from twisted.internet.error import DNSLookupError
  21. from twisted.test.proto_helpers import AccumulatingProtocol
  22. from tests import unittest
  23. from tests.server import FakeTransport
  24. try:
  25. import lxml
  26. except ImportError:
  27. lxml = None
  28. class URLPreviewTests(unittest.HomeserverTestCase):
  29. if not lxml:
  30. skip = "url preview feature requires lxml"
  31. hijack_auth = True
  32. user_id = "@test:user"
  33. end_content = (
  34. b"<html><head>"
  35. b'<meta property="og:title" content="~matrix~" />'
  36. b'<meta property="og:description" content="hi" />'
  37. b"</head></html>"
  38. )
  39. def make_homeserver(self, reactor, clock):
  40. config = self.default_config()
  41. config["url_preview_enabled"] = True
  42. config["max_spider_size"] = 9999999
  43. config["url_preview_ip_range_blacklist"] = (
  44. "192.168.1.1",
  45. "1.0.0.0/8",
  46. "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff",
  47. "2001:800::/21",
  48. )
  49. config["url_preview_ip_range_whitelist"] = ("1.1.1.1",)
  50. config["url_preview_url_blacklist"] = []
  51. config["url_preview_accept_language"] = [
  52. "en-UK",
  53. "en-US;q=0.9",
  54. "fr;q=0.8",
  55. "*;q=0.7",
  56. ]
  57. self.storage_path = self.mktemp()
  58. self.media_store_path = self.mktemp()
  59. os.mkdir(self.storage_path)
  60. os.mkdir(self.media_store_path)
  61. config["media_store_path"] = self.media_store_path
  62. provider_config = {
  63. "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
  64. "store_local": True,
  65. "store_synchronous": False,
  66. "store_remote": True,
  67. "config": {"directory": self.storage_path},
  68. }
  69. config["media_storage_providers"] = [provider_config]
  70. hs = self.setup_test_homeserver(config=config)
  71. return hs
  72. def prepare(self, reactor, clock, hs):
  73. self.media_repo = hs.get_media_repository_resource()
  74. self.preview_url = self.media_repo.children[b"preview_url"]
  75. self.lookups = {}
  76. class Resolver:
  77. def resolveHostName(
  78. _self,
  79. resolutionReceiver,
  80. hostName,
  81. portNumber=0,
  82. addressTypes=None,
  83. transportSemantics="TCP",
  84. ):
  85. resolution = HostResolution(hostName)
  86. resolutionReceiver.resolutionBegan(resolution)
  87. if hostName not in self.lookups:
  88. raise DNSLookupError("OH NO")
  89. for i in self.lookups[hostName]:
  90. resolutionReceiver.addressResolved(i[0]("TCP", i[1], portNumber))
  91. resolutionReceiver.resolutionComplete()
  92. return resolutionReceiver
  93. self.reactor.nameResolver = Resolver()
  94. def create_test_resource(self):
  95. return self.hs.get_media_repository_resource()
  96. def test_cache_returns_correct_type(self):
  97. self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
  98. channel = self.make_request(
  99. "GET",
  100. "preview_url?url=http://matrix.org",
  101. shorthand=False,
  102. await_result=False,
  103. )
  104. self.pump()
  105. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  106. server = AccumulatingProtocol()
  107. server.makeConnection(FakeTransport(client, self.reactor))
  108. client.makeConnection(FakeTransport(server, self.reactor))
  109. client.dataReceived(
  110. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  111. % (len(self.end_content),)
  112. + self.end_content
  113. )
  114. self.pump()
  115. self.assertEqual(channel.code, 200)
  116. self.assertEqual(
  117. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  118. )
  119. # Check the cache returns the correct response
  120. channel = self.make_request(
  121. "GET", "preview_url?url=http://matrix.org", shorthand=False
  122. )
  123. # Check the cache response has the same content
  124. self.assertEqual(channel.code, 200)
  125. self.assertEqual(
  126. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  127. )
  128. # Clear the in-memory cache
  129. self.assertIn("http://matrix.org", self.preview_url._cache)
  130. self.preview_url._cache.pop("http://matrix.org")
  131. self.assertNotIn("http://matrix.org", self.preview_url._cache)
  132. # Check the database cache returns the correct response
  133. channel = self.make_request(
  134. "GET", "preview_url?url=http://matrix.org", shorthand=False
  135. )
  136. # Check the cache response has the same content
  137. self.assertEqual(channel.code, 200)
  138. self.assertEqual(
  139. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  140. )
  141. def test_non_ascii_preview_httpequiv(self):
  142. self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
  143. end_content = (
  144. b"<html><head>"
  145. b'<meta http-equiv="Content-Type" content="text/html; charset=windows-1251"/>'
  146. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  147. b'<meta property="og:description" content="hi" />'
  148. b"</head></html>"
  149. )
  150. channel = self.make_request(
  151. "GET",
  152. "preview_url?url=http://matrix.org",
  153. shorthand=False,
  154. await_result=False,
  155. )
  156. self.pump()
  157. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  158. server = AccumulatingProtocol()
  159. server.makeConnection(FakeTransport(client, self.reactor))
  160. client.makeConnection(FakeTransport(server, self.reactor))
  161. client.dataReceived(
  162. (
  163. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  164. b'Content-Type: text/html; charset="utf8"\r\n\r\n'
  165. )
  166. % (len(end_content),)
  167. + end_content
  168. )
  169. self.pump()
  170. self.assertEqual(channel.code, 200)
  171. self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
  172. def test_non_ascii_preview_content_type(self):
  173. self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
  174. end_content = (
  175. b"<html><head>"
  176. b'<meta property="og:title" content="\xe4\xea\xe0" />'
  177. b'<meta property="og:description" content="hi" />'
  178. b"</head></html>"
  179. )
  180. channel = self.make_request(
  181. "GET",
  182. "preview_url?url=http://matrix.org",
  183. shorthand=False,
  184. await_result=False,
  185. )
  186. self.pump()
  187. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  188. server = AccumulatingProtocol()
  189. server.makeConnection(FakeTransport(client, self.reactor))
  190. client.makeConnection(FakeTransport(server, self.reactor))
  191. client.dataReceived(
  192. (
  193. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  194. b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
  195. )
  196. % (len(end_content),)
  197. + end_content
  198. )
  199. self.pump()
  200. self.assertEqual(channel.code, 200)
  201. self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
  202. def test_overlong_title(self):
  203. self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
  204. end_content = (
  205. b"<html><head>"
  206. b"<title>" + b"x" * 2000 + b"</title>"
  207. b'<meta property="og:description" content="hi" />'
  208. b"</head></html>"
  209. )
  210. channel = self.make_request(
  211. "GET",
  212. "preview_url?url=http://matrix.org",
  213. shorthand=False,
  214. await_result=False,
  215. )
  216. self.pump()
  217. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  218. server = AccumulatingProtocol()
  219. server.makeConnection(FakeTransport(client, self.reactor))
  220. client.makeConnection(FakeTransport(server, self.reactor))
  221. client.dataReceived(
  222. (
  223. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  224. b'Content-Type: text/html; charset="windows-1251"\r\n\r\n'
  225. )
  226. % (len(end_content),)
  227. + end_content
  228. )
  229. self.pump()
  230. self.assertEqual(channel.code, 200)
  231. res = channel.json_body
  232. # We should only see the `og:description` field, as `title` is too long and should be stripped out
  233. self.assertCountEqual(["og:description"], res.keys())
  234. def test_ipaddr(self):
  235. """
  236. IP addresses can be previewed directly.
  237. """
  238. self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
  239. channel = self.make_request(
  240. "GET",
  241. "preview_url?url=http://example.com",
  242. shorthand=False,
  243. await_result=False,
  244. )
  245. self.pump()
  246. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  247. server = AccumulatingProtocol()
  248. server.makeConnection(FakeTransport(client, self.reactor))
  249. client.makeConnection(FakeTransport(server, self.reactor))
  250. client.dataReceived(
  251. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  252. % (len(self.end_content),)
  253. + self.end_content
  254. )
  255. self.pump()
  256. self.assertEqual(channel.code, 200)
  257. self.assertEqual(
  258. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  259. )
  260. def test_blacklisted_ip_specific(self):
  261. """
  262. Blacklisted IP addresses, found via DNS, are not spidered.
  263. """
  264. self.lookups["example.com"] = [(IPv4Address, "192.168.1.1")]
  265. channel = self.make_request(
  266. "GET", "preview_url?url=http://example.com", shorthand=False
  267. )
  268. # No requests made.
  269. self.assertEqual(len(self.reactor.tcpClients), 0)
  270. self.assertEqual(channel.code, 502)
  271. self.assertEqual(
  272. channel.json_body,
  273. {
  274. "errcode": "M_UNKNOWN",
  275. "error": "DNS resolution failure during URL preview generation",
  276. },
  277. )
  278. def test_blacklisted_ip_range(self):
  279. """
  280. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  281. """
  282. self.lookups["example.com"] = [(IPv4Address, "1.1.1.2")]
  283. channel = self.make_request(
  284. "GET", "preview_url?url=http://example.com", shorthand=False
  285. )
  286. self.assertEqual(channel.code, 502)
  287. self.assertEqual(
  288. channel.json_body,
  289. {
  290. "errcode": "M_UNKNOWN",
  291. "error": "DNS resolution failure during URL preview generation",
  292. },
  293. )
  294. def test_blacklisted_ip_specific_direct(self):
  295. """
  296. Blacklisted IP addresses, accessed directly, are not spidered.
  297. """
  298. channel = self.make_request(
  299. "GET", "preview_url?url=http://192.168.1.1", shorthand=False
  300. )
  301. # No requests made.
  302. self.assertEqual(len(self.reactor.tcpClients), 0)
  303. self.assertEqual(
  304. channel.json_body,
  305. {
  306. "errcode": "M_UNKNOWN",
  307. "error": "IP address blocked by IP blacklist entry",
  308. },
  309. )
  310. self.assertEqual(channel.code, 403)
  311. def test_blacklisted_ip_range_direct(self):
  312. """
  313. Blacklisted IP ranges, accessed directly, are not spidered.
  314. """
  315. channel = self.make_request(
  316. "GET", "preview_url?url=http://1.1.1.2", shorthand=False
  317. )
  318. self.assertEqual(channel.code, 403)
  319. self.assertEqual(
  320. channel.json_body,
  321. {
  322. "errcode": "M_UNKNOWN",
  323. "error": "IP address blocked by IP blacklist entry",
  324. },
  325. )
  326. def test_blacklisted_ip_range_whitelisted_ip(self):
  327. """
  328. Blacklisted but then subsequently whitelisted IP addresses can be
  329. spidered.
  330. """
  331. self.lookups["example.com"] = [(IPv4Address, "1.1.1.1")]
  332. channel = self.make_request(
  333. "GET",
  334. "preview_url?url=http://example.com",
  335. shorthand=False,
  336. await_result=False,
  337. )
  338. self.pump()
  339. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  340. server = AccumulatingProtocol()
  341. server.makeConnection(FakeTransport(client, self.reactor))
  342. client.makeConnection(FakeTransport(server, self.reactor))
  343. client.dataReceived(
  344. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  345. % (len(self.end_content),)
  346. + self.end_content
  347. )
  348. self.pump()
  349. self.assertEqual(channel.code, 200)
  350. self.assertEqual(
  351. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  352. )
  353. def test_blacklisted_ip_with_external_ip(self):
  354. """
  355. If a hostname resolves a blacklisted IP, even if there's a
  356. non-blacklisted one, it will be rejected.
  357. """
  358. # Hardcode the URL resolving to the IP we want.
  359. self.lookups["example.com"] = [
  360. (IPv4Address, "1.1.1.2"),
  361. (IPv4Address, "10.1.2.3"),
  362. ]
  363. channel = self.make_request(
  364. "GET", "preview_url?url=http://example.com", shorthand=False
  365. )
  366. self.assertEqual(channel.code, 502)
  367. self.assertEqual(
  368. channel.json_body,
  369. {
  370. "errcode": "M_UNKNOWN",
  371. "error": "DNS resolution failure during URL preview generation",
  372. },
  373. )
  374. def test_blacklisted_ipv6_specific(self):
  375. """
  376. Blacklisted IP addresses, found via DNS, are not spidered.
  377. """
  378. self.lookups["example.com"] = [
  379. (IPv6Address, "3fff:ffff:ffff:ffff:ffff:ffff:ffff:ffff")
  380. ]
  381. channel = self.make_request(
  382. "GET", "preview_url?url=http://example.com", shorthand=False
  383. )
  384. # No requests made.
  385. self.assertEqual(len(self.reactor.tcpClients), 0)
  386. self.assertEqual(channel.code, 502)
  387. self.assertEqual(
  388. channel.json_body,
  389. {
  390. "errcode": "M_UNKNOWN",
  391. "error": "DNS resolution failure during URL preview generation",
  392. },
  393. )
  394. def test_blacklisted_ipv6_range(self):
  395. """
  396. Blacklisted IP ranges, IPs found over DNS, are not spidered.
  397. """
  398. self.lookups["example.com"] = [(IPv6Address, "2001:800::1")]
  399. channel = self.make_request(
  400. "GET", "preview_url?url=http://example.com", shorthand=False
  401. )
  402. self.assertEqual(channel.code, 502)
  403. self.assertEqual(
  404. channel.json_body,
  405. {
  406. "errcode": "M_UNKNOWN",
  407. "error": "DNS resolution failure during URL preview generation",
  408. },
  409. )
  410. def test_OPTIONS(self):
  411. """
  412. OPTIONS returns the OPTIONS.
  413. """
  414. channel = self.make_request(
  415. "OPTIONS", "preview_url?url=http://example.com", shorthand=False
  416. )
  417. self.assertEqual(channel.code, 200)
  418. self.assertEqual(channel.json_body, {})
  419. def test_accept_language_config_option(self):
  420. """
  421. Accept-Language header is sent to the remote server
  422. """
  423. self.lookups["example.com"] = [(IPv4Address, "10.1.2.3")]
  424. # Build and make a request to the server
  425. channel = self.make_request(
  426. "GET",
  427. "preview_url?url=http://example.com",
  428. shorthand=False,
  429. await_result=False,
  430. )
  431. self.pump()
  432. # Extract Synapse's tcp client
  433. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  434. # Build a fake remote server to reply with
  435. server = AccumulatingProtocol()
  436. # Connect the two together
  437. server.makeConnection(FakeTransport(client, self.reactor))
  438. client.makeConnection(FakeTransport(server, self.reactor))
  439. # Tell Synapse that it has received some data from the remote server
  440. client.dataReceived(
  441. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: text/html\r\n\r\n"
  442. % (len(self.end_content),)
  443. + self.end_content
  444. )
  445. # Move the reactor along until we get a response on our original channel
  446. self.pump()
  447. self.assertEqual(channel.code, 200)
  448. self.assertEqual(
  449. channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
  450. )
  451. # Check that the server received the Accept-Language header as part
  452. # of the request from Synapse
  453. self.assertIn(
  454. (
  455. b"Accept-Language: en-UK\r\n"
  456. b"Accept-Language: en-US;q=0.9\r\n"
  457. b"Accept-Language: fr;q=0.8\r\n"
  458. b"Accept-Language: *;q=0.7"
  459. ),
  460. server.data,
  461. )
  462. def test_oembed_photo(self):
  463. """Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
  464. # Route the HTTP version to an HTTP endpoint so that the tests work.
  465. with patch.dict(
  466. "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
  467. {
  468. re.compile(
  469. r"http://twitter\.com/.+/status/.+"
  470. ): "http://publish.twitter.com/oembed",
  471. },
  472. clear=True,
  473. ):
  474. self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
  475. self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
  476. result = {
  477. "version": "1.0",
  478. "type": "photo",
  479. "url": "http://cdn.twitter.com/matrixdotorg",
  480. }
  481. oembed_content = json.dumps(result).encode("utf-8")
  482. end_content = (
  483. b"<html><head>"
  484. b"<title>Some Title</title>"
  485. b'<meta property="og:description" content="hi" />'
  486. b"</head></html>"
  487. )
  488. channel = self.make_request(
  489. "GET",
  490. "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
  491. shorthand=False,
  492. await_result=False,
  493. )
  494. self.pump()
  495. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  496. server = AccumulatingProtocol()
  497. server.makeConnection(FakeTransport(client, self.reactor))
  498. client.makeConnection(FakeTransport(server, self.reactor))
  499. client.dataReceived(
  500. (
  501. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  502. b'Content-Type: application/json; charset="utf8"\r\n\r\n'
  503. )
  504. % (len(oembed_content),)
  505. + oembed_content
  506. )
  507. self.pump()
  508. client = self.reactor.tcpClients[1][2].buildProtocol(None)
  509. server = AccumulatingProtocol()
  510. server.makeConnection(FakeTransport(client, self.reactor))
  511. client.makeConnection(FakeTransport(server, self.reactor))
  512. client.dataReceived(
  513. (
  514. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  515. b'Content-Type: text/html; charset="utf8"\r\n\r\n'
  516. )
  517. % (len(end_content),)
  518. + end_content
  519. )
  520. self.pump()
  521. self.assertEqual(channel.code, 200)
  522. self.assertEqual(
  523. channel.json_body, {"og:title": "Some Title", "og:description": "hi"}
  524. )
  525. def test_oembed_rich(self):
  526. """Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
  527. # Route the HTTP version to an HTTP endpoint so that the tests work.
  528. with patch.dict(
  529. "synapse.rest.media.v1.preview_url_resource._oembed_patterns",
  530. {
  531. re.compile(
  532. r"http://twitter\.com/.+/status/.+"
  533. ): "http://publish.twitter.com/oembed",
  534. },
  535. clear=True,
  536. ):
  537. self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
  538. result = {
  539. "version": "1.0",
  540. "type": "rich",
  541. "html": "<div>Content Preview</div>",
  542. }
  543. end_content = json.dumps(result).encode("utf-8")
  544. channel = self.make_request(
  545. "GET",
  546. "preview_url?url=http://twitter.com/matrixdotorg/status/12345",
  547. shorthand=False,
  548. await_result=False,
  549. )
  550. self.pump()
  551. client = self.reactor.tcpClients[0][2].buildProtocol(None)
  552. server = AccumulatingProtocol()
  553. server.makeConnection(FakeTransport(client, self.reactor))
  554. client.makeConnection(FakeTransport(server, self.reactor))
  555. client.dataReceived(
  556. (
  557. b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n"
  558. b'Content-Type: application/json; charset="utf8"\r\n\r\n'
  559. )
  560. % (len(end_content),)
  561. + end_content
  562. )
  563. self.pump()
  564. self.assertEqual(channel.code, 200)
  565. self.assertEqual(
  566. channel.json_body,
  567. {"og:title": None, "og:description": "Content Preview"},
  568. )