Browse Source

Convert some test cases to use HomeserverTestCase. (#9377)

This has the side-effect of being able to remove use of `inlineCallbacks`
in the test-cases for cleaner tracebacks.
Patrick Cloke 3 years ago
parent
commit
8a33d217bd

+ 1 - 0
changelog.d/9377.misc

@@ -0,0 +1 @@
+Convert tests to use `HomeserverTestCase`.

+ 54 - 79
tests/handlers/test_auth.py

@@ -16,28 +16,21 @@ from mock import Mock
 
 
 import pymacaroons
 import pymacaroons
 
 
-from twisted.internet import defer
-
-import synapse
-import synapse.api.errors
-from synapse.api.errors import ResourceLimitError
+from synapse.api.errors import AuthError, ResourceLimitError
 
 
 from tests import unittest
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
 
 
-class AuthTestCase(unittest.TestCase):
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield setup_test_homeserver(self.addCleanup)
-        self.auth_handler = self.hs.get_auth_handler()
-        self.macaroon_generator = self.hs.get_macaroon_generator()
+class AuthTestCase(unittest.HomeserverTestCase):
+    def prepare(self, reactor, clock, hs):
+        self.auth_handler = hs.get_auth_handler()
+        self.macaroon_generator = hs.get_macaroon_generator()
 
 
         # MAU tests
         # MAU tests
         # AuthBlocking reads from the hs' config on initialization. We need to
         # AuthBlocking reads from the hs' config on initialization. We need to
         # modify its config instead of the hs'
         # modify its config instead of the hs'
-        self.auth_blocking = self.hs.get_auth()._auth_blocking
+        self.auth_blocking = hs.get_auth()._auth_blocking
         self.auth_blocking._max_mau_value = 50
         self.auth_blocking._max_mau_value = 50
 
 
         self.small_number_of_users = 1
         self.small_number_of_users = 1
@@ -52,8 +45,6 @@ class AuthTestCase(unittest.TestCase):
             self.fail("some_user was not in %s" % macaroon.inspect())
             self.fail("some_user was not in %s" % macaroon.inspect())
 
 
     def test_macaroon_caveats(self):
     def test_macaroon_caveats(self):
-        self.hs.get_clock().now = 5000
-
         token = self.macaroon_generator.generate_access_token("a_user")
         token = self.macaroon_generator.generate_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
 
@@ -76,29 +67,25 @@ class AuthTestCase(unittest.TestCase):
         v.satisfy_general(verify_nonce)
         v.satisfy_general(verify_nonce)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_gives_user_id(self):
     def test_short_term_login_token_gives_user_id(self):
-        self.hs.get_clock().now = 1000
-
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
             self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
         )
         )
         self.assertEqual("a_user", user_id)
         self.assertEqual("a_user", user_id)
 
 
         # when we advance the clock, the token should be rejected
         # when we advance the clock, the token should be rejected
-        self.hs.get_clock().now = 6000
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
-            )
+        self.reactor.advance(6)
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            AuthError,
+        )
 
 
-    @defer.inlineCallbacks
     def test_short_term_login_token_cannot_replace_user_id(self):
     def test_short_term_login_token_cannot_replace_user_id(self):
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
         macaroon = pymacaroons.Macaroon.deserialize(token)
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
 
-        user_id = yield defer.ensureDeferred(
+        user_id = self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 macaroon.serialize()
                 macaroon.serialize()
             )
             )
@@ -109,102 +96,90 @@ class AuthTestCase(unittest.TestCase):
         # user_id.
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
         macaroon.add_first_party_caveat("user_id = b_user")
 
 
-        with self.assertRaises(synapse.api.errors.AuthError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    macaroon.serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                macaroon.serialize()
+            ),
+            AuthError,
+        )
 
 
-    @defer.inlineCallbacks
     def test_mau_limits_disabled(self):
     def test_mau_limits_disabled(self):
         self.auth_blocking._limit_usage_by_mau = False
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
         # Ensure does not throw exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
                 "user_a", device_id=None, valid_until_ms=None
             )
             )
         )
         )
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
                 self._get_macaroon().serialize()
             )
             )
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_mau_limits_exceeded_large(self):
     def test_mau_limits_exceeded_large(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastore().get_monthly_active_count = Mock(
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
             return_value=make_awaitable(self.large_number_of_users)
         )
         )
 
 
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
+        )
 
 
         self.hs.get_datastore().get_monthly_active_count = Mock(
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
             return_value=make_awaitable(self.large_number_of_users)
         )
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
 
 
-    @defer.inlineCallbacks
     def test_mau_limits_parity(self):
     def test_mau_limits_parity(self):
+        # Ensure we're not at the unix epoch.
+        self.reactor.advance(1)
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._limit_usage_by_mau = True
 
 
-        # If not in monthly active cohort
+        # Set the server to be at the edge of too many users.
         self.hs.get_datastore().get_monthly_active_count = Mock(
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
             return_value=make_awaitable(self.auth_blocking._max_mau_value)
         )
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.get_access_token_for_user_id(
-                    "user_a", device_id=None, valid_until_ms=None
-                )
-            )
 
 
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
+        # If not in monthly active cohort
+        self.get_failure(
+            self.auth_handler.get_access_token_for_user_id(
+                "user_a", device_id=None, valid_until_ms=None
+            ),
+            ResourceLimitError,
         )
         )
-        with self.assertRaises(ResourceLimitError):
-            yield defer.ensureDeferred(
-                self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                    self._get_macaroon().serialize()
-                )
-            )
+        self.get_failure(
+            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+                self._get_macaroon().serialize()
+            ),
+            ResourceLimitError,
+        )
+
         # If in monthly active cohort
         # If in monthly active cohort
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
         self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+            return_value=make_awaitable(self.clock.time_msec())
         )
         )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
                 "user_a", device_id=None, valid_until_ms=None
             )
             )
         )
         )
