Просмотр исходного кода

Move server name config handling to SydentConfig

Azrenbeth 2 лет назад
Родитель
Сommit
1492c1c2fc

+ 3 - 3
scripts/casefold_db.py

@@ -143,7 +143,7 @@ def update_local_associations(
 
 
 def update_global_associations(
-    sydent, db: sqlite3.Connection, send_email: bool, dry_run: bool
+    sydent: Sydent, db: sqlite3.Connection, send_email: bool, dry_run: bool
 ):
     """Update the DB table global_threepid_associations so that all stored
     emails are casefolded, the signed association is re-signed and any duplicate
@@ -153,7 +153,7 @@ def update_global_associations(
     """
 
     # get every row where the local server is origin server and medium is email
-    origin_server = sydent.server_name
+    origin_server = sydent.config.general.server_name
     medium = "email"
 
     cur = db.cursor()
@@ -180,7 +180,7 @@ def update_global_associations(
         sg_assoc["address"] = address.casefold()
         sg_assoc = json.dumps(
             signedjson.sign.sign_json(
-                sg_assoc, sydent.server_name, sydent.keyring.ed25519
+                sg_assoc, sydent.config.general.server_name, sydent.keyring.ed25519
             )
         )
 

+ 3 - 0
sydent/config/__init__.py

@@ -17,6 +17,7 @@ from configparser import ConfigParser
 from sydent.config.crypto import CryptoConfig
 from sydent.config.database import DatabaseConfig
 from sydent.config.email import EmailConfig
+from sydent.config.general import GeneralConfig
 from sydent.config.http import HTTPConfig
 from sydent.config.sms import SMSConfig
 
@@ -32,6 +33,7 @@ class SydentConfig:
     """
 
     def __init__(self):
+        self.general = GeneralConfig()
         self.database = DatabaseConfig()
         self.crypto = CryptoConfig()
         self.sms = SMSConfig()
@@ -39,6 +41,7 @@ class SydentConfig:
         self.http = HTTPConfig()
 
         self.config_sections = [
+            self.general,
             self.database,
             self.crypto,
             self.sms,

+ 39 - 0
sydent/config/general.py

@@ -0,0 +1,39 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from typing import TYPE_CHECKING
+
+if TYPE_CHECKING:
+    from configparser import ConfigParser
+
+logger = logging.getLogger(__name__)
+
+
+class GeneralConfig:
+    def parse_config(self, cfg: "ConfigParser") -> None:
+        """
+        Parse the 'general' section of the config
+
+        :param cfg: the configuration to be parsed
+        """
+        self.server_name = cfg.get("general", "server.name")
+        if self.server_name == "":
+            self.server_name = os.uname()[1]
+            logger.warning(
+                "You have not specified a server name. I have guessed that this server is called '%s' ."
+                "If this is incorrect, you should edit 'general.server.name' in the config file."
+                % (self.server_name,)
+            )

+ 1 - 1
sydent/hs_federation/verifier.py

@@ -174,7 +174,7 @@ class Verifier:
         json_request = {
             "method": request.method,
             "uri": request.uri,
-            "destination_is": self.sydent.server_name,
+            "destination_is": self.sydent.config.general.server_name,
             "signatures": {},
         }
 

+ 1 - 1
sydent/http/servlets/blindlysignstuffservlet.py

@@ -36,7 +36,7 @@ class BlindlySignStuffServlet(Resource):
 
     def __init__(self, syd: "Sydent", require_auth: bool = False) -> None:
         self.sydent = syd
-        self.server_name = syd.server_name
+        self.server_name = syd.config.general.server_name
         self.tokenStore = JoinTokenStore(syd)
         self.require_auth = require_auth
 

+ 4 - 2
sydent/http/servlets/lookupservlet.py

@@ -63,7 +63,7 @@ class LookupServlet(Resource):
             return {}
 
         sgassoc = json_decoder.decode(sgassoc)
-        if self.sydent.server_name not in sgassoc["signatures"]:
+        if self.sydent.config.general.server_name not in sgassoc["signatures"]:
             # We have not yet worked out what the proper trust model should be.
             #
             # Maybe clients implicitly trust a server they talk to (and so we
@@ -82,7 +82,9 @@ class LookupServlet(Resource):
             # replication, so that we can undo this decision in the future if
             # we wish, without having destroyed the raw underlying data.
             sgassoc = signedjson.sign.sign_json(
-                sgassoc, self.sydent.server_name, self.sydent.keyring.ed25519
+                sgassoc,
+                self.sydent.config.general.server_name,
+                self.sydent.keyring.ed25519,
             )
         return sgassoc
 

+ 2 - 2
sydent/replication/peer.py

@@ -62,7 +62,7 @@ class LocalPeer(Peer):
     """
 
     def __init__(self, sydent: "Sydent") -> None:
-        super().__init__(sydent.server_name, {})
+        super().__init__(sydent.config.general.server_name, {})
         self.sydent = sydent
         self.hashing_store = HashingMetadataStore(sydent)
 
@@ -103,7 +103,7 @@ class LocalPeer(Peer):
                     globalAssocStore.addAssociation(
                         assocObj,
                         json.dumps(sgAssocs[localId]),
-                        self.sydent.server_name,
+                        self.sydent.config.general.server_name,
                         localId,
                     )
                 else:

+ 1 - 15
sydent/sydent.py

@@ -220,20 +220,6 @@ class Sydent:
 
         self.db = SqliteDatabase(self).db
 
-        self.server_name = self.cfg.get("general", "server.name")
-        if self.server_name == "":
-            self.server_name = os.uname()[1]
-            logger.warning(
-                (
-                    "You had not specified a server name. I have guessed that this server is called '%s' "
-                    + "and saved this in the config file. If this is incorrect, you should edit server.name in "
-                    + "the config file."
-                )
-                % (self.server_name,)
-            )
-            self.cfg.set("general", "server.name", self.server_name)
-            self.save_config()
-
         if self.cfg.has_option("general", "sentry_dsn"):
             # Only import and start sentry SDK if configured.
             import sentry_sdk
@@ -242,7 +228,7 @@ class Sydent:
                 dsn=self.cfg.get("general", "sentry_dsn"),
             )
             with sentry_sdk.configure_scope() as scope:
-                scope.set_tag("sydent_server_name", self.server_name)
+                scope.set_tag("sydent_server_name", self.config.general.server_name)
 
         if self.cfg.has_option("general", "prometheus_port"):
             import prometheus_client

+ 3 - 1
sydent/threepid/bind.py

@@ -99,7 +99,9 @@ class ThreepidBinder:
                 "token": cast(str, token["token"]),
             }
             token["signed"] = signedjson.sign.sign_json(
-                token["signed"], self.sydent.server_name, self.sydent.keyring.ed25519
+                token["signed"],
+                self.sydent.config.general.server_name,
+                self.sydent.keyring.ed25519,
             )
             invites.append(token)
         if invites:

+ 1 - 1
sydent/threepid/signer.py

@@ -43,6 +43,6 @@ class Signer:
         }
         sgassoc.update(assoc.extra_fields)
         sgassoc = signedjson.sign.sign_json(
-            sgassoc, self.sydent.server_name, self.sydent.keyring.ed25519
+            sgassoc, self.sydent.config.general.server_name, self.sydent.keyring.ed25519
         )
         return sgassoc

+ 1 - 1
tests/test_casefold_migration.py

@@ -95,7 +95,7 @@ class MigrationTestCase(unittest.TestCase):
 
         # create some global associations
         associations = []
-        originServer = self.sydent.server_name
+        originServer = self.sydent.config.general.server_name
 
         for i in range(10):
             address = "bob%d@example.com" % i