test_media_storage.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import os
  16. import shutil
  17. import tempfile
  18. from binascii import unhexlify
  19. from io import BytesIO
  20. from typing import Optional
  21. from urllib import parse
  22. from mock import Mock
  23. import attr
  24. from parameterized import parameterized_class
  25. from PIL import Image as Image
  26. from twisted.internet import defer
  27. from twisted.internet.defer import Deferred
  28. from synapse.logging.context import make_deferred_yieldable
  29. from synapse.rest.media.v1._base import FileInfo
  30. from synapse.rest.media.v1.filepath import MediaFilePaths
  31. from synapse.rest.media.v1.media_storage import MediaStorage
  32. from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
  33. from tests import unittest
  34. class MediaStorageTests(unittest.HomeserverTestCase):
  35. needs_threadpool = True
  36. def prepare(self, reactor, clock, hs):
  37. self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
  38. self.addCleanup(shutil.rmtree, self.test_dir)
  39. self.primary_base_path = os.path.join(self.test_dir, "primary")
  40. self.secondary_base_path = os.path.join(self.test_dir, "secondary")
  41. hs.config.media_store_path = self.primary_base_path
  42. storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)]
  43. self.filepaths = MediaFilePaths(self.primary_base_path)
  44. self.media_storage = MediaStorage(
  45. hs, self.primary_base_path, self.filepaths, storage_providers
  46. )
  47. def test_ensure_media_is_in_local_cache(self):
  48. media_id = "some_media_id"
  49. test_body = "Test\n"
  50. # First we create a file that is in a storage provider but not in the
  51. # local primary media store
  52. rel_path = self.filepaths.local_media_filepath_rel(media_id)
  53. secondary_path = os.path.join(self.secondary_base_path, rel_path)
  54. os.makedirs(os.path.dirname(secondary_path))
  55. with open(secondary_path, "w") as f:
  56. f.write(test_body)
  57. # Now we run ensure_media_is_in_local_cache, which should copy the file
  58. # to the local cache.
  59. file_info = FileInfo(None, media_id)
  60. # This uses a real blocking threadpool so we have to wait for it to be
  61. # actually done :/
  62. x = defer.ensureDeferred(
  63. self.media_storage.ensure_media_is_in_local_cache(file_info)
  64. )
  65. # Hotloop until the threadpool does its job...
  66. self.wait_on_thread(x)
  67. local_path = self.get_success(x)
  68. self.assertTrue(os.path.exists(local_path))
  69. # Asserts the file is under the expected local cache directory
  70. self.assertEquals(
  71. os.path.commonprefix([self.primary_base_path, local_path]),
  72. self.primary_base_path,
  73. )
  74. with open(local_path) as f:
  75. body = f.read()
  76. self.assertEqual(test_body, body)
  77. @attr.s
  78. class _TestImage:
  79. """An image for testing thumbnailing with the expected results
  80. Attributes:
  81. data: The raw image to thumbnail
  82. content_type: The type of the image as a content type, e.g. "image/png"
  83. extension: The extension associated with the format, e.g. ".png"
  84. expected_cropped: The expected bytes from cropped thumbnailing, or None if
  85. test should just check for success.
  86. expected_scaled: The expected bytes from scaled thumbnailing, or None if
  87. test should just check for a valid image returned.
  88. """
  89. data = attr.ib(type=bytes)
  90. content_type = attr.ib(type=bytes)
  91. extension = attr.ib(type=bytes)
  92. expected_cropped = attr.ib(type=Optional[bytes])
  93. expected_scaled = attr.ib(type=Optional[bytes])
  94. @parameterized_class(
  95. ("test_image",),
  96. [
  97. # smol png
  98. (
  99. _TestImage(
  100. unhexlify(
  101. b"89504e470d0a1a0a0000000d4948445200000001000000010806"
  102. b"0000001f15c4890000000a49444154789c63000100000500010d"
  103. b"0a2db40000000049454e44ae426082"
  104. ),
  105. b"image/png",
  106. b".png",
  107. unhexlify(
  108. b"89504e470d0a1a0a0000000d4948445200000020000000200806"
  109. b"000000737a7af40000001a49444154789cedc101010000008220"
  110. b"ffaf6e484001000000ef0610200001194334ee0000000049454e"
  111. b"44ae426082"
  112. ),
  113. unhexlify(
  114. b"89504e470d0a1a0a0000000d4948445200000001000000010806"
  115. b"0000001f15c4890000000d49444154789c636060606000000005"
  116. b"0001a5f645400000000049454e44ae426082"
  117. ),
  118. ),
  119. ),
  120. # small lossless webp
  121. (
  122. _TestImage(
  123. unhexlify(
  124. b"524946461a000000574542505650384c0d0000002f0000001007"
  125. b"1011118888fe0700"
  126. ),
  127. b"image/webp",
  128. b".webp",
  129. None,
  130. None,
  131. ),
  132. ),
  133. ],
  134. )
  135. class MediaRepoTests(unittest.HomeserverTestCase):
  136. hijack_auth = True
  137. user_id = "@test:user"
  138. def make_homeserver(self, reactor, clock):
  139. self.fetches = []
  140. def get_file(destination, path, output_stream, args=None, max_size=None):
  141. """
  142. Returns tuple[int,dict,str,int] of file length, response headers,
  143. absolute URI, and response code.
  144. """
  145. def write_to(r):
  146. data, response = r
  147. output_stream.write(data)
  148. return response
  149. d = Deferred()
  150. d.addCallback(write_to)
  151. self.fetches.append((d, destination, path, args))
  152. return make_deferred_yieldable(d)
  153. client = Mock()
  154. client.get_file = get_file
  155. self.storage_path = self.mktemp()
  156. self.media_store_path = self.mktemp()
  157. os.mkdir(self.storage_path)
  158. os.mkdir(self.media_store_path)
  159. config = self.default_config()
  160. config["media_store_path"] = self.media_store_path
  161. config["thumbnail_requirements"] = {}
  162. config["max_image_pixels"] = 2000000
  163. provider_config = {
  164. "module": "synapse.rest.media.v1.storage_provider.FileStorageProviderBackend",
  165. "store_local": True,
  166. "store_synchronous": False,
  167. "store_remote": True,
  168. "config": {"directory": self.storage_path},
  169. }
  170. config["media_storage_providers"] = [provider_config]
  171. hs = self.setup_test_homeserver(config=config, http_client=client)
  172. return hs
  173. def prepare(self, reactor, clock, hs):
  174. self.media_repo = hs.get_media_repository_resource()
  175. self.download_resource = self.media_repo.children[b"download"]
  176. self.thumbnail_resource = self.media_repo.children[b"thumbnail"]
  177. self.media_id = "example.com/12345"
  178. def _req(self, content_disposition):
  179. request, channel = self.make_request("GET", self.media_id, shorthand=False)
  180. request.render(self.download_resource)
  181. self.pump()
  182. # We've made one fetch, to example.com, using the media URL, and asking
  183. # the other server not to do a remote fetch
  184. self.assertEqual(len(self.fetches), 1)
  185. self.assertEqual(self.fetches[0][1], "example.com")
  186. self.assertEqual(
  187. self.fetches[0][2], "/_matrix/media/r0/download/" + self.media_id
  188. )
  189. self.assertEqual(self.fetches[0][3], {"allow_remote": "false"})
  190. headers = {
  191. b"Content-Length": [b"%d" % (len(self.test_image.data))],
  192. b"Content-Type": [self.test_image.content_type],
  193. }
  194. if content_disposition:
  195. headers[b"Content-Disposition"] = [content_disposition]
  196. self.fetches[0][0].callback(
  197. (self.test_image.data, (len(self.test_image.data), headers))
  198. )
  199. self.pump()
  200. self.assertEqual(channel.code, 200)
  201. return channel
  202. def test_disposition_filename_ascii(self):
  203. """
  204. If the filename is filename=<ascii> then Synapse will decode it as an
  205. ASCII string, and use filename= in the response.
  206. """
  207. channel = self._req(b"inline; filename=out" + self.test_image.extension)
  208. headers = channel.headers
  209. self.assertEqual(
  210. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  211. )
  212. self.assertEqual(
  213. headers.getRawHeaders(b"Content-Disposition"),
  214. [b"inline; filename=out" + self.test_image.extension],
  215. )
  216. def test_disposition_filenamestar_utf8escaped(self):
  217. """
  218. If the filename is filename=*utf8''<utf8 escaped> then Synapse will
  219. correctly decode it as the UTF-8 string, and use filename* in the
  220. response.
  221. """
  222. filename = parse.quote("\u2603".encode("utf8")).encode("ascii")
  223. channel = self._req(
  224. b"inline; filename*=utf-8''" + filename + self.test_image.extension
  225. )
  226. headers = channel.headers
  227. self.assertEqual(
  228. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  229. )
  230. self.assertEqual(
  231. headers.getRawHeaders(b"Content-Disposition"),
  232. [b"inline; filename*=utf-8''" + filename + self.test_image.extension],
  233. )
  234. def test_disposition_none(self):
  235. """
  236. If there is no filename, one isn't passed on in the Content-Disposition
  237. of the request.
  238. """
  239. channel = self._req(None)
  240. headers = channel.headers
  241. self.assertEqual(
  242. headers.getRawHeaders(b"Content-Type"), [self.test_image.content_type]
  243. )
  244. self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
  245. def test_thumbnail_crop(self):
  246. self._test_thumbnail("crop", self.test_image.expected_cropped)
  247. def test_thumbnail_scale(self):
  248. self._test_thumbnail("scale", self.test_image.expected_scaled)
  249. def _test_thumbnail(self, method, expected_body):
  250. params = "?width=32&height=32&method=" + method
  251. request, channel = self.make_request(
  252. "GET", self.media_id + params, shorthand=False
  253. )
  254. request.render(self.thumbnail_resource)
  255. self.pump()
  256. headers = {
  257. b"Content-Length": [b"%d" % (len(self.test_image.data))],
  258. b"Content-Type": [self.test_image.content_type],
  259. }
  260. self.fetches[0][0].callback(
  261. (self.test_image.data, (len(self.test_image.data), headers))
  262. )
  263. self.pump()
  264. self.assertEqual(channel.code, 200)
  265. if expected_body is not None:
  266. self.assertEqual(
  267. channel.result["body"], expected_body, channel.result["body"]
  268. )
  269. else:
  270. # ensure that the result is at least some valid image
  271. Image.open(BytesIO(channel.result["body"]))