-        self.hs.get_datastore().user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
-        )
-        self.hs.get_datastore().get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.auth_blocking._max_mau_value)
-        )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
                 self._get_macaroon().serialize()
             )
             )
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_mau_limits_not_exceeded(self):
     def test_mau_limits_not_exceeded(self):
         self.auth_blocking._limit_usage_by_mau = True
         self.auth_blocking._limit_usage_by_mau = True
 
 
@@ -212,7 +187,7 @@ class AuthTestCase(unittest.TestCase):
             return_value=make_awaitable(self.small_number_of_users)
             return_value=make_awaitable(self.small_number_of_users)
         )
         )
         # Ensure does not raise exception
         # Ensure does not raise exception
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.get_access_token_for_user_id(
             self.auth_handler.get_access_token_for_user_id(
                 "user_a", device_id=None, valid_until_ms=None
                 "user_a", device_id=None, valid_until_ms=None
             )
             )
@@ -221,7 +196,7 @@ class AuthTestCase(unittest.TestCase):
         self.hs.get_datastore().get_monthly_active_count = Mock(
         self.hs.get_datastore().get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
             return_value=make_awaitable(self.small_number_of_users)
         )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
             self.auth_handler.validate_short_term_login_token_and_get_user_id(
                 self._get_macaroon().serialize()
                 self._get_macaroon().serialize()
             )
             )

+ 91 - 139
tests/handlers/test_e2e_keys.py

@@ -18,42 +18,27 @@ import mock
 
 
 from signedjson import key as key, sign as sign
 from signedjson import key as key, sign as sign
 
 
-from twisted.internet import defer
-
-import synapse.handlers.e2e_keys
-import synapse.storage
-from synapse.api import errors
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.constants import RoomEncryptionAlgorithms
+from synapse.api.errors import Codes, SynapseError
 
 
-from tests import unittest, utils
+from tests import unittest
 
 
 
 
-class E2eKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eKeysHandler
-        self.store = None  # type: synapse.storage.Storage
+class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(federation_client=mock.Mock())
 
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, federation_client=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs)
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_keys_handler()
         self.store = self.hs.get_datastore()
         self.store = self.hs.get_datastore()
 
 
-    @defer.inlineCallbacks
     def test_query_local_devices_no_devices(self):
     def test_query_local_devices_no_devices(self):
         """If the user has no devices, we expect an empty list.
         """If the user has no devices, we expect an empty list.
         """
         """
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
-        )
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
         self.assertDictEqual(res, {local_user: {}})
 
 
-    @defer.inlineCallbacks
     def test_reupload_one_time_keys(self):
     def test_reupload_one_time_keys(self):
         """we should be able to re-upload the same keys"""
         """we should be able to re-upload the same keys"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
@@ -64,7 +49,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
             "alg2:k3": {"key": "key3"},
         }
         }
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
                 local_user, device_id, {"one_time_keys": keys}
             )
             )
@@ -73,14 +58,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
 
         # we should be able to change the signature without a problem
         # we should be able to change the signature without a problem
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
         keys["alg2:k2"]["signatures"]["k1"] = "sig2"
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
                 local_user, device_id, {"one_time_keys": keys}
             )
             )
         )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
 
-    @defer.inlineCallbacks
     def test_change_one_time_keys(self):
     def test_change_one_time_keys(self):
         """attempts to change one-time-keys should be rejected"""
         """attempts to change one-time-keys should be rejected"""
 
 
@@ -92,75 +76,64 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "alg2:k3": {"key": "key3"},
             "alg2:k3": {"key": "key3"},
         }
         }
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
                 local_user, device_id, {"one_time_keys": keys}
             )
             )
         )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1, "alg2": 2}})
 
 
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
-                )
-            )
-            self.fail("No error when changing string key")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
-                )
-            )
-            self.fail("No error when replacing dict key with string")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {"one_time_keys": {"alg1:k1": {"key": "key"}}},
-                )
-            )
-            self.fail("No error when replacing string key with dict")
-        except errors.SynapseError:
-            pass
-
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_keys_for_user(
-                    local_user,
-                    device_id,
-                    {
-                        "one_time_keys": {
-                            "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
-                        }
-                    },
-                )
-            )
-            self.fail("No error when replacing dict key")
-        except errors.SynapseError:
-            pass
+        # Error when changing string key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key with strin
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}}
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing string key with dict
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": {"key": "key"}}},
+            ),
+            SynapseError,
+        )
+
+        # Error when replacing dict key
+        self.get_failure(
+            self.handler.upload_keys_for_user(
+                local_user,
+                device_id,
+                {
+                    "one_time_keys": {
+                        "alg2:k2": {"key": "key3", "signatures": {"k1": "sig1"}}
+                    }
+                },
+            ),
+            SynapseError,
+        )
 
 
-    @defer.inlineCallbacks
     def test_claim_one_time_key(self):
     def test_claim_one_time_key(self):
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         device_id = "xyz"
         keys = {"alg1:k1": "key1"}
         keys = {"alg1:k1": "key1"}
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": keys}
                 local_user, device_id, {"one_time_keys": keys}
             )
             )
         )
         )
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
         self.assertDictEqual(res, {"one_time_key_counts": {"alg1": 1}})
 
 
-        res2 = yield defer.ensureDeferred(
+        res2 = self.get_success(
             self.handler.claim_one_time_keys(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
             )
@@ -173,7 +146,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             },
             },
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_fallback_key(self):
     def test_fallback_key(self):
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
         device_id = "xyz"
@@ -181,12 +153,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         otk = {"alg1:k2": "key2"}
         otk = {"alg1:k2": "key2"}
 
 
         # we shouldn't have any unused fallback keys yet
         # we shouldn't have any unused fallback keys yet
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         )
         self.assertEqual(res, [])
         self.assertEqual(res, [])
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user,
                 local_user,
                 device_id,
                 device_id,
@@ -195,14 +167,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         # we should now have an unused alg1 key
         # we should now have an unused alg1 key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         )
         self.assertEqual(res, ["alg1"])
         self.assertEqual(res, ["alg1"])
 
 
         # claiming an OTK when no OTKs are available should return the fallback
         # claiming an OTK when no OTKs are available should return the fallback
         # key
         # key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
             )
@@ -213,13 +185,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         # we shouldn't have any unused fallback keys again
         # we shouldn't have any unused fallback keys again
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
             self.store.get_e2e_unused_fallback_key_types(local_user, device_id)
         )
         )
         self.assertEqual(res, [])
         self.assertEqual(res, [])
 
 
         # claiming an OTK again should return the same fallback key
         # claiming an OTK again should return the same fallback key
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
             )
@@ -231,13 +203,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
 
         # if the user uploads a one-time key, the next claim should fetch the
         # if the user uploads a one-time key, the next claim should fetch the
         # one-time key, and then go back to the fallback
         # one-time key, and then go back to the fallback
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"one_time_keys": otk}
                 local_user, device_id, {"one_time_keys": otk}
             )
             )
         )
         )
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
             )
@@ -246,7 +218,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
             res, {"failures": {}, "one_time_keys": {local_user: {device_id: otk}}},
         )
         )
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.claim_one_time_keys(
             self.handler.claim_one_time_keys(
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
                 {"one_time_keys": {local_user: {device_id: "alg1"}}}, timeout=None
             )
             )
@@ -256,7 +228,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
             {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key}}},
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_replace_master_key(self):
     def test_replace_master_key(self):
         """uploading a new signing key should make the old signing key unavailable"""
         """uploading a new signing key should make the old signing key unavailable"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
