test_media_storage.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import shutil
  17. import tempfile
  18. from binascii import unhexlify
  19. from io import BytesIO
  20. from typing import Optional
  21. from urllib import parse
  22. from mock import Mock
  23. import attr
  24. from parameterized import parameterized_class
  25. from PIL import Image as Image
  26. from twisted.internet import defer
  27. from twisted.internet.defer import Deferred
  28. from synapse.logging.context import make_deferred_yieldable
  29. from synapse.rest.media.v1._base import FileInfo
  30. from synapse.rest.media.v1.filepath import MediaFilePaths
  31. from synapse.rest.media.v1.media_storage import MediaStorage
  32. from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
  33. from tests import unittest
  34. class MediaStorageTests(unittest.HomeserverTestCase):
  35. needs_threadpool = True
  36. def prepare(self, reactor, clock, hs):
  37. self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
  38. self.addCleanup(shutil.rmtree, self.test_dir)
  39. self.primary_base_path = os.path.join(self.test_dir, "primary")
  40. self.secondary_base_path = os.path.join(self.test_dir, "secondary")
  41. hs.config.media_store_path = self.primary_base_path
  42. storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
  43. self.filepaths = MediaFilePaths(self.primary_base_path)
  44. self.media_storage = MediaStorage(
  45. hs, self.primary_base_path, self.filepaths, storage_providers
  46. )
  47. def test_ensure_media_is_in_local_cache(self):
  48. media_id = "some_media_id"
  49. test_body = "Test\n"
  50. # First we create a file that is in a storage provider but not in the
  51. # local primary media store
  52. rel_path = self.filepaths.local_media_filepath_rel(media_id)
  53. secondary_path = os.path.join(self.secondary_base_path, rel_path)
  54. os.makedirs(os.path.dirname(secondary_path))
  55. with open(secondary_path, "w") as f:
  56. f.write(test_body)
  57. # Now we run ensure_media_is_in_local_cache, which should copy the file
  58. # to the local cache.
  59. file_info = FileInfo(None, media_id)
  60. # This uses a real blocking threadpool so we have to wait for it to be
  61. # actually done :/
  62. x = defer.ensureDeferred(
  63. self.media_storage.ensure_media_is_in_local_cache(file_info)
  64. )
  65. # Hotloop until the threadpool does its job...
  66. self.wait_on_thread(x)
  67. local_path = self.get_success(x)
  68. self.assertTrue(os.path.exists(local_path))
  69. # Asserts the file is under the expected local cache directory
  70. self.assertEquals(
  71. os.path.commonprefix([self.primary_base_path, local_path]),
  72. self.primary_base_path,
  73. )
  74. with open(local_path) as f:
  75. body = f.read()
  76. self.assertEqual(test_body, body)
  77. @attr.s
  78. class _TestImage:
  79. """An image for testing thumbnailing with the expected results
  80. Attributes:
  81. data: The raw image to thumbnail
  82. content_type: The type of the image as a content type, e.g. "image/png"
  83. extension: The extension associated with the format, e.g. ".png"
  84. expected_cropped: The expected bytes from cropped thumbnailing, or None if
  85. test should just check for success.
  86. expected_scaled: The expected bytes from scaled thumbnailing, or None if
  87. test should just check for a valid image returned.
  88. """
  89. data = attr.ib(type=bytes)
  90. content_type = attr.ib(type=bytes)
  91. extension = attr.ib(type=bytes)
  92. expected_cropped = attr.ib(type=Optional[bytes])
  93. expected_scaled = attr.ib(type=Optional[bytes])
  94. expected_found = attr.ib(default=True, type=bool)
  95. @parameterized_class(
  96. ("test_image",),
  97. [
  98. # smoll png
  99. (
  100. _TestImage(
  101. unhexlify(
  102. b"89504e470d0a1a0a0000000d4948445200000001000000010806"
  103. b"0000001f15c4890000000a49444154789c63000100000500010d"
  104. b"0a2db40000000049454e44ae426082"
  105. ),
  106. b"image/png",
  107. b".png",
  108. unhexlify(
  109. b"89504e470d0a1a0a0000000d4948445200000020000000200806"
  110. b"000000737a7af40000001a49444154789cedc101010000008220"
  111. b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
  112. b"44ae426082"
  113. ),
  114. unhexlify(
  115. b"89504e470d0a1a0a0000000d4948445200000001000000010806"
  116. b"0000001f15c4890000000d49444154789c636060606000000005"
  117. b"0001a5f645400000000049454e44ae426082"
  118. ),
  119. ),
  120. ),
  121. # small lossless webp
  122. (
  123. _TestImage(
  124. unhexlify(
  125. b"524946461a000000574542505650384c0d0000002f0000001007"
  126. b"1011118888fe0700"
  127. ),
  128. b"image/webp",
  129. b".webp",
  130. None,
  131. None,
  132. ),
  133. ),
  134. # an empty file
  135. (_TestImage(b"", b"image/gif", b".gif", None, None, False,),),
  136. ],
  137. )
  138. class MediaRepoTests(unittest.HomeserverTestCase):
  139. hijack_auth = True
  140. user_id = "@test:user"
  141. def make_homeserver(self, reactor, clock):
  142. self.fetches = []
  143. def get_file(destination, path, output_stream, args=None, max_size=None):
  144. """
  145. Returns tuple[int,dict,str,int] of file length, response headers,
  146. absolute URI, and response code.
  147. """
  148. def write_to(r):
  149. data, response = r
  150. output_stream.write(data)
  151. return response
  152. d = Deferred()
  153. d.addCallback(write_to)
  154. self.fetches.append((d, destination, path, args))
  155. return make_deferred_yieldable(d)
  156. client = Mock()
  157. client.get_file = get_file
  158. self.storage_path = self.mktemp()
  159. self.media_store_path = self.mktemp()
  160. os.mkdir(self.storage_path)
  161. os.mkdir(self.media_store_path)
  162. config = self.default_config()
  163. config["media_store_path"] = self.media_store_path
  164. config["thumbnail_requirements"] = {}
  165. config["max_image_pixels"] = 2000000
  166. provider_config = {
  167. "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
  168. "store_local": True,
  169. "store_synchronous": False,
  170. "store_remote": True,
  171. "config": {"directory": self.storage_path},
  172. }
  173. config["media_storage_providers"] = [provider_config]
  174. hs = self.setup_test_homeserver(config=config, http_client=client)
  175. return hs
  176. def prepare(self, reactor, clock, hs):
  177. self.media_repo = hs.get_media_repository_resource()
  178. self.download_resource = self.media_repo.children[b"download"]
  179. self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
  180. self.media_id = "example.com/12345"
  181. def _req(self, content_disposition):
  182. request, channel = self.make_request("GET", self.media_id, shorthand=False)
  183. request.render(self.download_resource)
  184. self.pump()
  185. # We've made one fetch, to example.com, using the media URL, and asking
  186. # the other server not to do a remote fetch
  187. self.assertEqual(len(self.fetches), 1)
  188. self.assertEqual(self.fetches[0][1], "example.com")
  189. self.assertEqual(
  190. self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
  191. )
  192. self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
  193. headers = {
  194. b"Content-Length": [b"%d" % (len(self.test_image.data))],
  195. b"Content-Type": [self.test_image.content_type],
  196. }
  197. if content_disposition:
  198. headers[b"Content-Disposition"] = [content_disposition]
  199. self.fetches[0][0].callback(
  200. (self.test_image.data, (len(self.test_image.data), headers))
  201. )
  202. self.pump()
  203. self.assertEqual(channel.code, 200)
  204. return channel
  205. def test_disposition_filename_ascii(self):
  206. """
  207. If the filename is filename=<ascii> then Synapse will decode it as an
  208. ASCII string, and use filename= in the response.
  209. """
  210. channel = self._req(b"inline; filename=out" + self.test_image.extension)
  211. headers = channel.headers
  212. self.assertEqual(
  213. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  214. )
  215. self.assertEqual(
  216. headers.getRawHeaders(b"Content-Disposition"),
  217. [b"inline; filename=out" + self.test_image.extension],
  218. )
  219. def test_disposition_filenamestar_utf8escaped(self):
  220. """
  221. If the filename is filename=*utf8''<utf8 escaped> then Synapse will
  222. correctly decode it as the UTF-8 string, and use filename* in the
  223. response.
  224. """
  225. filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
  226. channel = self._req(
  227. b"inline; filename*=utf-8''" + filename + self.test_image.extension
  228. )
  229. headers = channel.headers
  230. self.assertEqual(
  231. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  232. )
  233. self.assertEqual(
  234. headers.getRawHeaders(b"Content-Disposition"),
  235. [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
  236. )
  237. def test_disposition_none(self):
  238. """
  239. If there is no filename, one isn't passed on in the Content-Disposition
  240. of the request.
  241. """
  242. channel = self._req(None)
  243. headers = channel.headers
  244. self.assertEqual(
  245. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  246. )
  247. self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
  248. def test_thumbnail_crop(self):
  249. self._test_thumbnail(
  250. "crop", self.test_image.expected_cropped, self.test_image.expected_found
  251. )
  252. def test_thumbnail_scale(self):
  253. self._test_thumbnail(
  254. "scale", self.test_image.expected_scaled, self.test_image.expected_found
  255. )
  256. def _test_thumbnail(self, method, expected_body, expected_found):
  257. params = "?width=32&height=32&method=" + method
  258. request, channel = self.make_request(
  259. "GET", self.media_id + params, shorthand=False
  260. )
  261. request.render(self.thumbnail_resource)
  262. self.pump()
  263. headers = {
  264. b"Content-Length": [b"%d" % (len(self.test_image.data))],
  265. b"Content-Type": [self.test_image.content_type],
  266. }
  267. self.fetches[0][0].callback(
  268. (self.test_image.data, (len(self.test_image.data), headers))
  269. )
  270. self.pump()
  271. if expected_found:
  272. self.assertEqual(channel.code, 200)
  273. if expected_body is not None:
  274. self.assertEqual(
  275. channel.result["body"], expected_body, channel.result["body"]
  276. )
  277. else:
  278. # ensure that the result is at least some valid image
  279. Image.open(BytesIO(channel.result["body"]))
  280. else:
  281. # A 404 with a JSON body.
  282. self.assertEqual(channel.code, 404)
  283. self.assertEqual(
  284. channel.json_body,
  285. {
  286. "errcode": "M_NOT_FOUND",
  287. "error": "Not found [b'example.com', b'12345?width=32&height=32&method=%s']"
  288. % method,
  289. },
  290. )