123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Copyright 2019-2021 The Matrix.org Foundation C.I.C.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Utilities for running the unit tests
- """
- import json
- import sys
- import warnings
- from asyncio import Future
- from binascii import unhexlify
- from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, TypeVar
- from unittest.mock import Mock
- import attr
- import zope.interface
- from twisted.internet.interfaces import IProtocol
- from twisted.python.failure import Failure
- from twisted.web.client import ResponseDone
- from twisted.web.http import RESPONSES
- from twisted.web.http_headers import Headers
- from twisted.web.iweb import IResponse
- from synapse.types import JsonDict
- if TYPE_CHECKING:
- from sys import UnraisableHookArgs
- TV = TypeVar("TV")
- def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
- """Get the result from an Awaitable which should have completed
- Asserts that the given awaitable has a result ready, and returns its value
- """
- i = awaitable.__await__()
- try:
- next(i)
- except StopIteration as e:
- # awaitable returned a result
- return e.value
- # if next didn't raise, the awaitable hasn't completed.
- raise Exception("awaitable has not yet completed")
- def make_awaitable(result: TV) -> Awaitable[TV]:
- """
- Makes an awaitable, suitable for mocking an `async` function.
- This uses Futures as they can be awaited multiple times so can be returned
- to multiple callers.
- """
- future: Future[TV] = Future()
- future.set_result(result)
- return future
- def setup_awaitable_errors() -> Callable[[], None]:
- """
- Convert warnings from a non-awaited coroutines into errors.
- """
- warnings.simplefilter("error", RuntimeWarning)
- # unraisablehook was added in Python 3.8.
- if not hasattr(sys, "unraisablehook"):
- return lambda: None
- # State shared between unraisablehook and check_for_unraisable_exceptions.
- unraisable_exceptions = []
- orig_unraisablehook = sys.unraisablehook
- def unraisablehook(unraisable: "UnraisableHookArgs") -> None:
- unraisable_exceptions.append(unraisable.exc_value)
- def cleanup() -> None:
- """
- A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
- """
- sys.unraisablehook = orig_unraisablehook
- if unraisable_exceptions:
- exc = unraisable_exceptions.pop()
- assert exc is not None
- raise exc
- sys.unraisablehook = unraisablehook
- return cleanup
- def simple_async_mock(
- return_value: Optional[TV] = None, raises: Optional[Exception] = None
- ) -> Mock:
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args: Any, **kwargs: Any) -> Optional[TV]:
- if raises:
- raise raises
- return return_value
- return Mock(side_effect=cb)
- # Type ignore: it does not fully implement IResponse, but is good enough for tests
- @zope.interface.implementer(IResponse)
- @attr.s(slots=True, frozen=True, auto_attribs=True)
- class FakeResponse: # type: ignore[misc]
- """A fake twisted.web.IResponse object
- there is a similar class at treq.test.test_response, but it lacks a `phrase`
- attribute, and didn't support deliverBody until recently.
- """
- version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
- # HTTP response code
- code: int = 200
- # body of the response
- body: bytes = b""
- headers: Headers = attr.Factory(Headers)
- @property
- def phrase(self) -> bytes:
- return RESPONSES.get(self.code, b"Unknown Status")
- @property
- def length(self) -> int:
- return len(self.body)
- def deliverBody(self, protocol: IProtocol) -> None:
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
- @classmethod
- def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
- headers = Headers({"Content-Type": ["application/json"]})
- body = json.dumps(payload).encode("utf-8")
- return cls(code=code, body=body, headers=headers)
- # A small image used in some tests.
- #
- # Resolution: 1×1, MIME type: image/png, Extension: png, Size: 67 B
- SMALL_PNG = unhexlify(
- b"89504e470d0a1a0a0000000d4948445200000001000000010806"
- b"0000001f15c4890000000a49444154789c63000100000500010d"
- b"0a2db40000000049454e44ae426082"
- )
|