@@ -270,9 +241,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
                 },
             }
             }
         }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
 
         keys2 = {
         keys2 = {
             "master_key": {
             "master_key": {
@@ -284,16 +253,13 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
                 },
             }
             }
         }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys2)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
 
 
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
 
-    @defer.inlineCallbacks
     def test_reupload_signatures(self):
     def test_reupload_signatures(self):
         """re-uploading a signature should not fail"""
         """re-uploading a signature should not fail"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
@@ -326,9 +292,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
             "2lonYOM6xYKdEsO+6KrC766xBcHnYnim1x/4LFGF8B0",
         )
         )
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
 
         # upload two device keys, which will be signed later by the self-signing key
         # upload two device keys, which will be signed later by the self-signing key
         device_key_1 = {
         device_key_1 = {
@@ -358,12 +322,12 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
             "signatures": {local_user: {"ed25519:def": "base64+signature"}},
         }
         }
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, "abc", {"device_keys": device_key_1}
                 local_user, "abc", {"device_keys": device_key_1}
             )
             )
         )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, "def", {"device_keys": device_key_2}
                 local_user, "def", {"device_keys": device_key_2}
             )
             )
@@ -372,7 +336,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # sign the first device key and upload it
         # sign the first device key and upload it
         del device_key_1["signatures"]
         del device_key_1["signatures"]
         sign.sign_json(device_key_1, local_user, signing_key)
         sign.sign_json(device_key_1, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1}}
                 local_user, {local_user: {"abc": device_key_1}}
             )
             )
@@ -383,7 +347,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         # signature for it
         # signature for it
         del device_key_2["signatures"]
         del device_key_2["signatures"]
         sign.sign_json(device_key_2, local_user, signing_key)
         sign.sign_json(device_key_2, local_user, signing_key)
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signatures_for_device_keys(
             self.handler.upload_signatures_for_device_keys(
                 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
                 local_user, {local_user: {"abc": device_key_1, "def": device_key_2}}
             )
             )
@@ -391,7 +355,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
 
 
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
-        devices = yield defer.ensureDeferred(
+        devices = self.get_success(
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
             self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
         )
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
         del devices["device_keys"][local_user]["abc"]["unsigned"]
@@ -399,7 +363,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
         self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2)
 
 
-    @defer.inlineCallbacks
     def test_self_signing_key_doesnt_show_up_as_device(self):
     def test_self_signing_key_doesnt_show_up_as_device(self):
         """signing keys should be hidden when fetching a user's devices"""
         """signing keys should be hidden when fetching a user's devices"""
         local_user = "@boris:" + self.hs.hostname
         local_user = "@boris:" + self.hs.hostname
@@ -413,29 +376,22 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
                 },
                 },
             }
             }
         }
         }
-        yield defer.ensureDeferred(
-            self.handler.upload_signing_keys_for_user(local_user, keys1)
-        )
-
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.hs.get_device_handler().check_device_registered(
-                    user_id=local_user,
-                    device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
-                    initial_device_display_name="new display name",
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
-        self.assertEqual(res, 400)
+        self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys1))
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.query_local_devices({local_user: None})
+        e = self.get_failure(
+            self.hs.get_device_handler().check_device_registered(
+                user_id=local_user,
+                device_id="nqOvzeuGWT/sRx3h7+MHoInYj3Uk2LD/unI9kDYcHwk",
+                initial_device_display_name="new display name",
+            ),
+            SynapseError,
         )
         )
+        res = e.value.code
+        self.assertEqual(res, 400)
+
+        res = self.get_success(self.handler.query_local_devices({local_user: None}))
         self.assertDictEqual(res, {local_user: {}})
         self.assertDictEqual(res, {local_user: {}})
 
 
-    @defer.inlineCallbacks
     def test_upload_signatures(self):
     def test_upload_signatures(self):
         """should check signatures that are uploaded"""
         """should check signatures that are uploaded"""
         # set up a user with cross-signing keys and a device.  This user will
         # set up a user with cross-signing keys and a device.  This user will
