Browse Source

Add missing type hints to tests.config. (#14681)

Patrick Cloke 1 year ago
parent
commit
3aeca2588b

+ 1 - 0
changelog.d/14681.misc

@@ -0,0 +1 @@
+Add missing type hints.

+ 1 - 3
mypy.ini

@@ -36,8 +36,6 @@ exclude = (?x)
    |tests/api/test_ratelimiting.py
    |tests/app/test_openid_listener.py
    |tests/appservice/test_scheduler.py
-   |tests/config/test_cache.py
-   |tests/config/test_tls.py
    |tests/crypto/test_keyring.py
    |tests/events/test_presence_router.py
    |tests/events/test_utils.py
@@ -89,7 +87,7 @@ disallow_untyped_defs = False
 [mypy-tests.*]
 disallow_untyped_defs = False
 
-[mypy-tests.config.test_api]
+[mypy-tests.config.*]
 disallow_untyped_defs = True
 
 [mypy-tests.federation.transport.test_client]

+ 2 - 2
synapse/config/cache.py

@@ -16,7 +16,7 @@ import logging
 import os
 import re
 import threading
-from typing import Any, Callable, Dict, Optional
+from typing import Any, Callable, Dict, Mapping, Optional
 
 import attr
 
@@ -94,7 +94,7 @@ def add_resizable_cache(
 
 class CacheConfig(Config):
     section = "caches"
-    _environ = os.environ
+    _environ: Mapping[str, str] = os.environ
 
     event_cache_size: int
     cache_factors: Dict[str, float]

+ 2 - 7
synapse/util/caches/lrucache.py

@@ -788,26 +788,21 @@ class LruCache(Generic[KT, VT]):
     def __contains__(self, key: KT) -> bool:
         return self.contains(key)
 
-    def set_cache_factor(self, factor: float) -> bool:
+    def set_cache_factor(self, factor: float) -> None:
         """
         Set the cache factor for this individual cache.
 
         This will trigger a resize if it changes, which may require evicting
         items from the cache.
-
-        Returns:
-            Whether the cache changed size or not.
         """
         if not self.apply_cache_factor_from_config:
-            return False
+            return
 
         new_size = int(self._original_max_size * factor)
         if new_size != self.max_size:
             self.max_size = new_size
             if self._on_resize:
                 self._on_resize()
-            return True
-        return False
 
     def __del__(self) -> None:
         # We're about to be deleted, so we make sure to clear up all the nodes

+ 3 - 3
tests/config/test___main__.py

@@ -17,15 +17,15 @@ from tests.config.utils import ConfigFileTestCase
 
 
 class ConfigMainFileTestCase(ConfigFileTestCase):
-    def test_executes_without_an_action(self):
+    def test_executes_without_an_action(self) -> None:
         self.generate_config()
         main(["", "-c", self.config_file])
 
-    def test_read__error_if_key_not_found(self):
+    def test_read__error_if_key_not_found(self) -> None:
         self.generate_config()
         with self.assertRaises(SystemExit):
             main(["", "read", "foo.bar.hello", "-c", self.config_file])
 
-    def test_read__passes_if_key_found(self):
+    def test_read__passes_if_key_found(self) -> None:
         self.generate_config()
         main(["", "read", "server.server_name", "-c", self.config_file])

+ 2 - 2
tests/config/test_background_update.py

@@ -22,7 +22,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
     # Tests that the default values in the config are correctly loaded. Note that the default
     # values are loaded when the corresponding config options are commented out, which is why there isn't
     # a config specified here.
-    def test_default_configuration(self):
+    def test_default_configuration(self) -> None:
         background_updater = BackgroundUpdater(
             self.hs, self.hs.get_datastores().main.db_pool
         )
@@ -46,7 +46,7 @@ class BackgroundUpdateConfigTestCase(HomeserverTestCase):
             """
         )
     )
-    def test_custom_configuration(self):
+    def test_custom_configuration(self) -> None:
         background_updater = BackgroundUpdater(
             self.hs, self.hs.get_datastores().main.db_pool
         )

+ 5 - 5
tests/config/test_base.py

@@ -24,13 +24,13 @@ from tests import unittest
 
 
 class BaseConfigTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         # The root object needs a server property with a public_baseurl.
         root = Mock()
         root.server.public_baseurl = "http://test"
         self.config = Config(root)
 
-    def test_loading_missing_templates(self):
+    def test_loading_missing_templates(self) -> None:
         # Use a temporary directory that exists on the system, but that isn't likely to
         # contain template files
         with tempfile.TemporaryDirectory() as tmp_dir:
@@ -50,7 +50,7 @@ class BaseConfigTestCase(unittest.TestCase):
             "Template file did not contain our test string",
         )
 
-    def test_loading_custom_templates(self):
+    def test_loading_custom_templates(self) -> None:
         # Use a temporary directory that exists on the system
         with tempfile.TemporaryDirectory() as tmp_dir:
             # Create a temporary bogus template file
@@ -79,7 +79,7 @@ class BaseConfigTestCase(unittest.TestCase):
             "Template file did not contain our test string",
         )
 
-    def test_multiple_custom_template_directories(self):
+    def test_multiple_custom_template_directories(self) -> None:
         """Tests that directories are searched in the right order if multiple custom
         template directories are provided.
         """
@@ -137,7 +137,7 @@ class BaseConfigTestCase(unittest.TestCase):
         for td in tempdirs:
             td.cleanup()
 
-    def test_loading_template_from_nonexistent_custom_directory(self):
+    def test_loading_template_from_nonexistent_custom_directory(self) -> None:
         with self.assertRaises(ConfigError):
             self.config.read_templates(
                 ["some_filename.html"], ("a_nonexistent_directory",)

+ 29 - 28
tests/config/test_cache.py

@@ -13,26 +13,27 @@
 # limitations under the License.
 
 from synapse.config.cache import CacheConfig, add_resizable_cache
+from synapse.types import JsonDict
 from synapse.util.caches.lrucache import LruCache
 
 from tests.unittest import TestCase
 
 
 class CacheConfigTests(TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         # Reset caches before each test since there's global state involved.
         self.config = CacheConfig()
         self.config.reset()
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         # Also reset the caches after each test to leave state pristine.
         self.config.reset()
 
-    def test_individual_caches_from_environ(self):
+    def test_individual_caches_from_environ(self) -> None:
         """
         Individual cache factors will be loaded from the environment.
         """
-        config = {}
+        config: JsonDict = {}
         self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
             "SYNAPSE_NOT_CACHE": "BLAH",
@@ -42,15 +43,15 @@ class CacheConfigTests(TestCase):
 
         self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0})
 
-    def test_config_overrides_environ(self):
+    def test_config_overrides_environ(self) -> None:
         """
         Individual cache factors defined in the environment will take precedence
         over those in the config.
         """
-        config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
+        config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}}
         self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2",
-            "SYNAPSE_CACHE_FACTOR_FOO": 1,
+            "SYNAPSE_CACHE_FACTOR_FOO": "1",
         }
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
@@ -60,104 +61,104 @@ class CacheConfigTests(TestCase):
             {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0},
         )
 
-    def test_individual_instantiated_before_config_load(self):
+    def test_individual_instantiated_before_config_load(self) -> None:
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized once the config
         is loaded.
         """
-        cache = LruCache(100)
+        cache: LruCache = LruCache(100)
 
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 50)
 
-        config = {"caches": {"per_cache_factors": {"foo": 3}}}
+        config: JsonDict = {"caches": {"per_cache_factors": {"foo": 3}}}
         self.config.read_config(config)
         self.config.resize_all_caches()
 
         self.assertEqual(cache.max_size, 300)
 
-    def test_individual_instantiated_after_config_load(self):
+    def test_individual_instantiated_after_config_load(self) -> None:
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the per_cache_factor if
         there is one.
         """
-        config = {"caches": {"per_cache_factors": {"foo": 2}}}
+        config: JsonDict = {"caches": {"per_cache_factors": {"foo": 2}}}
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
 
-        cache = LruCache(100)
+        cache: LruCache = LruCache(100)
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 200)
 
-    def test_global_instantiated_before_config_load(self):
+    def test_global_instantiated_before_config_load(self) -> None:
         """
         If a cache is instantiated before the config is read, it will be given
         the default cache size in the interim, and then resized to the new
         default cache size once the config is loaded.
         """
-        cache = LruCache(100)
+        cache: LruCache = LruCache(100)
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 50)
 
-        config = {"caches": {"global_factor": 4}}
+        config: JsonDict = {"caches": {"global_factor": 4}}
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
 
         self.assertEqual(cache.max_size, 400)
 
-    def test_global_instantiated_after_config_load(self):
+    def test_global_instantiated_after_config_load(self) -> None:
         """
         If a cache is instantiated after the config is read, it will be
         immediately resized to the correct size given the global factor if there
         is no per-cache factor.
         """
-        config = {"caches": {"global_factor": 1.5}}
+        config: JsonDict = {"caches": {"global_factor": 1.5}}
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
 
-        cache = LruCache(100)
+        cache: LruCache = LruCache(100)
         add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor)
         self.assertEqual(cache.max_size, 150)
 
-    def test_cache_with_asterisk_in_name(self):
+    def test_cache_with_asterisk_in_name(self) -> None:
         """Some caches have asterisks in their name, test that they are set correctly."""
 
-        config = {
+        config: JsonDict = {
             "caches": {
                 "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2}
             }
         }
         self.config._environ = {
             "SYNAPSE_CACHE_FACTOR_CACHE_A": "2",
-            "SYNAPSE_CACHE_FACTOR_CACHE_B": 3,
+            "SYNAPSE_CACHE_FACTOR_CACHE_B": "3",
         }
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
 
-        cache_a = LruCache(100)
+        cache_a: LruCache = LruCache(100)
         add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor)
         self.assertEqual(cache_a.max_size, 200)
 
-        cache_b = LruCache(100)
+        cache_b: LruCache = LruCache(100)
         add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor)
         self.assertEqual(cache_b.max_size, 300)
 
-        cache_c = LruCache(100)
+        cache_c: LruCache = LruCache(100)
         add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor)
         self.assertEqual(cache_c.max_size, 200)
 
-    def test_apply_cache_factor_from_config(self):
+    def test_apply_cache_factor_from_config(self) -> None:
         """Caches can disable applying cache factor updates, mainly used by
         event cache size.
         """
 
-        config = {"caches": {"event_cache_size": "10k"}}
+        config: JsonDict = {"caches": {"event_cache_size": "10k"}}
         self.config.read_config(config, config_dir_path="", data_dir_path="")
         self.config.resize_all_caches()
 
-        cache = LruCache(
+        cache: LruCache = LruCache(
             max_size=self.config.event_cache_size,
             apply_cache_factor_from_config=False,
         )

+ 1 - 1
tests/config/test_database.py

@@ -20,7 +20,7 @@ from tests import unittest
 
 
 class DatabaseConfigTestCase(unittest.TestCase):
-    def test_database_configured_correctly(self):
+    def test_database_configured_correctly(self) -> None:
         conf = yaml.safe_load(
             DatabaseConfig().generate_config_section(data_dir_path="/data_dir_path")
         )

+ 4 - 4
tests/config/test_generate.py

@@ -25,14 +25,14 @@ from tests import unittest
 
 
 class ConfigGenerationTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dir = tempfile.mkdtemp()
         self.file = os.path.join(self.dir, "homeserver.yaml")
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.dir)
 
-    def test_generate_config_generates_files(self):
+    def test_generate_config_generates_files(self) -> None:
         with redirect_stdout(StringIO()):
             HomeServerConfig.load_or_generate_config(
                 "",
@@ -56,7 +56,7 @@ class ConfigGenerationTestCase(unittest.TestCase):
             os.path.join(os.getcwd(), "homeserver.log"),
         )
 
-    def assert_log_filename_is(self, log_config_file, expected):
+    def assert_log_filename_is(self, log_config_file: str, expected: str) -> None:
         with open(log_config_file) as f:
             config = f.read()
             # find the 'filename' line

+ 6 - 6
tests/config/test_load.py

@@ -21,14 +21,14 @@ from tests.config.utils import ConfigFileTestCase
 
 
 class ConfigLoadingFileTestCase(ConfigFileTestCase):
-    def test_load_fails_if_server_name_missing(self):
+    def test_load_fails_if_server_name_missing(self) -> None:
         self.generate_config_and_remove_lines_containing("server_name")
         with self.assertRaises(ConfigError):
             HomeServerConfig.load_config("", ["-c", self.config_file])
         with self.assertRaises(ConfigError):
             HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
 
-    def test_generates_and_loads_macaroon_secret_key(self):
+    def test_generates_and_loads_macaroon_secret_key(self) -> None:
         self.generate_config()
 
         with open(self.config_file) as f:
@@ -58,7 +58,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
                 "was: %r" % (config2.key.macaroon_secret_key,)
             )
 
-    def test_load_succeeds_if_macaroon_secret_key_missing(self):
+    def test_load_succeeds_if_macaroon_secret_key_missing(self) -> None:
         self.generate_config_and_remove_lines_containing("macaroon")
         config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
         config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
@@ -73,7 +73,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
             config1.key.macaroon_secret_key, config3.key.macaroon_secret_key
         )
 
-    def test_disable_registration(self):
+    def test_disable_registration(self) -> None:
         self.generate_config()
         self.add_lines_to_config(
             ["enable_registration: true", "disable_registration: true"]
@@ -93,7 +93,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
         assert config3 is not None
         self.assertTrue(config3.registration.enable_registration)
 
-    def test_stats_enabled(self):
+    def test_stats_enabled(self) -> None:
         self.generate_config_and_remove_lines_containing("enable_metrics")
         self.add_lines_to_config(["enable_metrics: true"])
 
@@ -101,7 +101,7 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
         config = HomeServerConfig.load_config("", ["-c", self.config_file])
         self.assertFalse(config.metrics.metrics_flags.known_servers)
 
-    def test_depreciated_identity_server_flag_throws_error(self):
+    def test_depreciated_identity_server_flag_throws_error(self) -> None:
         self.generate_config()
         # Needed to ensure that actual key/value pair added below don't end up on a line with a comment
         self.add_lines_to_config([" "])

+ 1 - 1
tests/config/test_ratelimiting.py

@@ -18,7 +18,7 @@ from tests.utils import default_config
 
 
 class RatelimitConfigTestCase(TestCase):
-    def test_parse_rc_federation(self):
+    def test_parse_rc_federation(self) -> None:
         config_dict = default_config("test")
         config_dict["rc_federation"] = {
             "window_size": 20000,

+ 2 - 2
tests/config/test_registration_config.py

@@ -21,7 +21,7 @@ from tests.utils import default_config
 
 
 class RegistrationConfigTestCase(ConfigFileTestCase):
-    def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self):
+    def test_session_lifetime_must_not_be_exceeded_by_smaller_lifetimes(self) -> None:
         """
         session_lifetime should logically be larger than, or at least as large as,
         all the different token lifetimes.
@@ -91,7 +91,7 @@ class RegistrationConfigTestCase(ConfigFileTestCase):
             "",
         )
 