@@ -458,7 +414,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
             "ed25519", "xyz", "OMkooTr76ega06xNvXIGPbgvvxAOzmQncN8VObS7aBA"
         )
         )
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_keys_for_user(
             self.handler.upload_keys_for_user(
                 local_user, device_id, {"device_keys": device_key}
                 local_user, device_id, {"device_keys": device_key}
             )
             )
@@ -501,7 +457,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "user_signing_key": usersigning_key,
             "user_signing_key": usersigning_key,
             "self_signing_key": selfsigning_key,
             "self_signing_key": selfsigning_key,
         }
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
             self.handler.upload_signing_keys_for_user(local_user, cross_signing_keys)
         )
         )
 
 
@@ -515,14 +471,14 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
             "usage": ["master"],
             "usage": ["master"],
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
             "keys": {"ed25519:" + other_master_pubkey: other_master_pubkey},
         }
         }
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_signing_keys_for_user(
             self.handler.upload_signing_keys_for_user(
                 other_user, {"master_key": other_master_key}
                 other_user, {"master_key": other_master_key}
             )
             )
         )
         )
 
 
         # test various signature failures (see below)
         # test various signature failures (see below)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 local_user,
                 {
                 {
@@ -602,20 +558,16 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         user_failures = ret["failures"][local_user]
         user_failures = ret["failures"][local_user]
+        self.assertEqual(user_failures[device_id]["errcode"], Codes.INVALID_SIGNATURE)
         self.assertEqual(
         self.assertEqual(
-            user_failures[device_id]["errcode"], errors.Codes.INVALID_SIGNATURE
+            user_failures[master_pubkey]["errcode"], Codes.INVALID_SIGNATURE
         )
         )
-        self.assertEqual(
-            user_failures[master_pubkey]["errcode"], errors.Codes.INVALID_SIGNATURE
-        )
-        self.assertEqual(user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND)
+        self.assertEqual(user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
 
 
         other_user_failures = ret["failures"][other_user]
         other_user_failures = ret["failures"][other_user]
+        self.assertEqual(other_user_failures["unknown"]["errcode"], Codes.NOT_FOUND)
         self.assertEqual(
         self.assertEqual(
-            other_user_failures["unknown"]["errcode"], errors.Codes.NOT_FOUND
-        )
-        self.assertEqual(
-            other_user_failures[other_master_pubkey]["errcode"], errors.Codes.UNKNOWN
+            other_user_failures[other_master_pubkey]["errcode"], Codes.UNKNOWN
         )
         )
 
 
         # test successful signatures
         # test successful signatures
@@ -623,7 +575,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(device_key, local_user, selfsigning_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(master_key, local_user, device_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
         sign.sign_json(other_master_key, local_user, usersigning_signing_key)
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.upload_signatures_for_device_keys(
             self.handler.upload_signatures_for_device_keys(
                 local_user,
                 local_user,
                 {
                 {
@@ -636,7 +588,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(ret["failures"], {})
         self.assertEqual(ret["failures"], {})
 
 
         # fetch the signed keys/devices and make sure that the signatures are there
         # fetch the signed keys/devices and make sure that the signatures are there
-        ret = yield defer.ensureDeferred(
+        ret = self.get_success(
             self.handler.query_devices(
             self.handler.query_devices(
                 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
                 {"device_keys": {local_user: [], other_user: []}}, 0, local_user
             )
             )

+ 117 - 188
tests/handlers/test_e2e_room_keys.py

@@ -19,14 +19,9 @@ import copy
 
 
 import mock
 import mock
 
 
-from twisted.internet import defer
+from synapse.api.errors import SynapseError
 
 
-import synapse.api.errors
-import synapse.handlers.e2e_room_keys
-import synapse.storage
-from synapse.api import errors
-
-from tests import unittest, utils
+from tests import unittest
 
 
 # sample room_key data for use in the tests
 # sample room_key data for use in the tests
 room_keys = {
 room_keys = {
@@ -45,51 +40,39 @@ room_keys = {
 }
 }
 
 
 
 
-class E2eRoomKeysHandlerTestCase(unittest.TestCase):
-    def __init__(self, *args, **kwargs):
-        super().__init__(*args, **kwargs)
-        self.hs = None  # type: synapse.server.HomeServer
-        self.handler = None  # type: synapse.handlers.e2e_keys.E2eRoomKeysHandler
+class E2eRoomKeysHandlerTestCase(unittest.HomeserverTestCase):
+    def make_homeserver(self, reactor, clock):
+        return self.setup_test_homeserver(replication_layer=mock.Mock())
 
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        self.hs = yield utils.setup_test_homeserver(
-            self.addCleanup, replication_layer=mock.Mock()
-        )
-        self.handler = synapse.handlers.e2e_room_keys.E2eRoomKeysHandler(self.hs)
-        self.local_user = "@boris:" + self.hs.hostname
+    def prepare(self, reactor, clock, hs):
+        self.handler = hs.get_e2e_room_keys_handler()
+        self.local_user = "@boris:" + hs.hostname
 
 
-    @defer.inlineCallbacks
     def test_get_missing_current_version_info(self):
     def test_get_missing_current_version_info(self):
         """Check that we get a 404 if we ask for info about the current version
         """Check that we get a 404 if we ask for info about the current version
         if there is no version.
         if there is no version.
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_get_missing_version_info(self):
     def test_get_missing_version_info(self):
         """Check that we get a 404 if we ask for info about a specific version
         """Check that we get a 404 if we ask for info about a specific version
         if it doesn't exist.
         if it doesn't exist.
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "bogus_version"),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_create_version(self):
     def test_create_version(self):
         """Check that we can create and then retrieve versions.
         """Check that we can create and then retrieve versions.
         """
         """
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -101,7 +84,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
         self.assertEqual(res, "1")
 
 
         # check we can retrieve it as the current version
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         version_etag = res["etag"]
         version_etag = res["etag"]
         self.assertIsInstance(version_etag, str)
         self.assertIsInstance(version_etag, str)
         del res["etag"]
         del res["etag"]
@@ -116,9 +99,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         # check we can retrieve it as a specific version
         # check we can retrieve it as a specific version
-        res = yield defer.ensureDeferred(
-            self.handler.get_version_info(self.local_user, "1")
-        )
+        res = self.get_success(self.handler.get_version_info(self.local_user, "1"))
         self.assertEqual(res["etag"], version_etag)
         self.assertEqual(res["etag"], version_etag)
         del res["etag"]
         del res["etag"]
         self.assertDictEqual(
         self.assertDictEqual(
@@ -132,7 +113,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         # upload a new one...
         # upload a new one...
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -144,7 +125,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "2")
         self.assertEqual(res, "2")
 
 
         # check we can retrieve it as the current version
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         del res["etag"]
         self.assertDictEqual(
         self.assertDictEqual(
             res,
             res,
@@ -156,11 +137,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
             },
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_update_version(self):
     def test_update_version(self):
         """Check that we can update versions.
         """Check that we can update versions.
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -171,7 +151,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.update_version(
             self.handler.update_version(
                 self.local_user,
                 self.local_user,
                 version,
                 version,
@@ -185,7 +165,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {})
         self.assertDictEqual(res, {})
 
 
         # check we can retrieve it as the current version
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]
         del res["etag"]
         self.assertDictEqual(
         self.assertDictEqual(
             res,
             res,
@@ -197,32 +177,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
             },
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_update_missing_version(self):
     def test_update_missing_version(self):
         """Check that we get a 404 on updating nonexistent versions
         """Check that we get a 404 on updating nonexistent versions
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    "1",
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "1",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                "1",
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "1",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_update_omitted_version(self):
     def test_update_omitted_version(self):
         """Check that the update succeeds if the version is missing from the body
         """Check that the update succeeds if the version is missing from the body
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -233,7 +209,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.update_version(
             self.handler.update_version(
                 self.local_user,
                 self.local_user,
                 version,
                 version,
@@ -245,7 +221,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
 
 
         # check we can retrieve it as the current version
         # check we can retrieve it as the current version
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         del res["etag"]  # etag is opaque, so don't test its contents
         del res["etag"]  # etag is opaque, so don't test its contents
         self.assertDictEqual(
         self.assertDictEqual(
             res,
             res,
@@ -257,11 +233,10 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
             },
             },
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_update_bad_version(self):
     def test_update_bad_version(self):
         """Check that we get a 400 if the version in the body doesn't match
         """Check that we get a 400 if the version in the body doesn't match
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -272,52 +247,41 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.update_version(
-                    self.local_user,
-                    version,
-                    {
-                        "algorithm": "m.megolm_backup.v1",
-                        "auth_data": "revised_first_version_auth_data",
-                        "version": "incorrect",
-                    },
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.update_version(
+                self.local_user,
+                version,
+                {
+                    "algorithm": "m.megolm_backup.v1",
+                    "auth_data": "revised_first_version_auth_data",
+                    "version": "incorrect",
+                },
+            ),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 400)
         self.assertEqual(res, 400)
 
 
-    @defer.inlineCallbacks
     def test_delete_missing_version(self):
     def test_delete_missing_version(self):
         """Check that we get a 404 on deleting nonexistent versions
         """Check that we get a 404 on deleting nonexistent versions
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.delete_version(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.delete_version(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_delete_missing_current_version(self):
     def test_delete_missing_current_version(self):
         """Check that we get a 404 on deleting nonexistent current version
         """Check that we get a 404 on deleting nonexistent current version
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(self.handler.delete_version(self.local_user))
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(self.handler.delete_version(self.local_user), SynapseError)
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_delete_version(self):
     def test_delete_version(self):
         """Check that we can create and then delete versions.
         """Check that we can create and then delete versions.
         """
         """
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -329,36 +293,28 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(res, "1")
         self.assertEqual(res, "1")
 
 
         # check we can delete it
         # check we can delete it
-        yield defer.ensureDeferred(self.handler.delete_version(self.local_user, "1"))
+        self.get_success(self.handler.delete_version(self.local_user, "1"))
 
 
         # check that it's gone
         # check that it's gone
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_version_info(self.local_user, "1")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_version_info(self.local_user, "1"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_get_missing_backup(self):
     def test_get_missing_backup(self):
         """Check that we get a 404 on querying missing backup
         """Check that we get a 404 on querying missing backup
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.get_room_keys(self.local_user, "bogus_version")
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.get_room_keys(self.local_user, "bogus_version"), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_get_missing_room_keys(self):
     def test_get_missing_room_keys(self):
         """Check we get an empty response from an empty backup
         """Check we get an empty response from an empty backup
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -369,33 +325,27 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, {"rooms": {}})
         self.assertDictEqual(res, {"rooms": {}})
 
 
     # TODO: test the locking semantics when uploading room_keys,
     # TODO: test the locking semantics when uploading room_keys,
     # although this is probably best done in sytest
     # although this is probably best done in sytest
 
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_no_versions(self):
     def test_upload_room_keys_no_versions(self):
         """Check that we get a 404 on uploading keys when no versions are defined
         """Check that we get a 404 on uploading keys when no versions are defined
         """
         """
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "no_version", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "no_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_bogus_version(self):
     def test_upload_room_keys_bogus_version(self):
         """Check that we get a 404 on uploading keys when an nonexistent version
         """Check that we get a 404 on uploading keys when an nonexistent version
         is specified
         is specified
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -406,22 +356,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(
-                    self.local_user, "bogus_version", room_keys
-                )
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "bogus_version", room_keys),
+            SynapseError,
+        )
+        res = e.value.code
         self.assertEqual(res, 404)
         self.assertEqual(res, 404)
 
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_wrong_version(self):
     def test_upload_room_keys_wrong_version(self):
         """Check that we get a 403 on uploading keys for an old version
         """Check that we get a 403 on uploading keys for an old version
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -432,7 +377,7 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -443,20 +388,16 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "2")
         self.assertEqual(version, "2")
 
 
-        res = None
-        try:
-            yield defer.ensureDeferred(
-                self.handler.upload_room_keys(self.local_user, "1", room_keys)
-            )
-        except errors.SynapseError as e:
-            res = e.code
+        e = self.get_failure(
+            self.handler.upload_room_keys(self.local_user, "1", room_keys), SynapseError
+        )
+        res = e.value.code
         self.assertEqual(res, 403)
         self.assertEqual(res, 403)
 
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_insert(self):
     def test_upload_room_keys_insert(self):
         """Check that we can insert and retrieve keys for a session
         """Check that we can insert and retrieve keys for a session
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -467,17 +408,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
         )
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertDictEqual(res, room_keys)
         self.assertDictEqual(res, room_keys)
 
 
         # check getting room_keys for a given room
         # check getting room_keys for a given room
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
                 self.local_user, version, room_id="!abc:matrix.org"
             )
             )
@@ -485,18 +424,17 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, room_keys)
         self.assertDictEqual(res, room_keys)
 
 
         # check getting room_keys for a given session_id
         # check getting room_keys for a given session_id
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
             )
         )
         )
         self.assertDictEqual(res, room_keys)
         self.assertDictEqual(res, room_keys)
 
 
-    @defer.inlineCallbacks
     def test_upload_room_keys_merge(self):
     def test_upload_room_keys_merge(self):
         """Check that we can upload a new room_key for an existing session and
         """Check that we can upload a new room_key for an existing session and
         have it correctly merged"""
         have it correctly merged"""
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -507,12 +445,12 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         )
         )
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
         )
 
 
         # get the etag to compare to future versions
         # get the etag to compare to future versions
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         backup_etag = res["etag"]
         backup_etag = res["etag"]
         self.assertEqual(res["count"], 1)
         self.assertEqual(res["count"], 1)
 
 
@@ -522,37 +460,33 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # test that increasing the message_index doesn't replace the existing session
         # test that increasing the message_index doesn't replace the existing session
         new_room_key["first_message_index"] = 2
         new_room_key["first_message_index"] = 2
         new_room_key["session_data"] = "new"
         new_room_key["session_data"] = "new"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
         )
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"],
             "SSBBTSBBIEZJU0gK",
             "SSBBTSBBIEZJU0gK",
         )
         )
 
 
         # the etag should be the same since the session did not change
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
         self.assertEqual(res["etag"], backup_etag)
 
 
         # test that marking the session as verified however /does/ replace it
         # test that marking the session as verified however /does/ replace it
         new_room_key["is_verified"] = True
         new_room_key["is_verified"] = True
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
         )
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
         )
 
 
         # the etag should NOT be equal now, since the key changed
         # the etag should NOT be equal now, since the key changed
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertNotEqual(res["etag"], backup_etag)
         self.assertNotEqual(res["etag"], backup_etag)
         backup_etag = res["etag"]
         backup_etag = res["etag"]
 
 
@@ -560,28 +494,25 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         # with a lower forwarding count
         # with a lower forwarding count
         new_room_key["forwarded_count"] = 2
         new_room_key["forwarded_count"] = 2
         new_room_key["session_data"] = "other"
         new_room_key["session_data"] = "other"
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
             self.handler.upload_room_keys(self.local_user, version, new_room_keys)
         )
         )
 
 
-        res = yield defer.ensureDeferred(
-            self.handler.get_room_keys(self.local_user, version)
-        )
+        res = self.get_success(self.handler.get_room_keys(self.local_user, version))
         self.assertEqual(
         self.assertEqual(
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
             res["rooms"]["!abc:matrix.org"]["sessions"]["c0ff33"]["session_data"], "new"
         )
         )
 
 
         # the etag should be the same since the session did not change
         # the etag should be the same since the session did not change
-        res = yield defer.ensureDeferred(self.handler.get_version_info(self.local_user))
+        res = self.get_success(self.handler.get_version_info(self.local_user))
         self.assertEqual(res["etag"], backup_etag)
         self.assertEqual(res["etag"], backup_etag)
 
 
         # TODO: check edge cases as well as the common variations here
         # TODO: check edge cases as well as the common variations here
 
 
-    @defer.inlineCallbacks
     def test_delete_room_keys(self):
     def test_delete_room_keys(self):
         """Check that we can insert and delete keys for a session
         """Check that we can insert and delete keys for a session
         """
         """