-    def test_refuse_to_start_if_open_registration_and_no_verification(self):
+    def test_refuse_to_start_if_open_registration_and_no_verification(self) -> None:
         self.generate_config()
         self.add_lines_to_config(
             [

+ 2 - 2
tests/config/test_room_directory.py

@@ -20,7 +20,7 @@ from tests import unittest
 
 
 class RoomDirectoryConfigTestCase(unittest.TestCase):
-    def test_alias_creation_acl(self):
+    def test_alias_creation_acl(self) -> None:
         config = yaml.safe_load(
             """
         alias_creation_rules:
@@ -78,7 +78,7 @@ class RoomDirectoryConfigTestCase(unittest.TestCase):
             )
         )
 
-    def test_room_publish_acl(self):
+    def test_room_publish_acl(self) -> None:
         config = yaml.safe_load(
             """
         alias_creation_rules: []

+ 9 - 9
tests/config/test_server.py

@@ -21,7 +21,7 @@ from tests import unittest
 
 
 class ServerConfigTestCase(unittest.TestCase):
-    def test_is_threepid_reserved(self):
+    def test_is_threepid_reserved(self) -> None:
         user1 = {"medium": "email", "address": "user1@example.com"}
         user2 = {"medium": "email", "address": "user2@example.com"}
         user3 = {"medium": "email", "address": "user3@example.com"}
@@ -32,7 +32,7 @@ class ServerConfigTestCase(unittest.TestCase):
         self.assertFalse(is_threepid_reserved(config, user3))
         self.assertFalse(is_threepid_reserved(config, user1_msisdn))
 
-    def test_unsecure_listener_no_listeners_open_private_ports_false(self):
+    def test_unsecure_listener_no_listeners_open_private_ports_false(self) -> None:
         conf = yaml.safe_load(
             ServerConfig().generate_config_section(
                 "CONFDIR", "/data_dir_path", "che.org", False, None
@@ -52,7 +52,7 @@ class ServerConfigTestCase(unittest.TestCase):
 
         self.assertEqual(conf["listeners"], expected_listeners)
 
-    def test_unsecure_listener_no_listeners_open_private_ports_true(self):
+    def test_unsecure_listener_no_listeners_open_private_ports_true(self) -> None:
         conf = yaml.safe_load(
             ServerConfig().generate_config_section(
                 "CONFDIR", "/data_dir_path", "che.org", True, None
@@ -71,7 +71,7 @@ class ServerConfigTestCase(unittest.TestCase):
 
         self.assertEqual(conf["listeners"], expected_listeners)
 
-    def test_listeners_set_correctly_open_private_ports_false(self):
+    def test_listeners_set_correctly_open_private_ports_false(self) -> None:
         listeners = [
             {
                 "port": 8448,
@@ -95,7 +95,7 @@ class ServerConfigTestCase(unittest.TestCase):
 
         self.assertEqual(conf["listeners"], listeners)
 
-    def test_listeners_set_correctly_open_private_ports_true(self):
+    def test_listeners_set_correctly_open_private_ports_true(self) -> None:
         listeners = [
             {
                 "port": 8448,
@@ -131,14 +131,14 @@ class ServerConfigTestCase(unittest.TestCase):
 
 
 class GenerateIpSetTestCase(unittest.TestCase):
-    def test_empty(self):
+    def test_empty(self) -> None:
         ip_set = generate_ip_set(())
         self.assertFalse(ip_set)
 
         ip_set = generate_ip_set((), ())
         self.assertFalse(ip_set)
 
-    def test_generate(self):
+    def test_generate(self) -> None:
         """Check adding IPv4 and IPv6 addresses."""
         # IPv4 address
         ip_set = generate_ip_set(("1.2.3.4",))
@@ -160,7 +160,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
         ip_set = generate_ip_set(("1.2.3.4", "::1.2.3.4"))
         self.assertEqual(len(ip_set.iter_cidrs()), 4)
 
-    def test_extra(self):
+    def test_extra(self) -> None:
         """Extra IP addresses are treated the same."""
         ip_set = generate_ip_set((), ("1.2.3.4",))
         self.assertEqual(len(ip_set.iter_cidrs()), 4)
@@ -172,7 +172,7 @@ class GenerateIpSetTestCase(unittest.TestCase):
         ip_set = generate_ip_set(("1.2.3.4",), ("1.2.3.4",))
         self.assertEqual(len(ip_set.iter_cidrs()), 4)
 
-    def test_bad_value(self):
+    def test_bad_value(self) -> None:
         """An error should be raised if a bad value is passed in."""
         with self.assertRaises(ConfigError):
             generate_ip_set(("not-an-ip",))

+ 31 - 22
tests/config/test_tls.py

@@ -13,13 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import cast
+
 import idna
 
 from OpenSSL import SSL
 
 from synapse.config._base import Config, RootConfig
+from synapse.config.homeserver import HomeServerConfig
 from synapse.config.tls import ConfigError, TlsConfig
-from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.crypto.context_factory import (
+    FederationPolicyForHTTPS,
+    SSLClientConnectionCreator,
+)
+from synapse.types import JsonDict
 
 from tests.unittest import TestCase
 
@@ -27,7 +34,7 @@ from tests.unittest import TestCase
 class FakeServer(Config):
     section = "server"
 
-    def has_tls_listener(self):
+    def has_tls_listener(self) -> bool:
         return False
 
 
@@ -36,21 +43,21 @@ class TestConfig(RootConfig):
 
 
 class TLSConfigTests(TestCase):
-    def test_tls_client_minimum_default(self):
+    def test_tls_client_minimum_default(self) -> None:
         """
         The default client TLS version is 1.0.
         """
-        config = {}
+        config: JsonDict = {}
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
         self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
 
-    def test_tls_client_minimum_set(self):
+    def test_tls_client_minimum_set(self) -> None:
         """
         The default client TLS version can be set to 1.0, 1.1, and 1.2.
         """
-        config = {"federation_client_minimum_tls_version": 1}
+        config: JsonDict = {"federation_client_minimum_tls_version": 1}
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
         self.assertEqual(t.tls.federation_client_minimum_tls_version, "1")
@@ -76,7 +83,7 @@ class TLSConfigTests(TestCase):
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
         self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2")
 
-    def test_tls_client_minimum_1_point_3_missing(self):
+    def test_tls_client_minimum_1_point_3_missing(self) -> None:
         """
         If TLS 1.3 support is missing and it's configured, it will raise a
         ConfigError.
@@ -88,7 +95,7 @@ class TLSConfigTests(TestCase):
             self.addCleanup(setattr, SSL, "SSL.OP_NO_TLSv1_3", OP_NO_TLSv1_3)
             assert not hasattr(SSL, "OP_NO_TLSv1_3")
 
-        config = {"federation_client_minimum_tls_version": 1.3}
+        config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
         t = TestConfig()
         with self.assertRaises(ConfigError) as e:
             t.tls.read_config(config, config_dir_path="", data_dir_path="")
@@ -100,7 +107,7 @@ class TLSConfigTests(TestCase):
             ),
         )
 
-    def test_tls_client_minimum_1_point_3_exists(self):
+    def test_tls_client_minimum_1_point_3_exists(self) -> None:
         """
         If TLS 1.3 support exists and it's configured, it will be settable.
         """
@@ -110,20 +117,20 @@ class TLSConfigTests(TestCase):
             self.addCleanup(lambda: delattr(SSL, "OP_NO_TLSv1_3"))
             assert hasattr(SSL, "OP_NO_TLSv1_3")
 
-        config = {"federation_client_minimum_tls_version": 1.3}
+        config: JsonDict = {"federation_client_minimum_tls_version": 1.3}
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
         self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3")
 
-    def test_tls_client_minimum_set_passed_through_1_2(self):
+    def test_tls_client_minimum_set_passed_through_1_2(self) -> None:
         """
         The configured TLS version is correctly configured by the ContextFactory.
         """
-        config = {"federation_client_minimum_tls_version": 1.2}
+        config: JsonDict = {"federation_client_minimum_tls_version": 1.2}
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = FederationPolicyForHTTPS(t)
+        cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
         options = _get_ssl_context_options(cf._verify_ssl_context)
 
         # The context has had NO_TLSv1_1 and NO_TLSv1_0 set, but not NO_TLSv1_2
@@ -131,15 +138,15 @@ class TLSConfigTests(TestCase):
         self.assertNotEqual(options & SSL.OP_NO_TLSv1_1, 0)
         self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
 
-    def test_tls_client_minimum_set_passed_through_1_0(self):
+    def test_tls_client_minimum_set_passed_through_1_0(self) -> None:
         """
         The configured TLS version is correctly configured by the ContextFactory.
         """
-        config = {"federation_client_minimum_tls_version": 1}
+        config: JsonDict = {"federation_client_minimum_tls_version": 1}
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = FederationPolicyForHTTPS(t)
+        cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
         options = _get_ssl_context_options(cf._verify_ssl_context)
 
         # The context has not had any of the NO_TLS set.
@@ -147,11 +154,11 @@ class TLSConfigTests(TestCase):
         self.assertEqual(options & SSL.OP_NO_TLSv1_1, 0)
         self.assertEqual(options & SSL.OP_NO_TLSv1_2, 0)
 
-    def test_whitelist_idna_failure(self):
+    def test_whitelist_idna_failure(self) -> None:
         """
         The federation certificate whitelist will not allow IDNA domain names.
         """
-        config = {
+        config: JsonDict = {
             "federation_certificate_verification_whitelist": [
                 "example.com",
                 "*.ドメイン.テスト",
@@ -163,11 +170,11 @@ class TLSConfigTests(TestCase):
         )
         self.assertIn("IDNA domain names", str(e))
 
-    def test_whitelist_idna_result(self):
+    def test_whitelist_idna_result(self) -> None:
         """
         The federation certificate whitelist will match on IDNA encoded names.
         """
-        config = {
+        config: JsonDict = {
             "federation_certificate_verification_whitelist": [
                 "example.com",
                 "*.xn--eckwd4c7c.xn--zckzah",
@@ -176,14 +183,16 @@ class TLSConfigTests(TestCase):
         t = TestConfig()
         t.tls.read_config(config, config_dir_path="", data_dir_path="")
 
-        cf = FederationPolicyForHTTPS(t)
+        cf = FederationPolicyForHTTPS(cast(HomeServerConfig, t))
 
         # Not in the whitelist
         opts = cf.get_options(b"notexample.com")
+        assert isinstance(opts, SSLClientConnectionCreator)
         self.assertTrue(opts._verifier._verify_certs)
 
         # Caught by the wildcard
         opts = cf.get_options(idna.encode("テスト.ドメイン.テスト"))
+        assert isinstance(opts, SSLClientConnectionCreator)
         self.assertFalse(opts._verifier._verify_certs)
 
 
@@ -191,4 +200,4 @@ def _get_ssl_context_options(ssl_context: SSL.Context) -> int:
     """get the options bits from an openssl context object"""
     # the OpenSSL.SSL.Context wrapper doesn't expose get_options, so we have to
     # use the low-level interface
-    return SSL._lib.SSL_CTX_get_options(ssl_context._context)
+    return SSL._lib.SSL_CTX_get_options(ssl_context._context)  # type: ignore[attr-defined]

+ 1 - 1
tests/config/test_util.py

@@ -21,7 +21,7 @@ from tests.unittest import TestCase
 class ValidateConfigTestCase(TestCase):
     """Test cases for synapse.config._util.validate_config"""
 
-    def test_bad_object_in_array(self):
+    def test_bad_object_in_array(self) -> None:
         """malformed objects within an array should be validated correctly"""
 
         # consider a structure:

+ 6 - 5
tests/config/utils.py

@@ -17,19 +17,20 @@ import tempfile
 import unittest
 from contextlib import redirect_stdout
 from io import StringIO
+from typing import List
 
 from synapse.config.homeserver import HomeServerConfig
 
 
 class ConfigFileTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         self.dir = tempfile.mkdtemp()
         self.config_file = os.path.join(self.dir, "homeserver.yaml")
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.dir)
 
-    def generate_config(self):
+    def generate_config(self) -> None:
         with redirect_stdout(StringIO()):
             HomeServerConfig.load_or_generate_config(
                 "",
@@ -43,7 +44,7 @@ class ConfigFileTestCase(unittest.TestCase):
                 ],
             )
 
-    def generate_config_and_remove_lines_containing(self, needle):
+    def generate_config_and_remove_lines_containing(self, needle: str) -> None:
         self.generate_config()
 
         with open(self.config_file) as f:
@@ -52,7 +53,7 @@ class ConfigFileTestCase(unittest.TestCase):
         with open(self.config_file, "w") as f:
             f.write("".join(contents))
 
-    def add_lines_to_config(self, lines):
+    def add_lines_to_config(self, lines: List[str]) -> None:
         with open(self.config_file, "a") as f:
             for line in lines:
                 f.write(line + "\n")