-        version = yield defer.ensureDeferred(
+        version = self.get_success(
             self.handler.create_version(
             self.handler.create_version(
                 self.local_user,
                 self.local_user,
                 {
                 {
@@ -593,13 +524,11 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertEqual(version, "1")
         self.assertEqual(version, "1")
 
 
         # check for bulk-delete
         # check for bulk-delete
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
         )
-        yield defer.ensureDeferred(
-            self.handler.delete_room_keys(self.local_user, version)
-        )
-        res = yield defer.ensureDeferred(
+        self.get_success(self.handler.delete_room_keys(self.local_user, version))
+        res = self.get_success(
             self.handler.get_room_keys(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
             )
@@ -607,15 +536,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
         self.assertDictEqual(res, {"rooms": {}})
 
 
         # check for bulk-delete per room
         # check for bulk-delete per room
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org"
                 self.local_user, version, room_id="!abc:matrix.org"
             )
             )
         )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
             )
@@ -623,15 +552,15 @@ class E2eRoomKeysHandlerTestCase(unittest.TestCase):
         self.assertDictEqual(res, {"rooms": {}})
         self.assertDictEqual(res, {"rooms": {}})
 
 
         # check for bulk-delete per session
         # check for bulk-delete per session
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.upload_room_keys(self.local_user, version, room_keys)
             self.handler.upload_room_keys(self.local_user, version, room_keys)
         )
         )
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.delete_room_keys(
             self.handler.delete_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
             )
         )
         )
-        res = yield defer.ensureDeferred(
+        res = self.get_success(
             self.handler.get_room_keys(
             self.handler.get_room_keys(
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
                 self.local_user, version, room_id="!abc:matrix.org", session_id="c0ff33"
             )
             )

+ 39 - 82
tests/handlers/test_profile.py

@@ -13,25 +13,20 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 
 
-
 from mock import Mock
 from mock import Mock
 
 
-from twisted.internet import defer
-
 import synapse.types
 import synapse.types
 from synapse.api.errors import AuthError, SynapseError
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID
 from synapse.types import UserID
 
 
 from tests import unittest
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.test_utils import make_awaitable
-from tests.utils import setup_test_homeserver
 
 
 
 
-class ProfileTestCase(unittest.TestCase):
+class ProfileTestCase(unittest.HomeserverTestCase):
     """ Tests profile management. """
     """ Tests profile management. """
 
 
-    @defer.inlineCallbacks
-    def setUp(self):
+    def make_homeserver(self, reactor, clock):
         self.mock_federation = Mock()
         self.mock_federation = Mock()
         self.mock_registry = Mock()
         self.mock_registry = Mock()
 
 
@@ -42,39 +37,35 @@ class ProfileTestCase(unittest.TestCase):
 
 
         self.mock_registry.register_query_handler = register_query_handler
         self.mock_registry.register_query_handler = register_query_handler
 
 
-        hs = yield setup_test_homeserver(
-            self.addCleanup,
+        hs = self.setup_test_homeserver(
             federation_client=self.mock_federation,
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_server=Mock(),
             federation_registry=self.mock_registry,
             federation_registry=self.mock_registry,
         )
         )
+        return hs
 
 
+    def prepare(self, reactor, clock, hs):
         self.store = hs.get_datastore()
         self.store = hs.get_datastore()
 
 
         self.frank = UserID.from_string("@1234ABCD:test")
         self.frank = UserID.from_string("@1234ABCD:test")
         self.bob = UserID.from_string("@4567:test")
         self.bob = UserID.from_string("@4567:test")
         self.alice = UserID.from_string("@alice:remote")
         self.alice = UserID.from_string("@alice:remote")
 
 
-        yield defer.ensureDeferred(self.store.create_profile(self.frank.localpart))
+        self.get_success(self.store.create_profile(self.frank.localpart))
 
 
         self.handler = hs.get_profile_handler()
         self.handler = hs.get_profile_handler()
-        self.hs = hs
 
 
-    @defer.inlineCallbacks
     def test_get_my_name(self):
     def test_get_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
         )
 
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.frank)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.frank))
 
 
         self.assertEquals("Frank", displayname)
         self.assertEquals("Frank", displayname)
 
 
-    @defer.inlineCallbacks
     def test_set_my_name(self):
     def test_set_my_name(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
             )
             )
@@ -82,7 +73,7 @@ class ProfileTestCase(unittest.TestCase):
 
 
         self.assertEquals(
         self.assertEquals(
             (
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
                 )
             ),
             ),
@@ -90,7 +81,7 @@ class ProfileTestCase(unittest.TestCase):
         )
         )
 
 
         # Set displayname again
         # Set displayname again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank"
                 self.frank, synapse.types.create_requester(self.frank), "Frank"
             )
             )
@@ -98,7 +89,7 @@ class ProfileTestCase(unittest.TestCase):
 
 
         self.assertEquals(
         self.assertEquals(
             (
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
                 )
             ),
             ),
@@ -106,32 +97,27 @@ class ProfileTestCase(unittest.TestCase):
         )
         )
 
 
         # Set displayname to an empty string
         # Set displayname to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), ""
                 self.frank, synapse.types.create_requester(self.frank), ""
             )
             )
         )
         )
 
 
         self.assertIsNone(
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_displayname(self.frank.localpart)
-                )
-            )
+            (self.get_success(self.store.get_profile_displayname(self.frank.localpart)))
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
         self.hs.config.enable_set_displayname = False
 
 
         # Setting displayname for the first time is allowed
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
             self.store.set_profile_displayname(self.frank.localpart, "Frank")
         )
         )
 
 
         self.assertEquals(
         self.assertEquals(
             (
             (
-                yield defer.ensureDeferred(
+                self.get_success(
                     self.store.get_profile_displayname(self.frank.localpart)
                     self.store.get_profile_displayname(self.frank.localpart)
                 )
                 )
             ),
             ),
@@ -139,33 +125,27 @@ class ProfileTestCase(unittest.TestCase):
         )
         )
 
 
         # Setting displayname a second time is forbidden
         # Setting displayname a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
                 self.frank, synapse.types.create_requester(self.frank), "Frank Jr."
-            )
+            ),
+            SynapseError,
         )
         )
 
 
-        yield self.assertFailure(d, SynapseError)
-
-    @defer.inlineCallbacks
     def test_set_my_name_noauth(self):
     def test_set_my_name_noauth(self):
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_displayname(
             self.handler.set_displayname(
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
                 self.frank, synapse.types.create_requester(self.bob), "Frank Jr."
-            )
+            ),
+            AuthError,
         )
         )
 
 
-        yield self.assertFailure(d, AuthError)
-
-    @defer.inlineCallbacks
     def test_get_other_name(self):
     def test_get_other_name(self):
         self.mock_federation.make_query.return_value = make_awaitable(
         self.mock_federation.make_query.return_value = make_awaitable(
             {"displayname": "Alice"}
             {"displayname": "Alice"}
         )
         )
 
 
-        displayname = yield defer.ensureDeferred(
-            self.handler.get_displayname(self.alice)
-        )
+        displayname = self.get_success(self.handler.get_displayname(self.alice))
 
 
         self.assertEquals(displayname, "Alice")
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
         self.mock_federation.make_query.assert_called_with(
@@ -175,14 +155,11 @@ class ProfileTestCase(unittest.TestCase):
             ignore_backoff=True,
             ignore_backoff=True,
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_incoming_fed_query(self):
     def test_incoming_fed_query(self):
-        yield defer.ensureDeferred(self.store.create_profile("caroline"))
-        yield defer.ensureDeferred(
-            self.store.set_profile_displayname("caroline", "Caroline")
-        )
+        self.get_success(self.store.create_profile("caroline"))
+        self.get_success(self.store.set_profile_displayname("caroline", "Caroline"))
 
 
-        response = yield defer.ensureDeferred(
+        response = self.get_success(
             self.query_handlers["profile"](
             self.query_handlers["profile"](
                 {"user_id": "@caroline:test", "field": "displayname"}
                 {"user_id": "@caroline:test", "field": "displayname"}
             )
             )
@@ -190,20 +167,18 @@ class ProfileTestCase(unittest.TestCase):
 
 
         self.assertEquals({"displayname": "Caroline"}, response)
         self.assertEquals({"displayname": "Caroline"}, response)
 
 
-    @defer.inlineCallbacks
     def test_get_my_avatar(self):
     def test_get_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
                 self.frank.localpart, "http://my.server/me.png"
             )
             )
         )
         )
-        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
+        avatar_url = self.get_success(self.handler.get_avatar_url(self.frank))
 
 
         self.assertEquals("http://my.server/me.png", avatar_url)
         self.assertEquals("http://my.server/me.png", avatar_url)
 
 
-    @defer.inlineCallbacks
     def test_set_my_avatar(self):
     def test_set_my_avatar(self):
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
             self.handler.set_avatar_url(
                 self.frank,
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 synapse.types.create_requester(self.frank),
@@ -212,16 +187,12 @@ class ProfileTestCase(unittest.TestCase):
         )
         )
 
 
         self.assertEquals(
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/pic.gif",
             "http://my.server/pic.gif",
         )
         )
 
 
         # Set avatar again
         # Set avatar again
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
             self.handler.set_avatar_url(
                 self.frank,
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 synapse.types.create_requester(self.frank),
@@ -230,56 +201,42 @@ class ProfileTestCase(unittest.TestCase):
         )
         )
 
 
         self.assertEquals(
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
             "http://my.server/me.png",
         )
         )
 
 
         # Set avatar to an empty string
         # Set avatar to an empty string
-        yield defer.ensureDeferred(
+        self.get_success(
             self.handler.set_avatar_url(
             self.handler.set_avatar_url(
                 self.frank, synapse.types.create_requester(self.frank), "",
                 self.frank, synapse.types.create_requester(self.frank), "",
             )
             )
         )
         )
 
 
         self.assertIsNone(
         self.assertIsNone(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
         )
         )
 
 
-    @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
         self.hs.config.enable_set_avatar_url = False
 
 
         # Setting displayname for the first time is allowed
         # Setting displayname for the first time is allowed
-        yield defer.ensureDeferred(
+        self.get_success(
             self.store.set_profile_avatar_url(
             self.store.set_profile_avatar_url(
                 self.frank.localpart, "http://my.server/me.png"
                 self.frank.localpart, "http://my.server/me.png"
             )
             )
         )
         )
 
 
         self.assertEquals(
         self.assertEquals(
-            (
-                yield defer.ensureDeferred(
-                    self.store.get_profile_avatar_url(self.frank.localpart)
-                )
-            ),
+            (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))),
             "http://my.server/me.png",
             "http://my.server/me.png",
         )
         )
 
 
         # Set avatar a second time is forbidden
         # Set avatar a second time is forbidden
-        d = defer.ensureDeferred(
+        self.get_failure(
             self.handler.set_avatar_url(
             self.handler.set_avatar_url(
                 self.frank,
                 self.frank,
                 synapse.types.create_requester(self.frank),
                 synapse.types.create_requester(self.frank),
                 "http://my.server/pic.gif",
                 "http://my.server/pic.gif",
-            )
+            ),
+            SynapseError,
         )
         )
-
-        yield self.assertFailure(d, SynapseError)

+ 0 - 28
tests/rest/client/v1/test_typing.py

@@ -18,8 +18,6 @@
 
 
 from mock import Mock
 from mock import Mock
 
 
-from twisted.internet import defer
-
 from synapse.rest.client.v1 import room
 from synapse.rest.client.v1 import room
 from synapse.types import UserID
 from synapse.types import UserID
 
 
@@ -60,32 +58,6 @@ class RoomTypingTestCase(unittest.HomeserverTestCase):
 
 
         hs.get_datastore().insert_client_ip = _insert_client_ip
         hs.get_datastore().insert_client_ip = _insert_client_ip
 
 
-        def get_room_members(room_id):
-            if room_id == self.room_id:
-                return defer.succeed([self.user])
-            else:
-                return defer.succeed([])
-
-        @defer.inlineCallbacks
-        def fetch_room_distributions_into(
-            room_id, localusers=None, remotedomains=None, ignore_user=None
-        ):
-            members = yield get_room_members(room_id)
-            for member in members:
-                if ignore_user is not None and member == ignore_user:
-                    continue
-
-                if hs.is_mine(member):
-                    if localusers is not None:
-                        localusers.add(member)
-                else:
-                    if remotedomains is not None:
-                        remotedomains.add(member.domain)
-
-        hs.get_room_member_handler().fetch_room_distributions_into = (
-            fetch_room_distributions_into
-        )
-
         return hs
         return hs
 
 
     def prepare(self, reactor, clock, hs):
     def prepare(self, reactor, clock, hs):