Sfoglia il codice sorgente

Make sydent.validators pass `mypy --strict` (#425)

* Bump phonenumbers so we can use its type stubs
* Use snake_case instead of lowerCamelCase
* Don't int(...) an int
David Robertson 2 anni fa
parent
commit
57ba780bbd

+ 1 - 0
changelog.d/425.misc

@@ -0,0 +1 @@
+Add type hints so `sydent.validators` passes `mypy --strict`.

+ 1 - 1
pyproject.toml

@@ -52,6 +52,7 @@ files = [
     "sydent/db",
     "sydent/users",
     "sydent/util",
+    "sydent/validators",
     # TODO the rest of CI checks these---mypy ought to too.
     # "tests",
     # "matrix_is_test",
@@ -66,7 +67,6 @@ module = [
     "nacl.*",
     "netaddr",
     "prometheus_client",
-    "phonenumbers",
     "sentry_sdk",
     "signedjson.*",
     "sortedcontainers",

+ 1 - 1
setup.py

@@ -47,7 +47,7 @@ setup(
         "Twisted>=18.4.0",
         # twisted warns about about the absence of this
         "service_identity>=1.0.0",
-        "phonenumbers",
+        "phonenumbers>=8.12.32",
         "pyopenssl",
         "attrs>=19.1.0",
         "netaddr>=0.7.0",

+ 34 - 25
sydent/db/valsession.py

@@ -13,15 +13,17 @@
 # limitations under the License.
 
 from random import SystemRandom
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Optional, Tuple
 
 import sydent.util.tokenutils
 from sydent.util import time_msec
 from sydent.validators import (
+    THREEPID_SESSION_VALID_LIFETIME_MS,
     IncorrectClientSecretException,
     InvalidSessionIdException,
     SessionExpiredException,
     SessionNotValidatedException,
+    TokenInfo,
     ValidationSession,
 )
 
@@ -36,7 +38,7 @@ class ThreePidValSessionStore:
 
     def getOrCreateTokenSession(
         self, medium: str, address: str, clientSecret: str
-    ) -> ValidationSession:
+    ) -> Tuple[ValidationSession, TokenInfo]:
         """
         Retrieves the validation session for a given medium, address and client secret,
         or creates one if none was found.
@@ -56,13 +58,16 @@ class ThreePidValSessionStore:
             "where s.medium = ? and s.address = ? and s.clientSecret = ? and t.validationSession = s.id",
             (medium, address, clientSecret),
         )
-        row = cur.fetchone()
+        row: Optional[
+            Tuple[int, str, str, str, Optional[int], int, str, int]
+        ] = cur.fetchone()
 
         if row:
-            s = ValidationSession(
-                row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7]
+            session = ValidationSession(
+                row[0], row[1], row[2], row[3], bool(row[4]), row[5]
             )
-            return s
+            token_info = TokenInfo(row[6], row[7])
+            return session, token_info
 
         sid = self.addValSession(
             medium, address, clientSecret, time_msec(), commit=False
@@ -76,10 +81,16 @@ class ThreePidValSessionStore:
         )
         self.sydent.db.commit()
 
-        s = ValidationSession(
-            sid, medium, address, clientSecret, False, time_msec(), tokenString, -1
+        session = ValidationSession(
+            sid,
+            medium,
+            address,
+            clientSecret,
+            False,
+            time_msec(),
         )
-        return s
+        token_info = TokenInfo(tokenString, -1)
+        return session, token_info
 
     def addValSession(
         self,
@@ -178,16 +189,16 @@ class ThreePidValSessionStore:
             + "threepid_validation_sessions where id = ?",
             (sid,),
         )
-        row = cur.fetchone()
+        row: Optional[Tuple[int, str, str, str, Optional[int], int]] = cur.fetchone()
 
         if not row:
             return None
 
-        return ValidationSession(
-            row[0], row[1], row[2], row[3], row[4], row[5], None, None
-        )
+        return ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5])
 
-    def getTokenSessionById(self, sid: int) -> Optional[ValidationSession]:
+    def getTokenSessionById(
+        self, sid: int
+    ) -> Optional[Tuple[ValidationSession, TokenInfo]]:
         """
         Retrieves a validation session using the session's ID.
 
@@ -203,23 +214,23 @@ class ThreePidValSessionStore:
             "where s.id = ? and t.validationSession = s.id",
             (sid,),
         )
+        row: Optional[Tuple[int, str, str, str, Optional[int], int, str, int]]
         row = cur.fetchone()
 
         if row:
-            s = ValidationSession(
-                row[0], row[1], row[2], row[3], row[4], row[5], row[6], row[7]
-            )
-            return s
+            s = ValidationSession(row[0], row[1], row[2], row[3], bool(row[4]), row[5])
+            t = TokenInfo(row[6], row[7])
+            return s, t
 
         return None
 
-    def getValidatedSession(self, sid: int, clientSecret: str) -> ValidationSession:
+    def getValidatedSession(self, sid: int, client_secret: str) -> ValidationSession:
         """
         Retrieve a validated and still-valid session whose client secret matches the
         one passed in.
 
         :param sid: The ID of the session to retrieve.
-        :param clientSecret: A client secret to check against the one retrieved from
+        :param client_secret: A client secret to check against the one retrieved from
             the database.
 
         :return: The retrieved session.
@@ -236,10 +247,10 @@ class ThreePidValSessionStore:
         if not s:
             raise InvalidSessionIdException()
 
-        if not s.clientSecret == clientSecret:
+        if not s.client_secret == client_secret:
             raise IncorrectClientSecretException()
 
-        if s.mtime + ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS < time_msec():
+        if s.mtime + THREEPID_SESSION_VALID_LIFETIME_MS < time_msec():
             raise SessionExpiredException()
 
         if not s.validated:
@@ -252,9 +263,7 @@ class ThreePidValSessionStore:
 
         cur = self.sydent.db.cursor()
 
-        delete_before_ts = (
-            time_msec() - 5 * ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS
-        )
+        delete_before_ts = time_msec() - 5 * THREEPID_SESSION_VALID_LIFETIME_MS
 
         sql = """
             DELETE FROM threepid_validation_sessions

+ 21 - 26
sydent/validators/__init__.py

@@ -11,35 +11,30 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
 
+import attr
 
+# how long a user can wait before validating a session after starting it
+THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000
+
+# how long we keep sessions for after they've been validated
+THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
 class ValidationSession:
-    # how long a user can wait before validating a session after starting it
-    THREEPID_SESSION_VALIDATION_TIMEOUT_MS = 24 * 60 * 60 * 1000
-
-    # how long we keep sessions for after they've been validated
-    THREEPID_SESSION_VALID_LIFETIME_MS = 24 * 60 * 60 * 1000
-
-    def __init__(
-        self,
-        _id: int,
-        _medium: str,
-        _address: str,
-        _clientSecret: str,
-        _validated: int,  # bool, but sqlite has no bool type
-        _mtime: int,
-        _token: Optional[str],
-        _sendAttemptNumber: Optional[int],
-    ):
-        self.id = _id
-        self.medium = _medium
-        self.address = _address
-        self.clientSecret = _clientSecret
-        self.validated = _validated
-        self.mtime = _mtime
-        self.token = _token
-        self.sendAttemptNumber = _sendAttemptNumber
+    id: int
+    medium: str
+    address: str
+    client_secret: str
+    validated: bool
+    mtime: int
+
+
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class TokenInfo:
+    token: str
+    send_attempt_number: int
 
 
 class IncorrectClientSecretException(Exception):

+ 10 - 8
sydent/validators/common.py

@@ -4,11 +4,11 @@ from typing import TYPE_CHECKING, Dict
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.util import time_msec
 from sydent.validators import (
+    THREEPID_SESSION_VALIDATION_TIMEOUT_MS,
     IncorrectClientSecretException,
     IncorrectSessionTokenException,
     InvalidSessionIdException,
     SessionExpiredException,
-    ValidationSession,
 )
 
 if TYPE_CHECKING:
@@ -39,16 +39,18 @@ def validateSessionWithToken(
     :raise IncorrectSessionTokenException: The provided token is incorrect
     """
     valSessionStore = ThreePidValSessionStore(sydent)
-    s = valSessionStore.getTokenSessionById(sid)
-    if not s:
+    result = valSessionStore.getTokenSessionById(sid)
+    if not result:
         logger.info("Session ID %s not found", sid)
         raise InvalidSessionIdException()
 
-    if not clientSecret == s.clientSecret:
+    session, token_info = result
+
+    if not clientSecret == session.client_secret:
         logger.info("Incorrect client secret", sid)
         raise IncorrectClientSecretException()
 
-    if s.mtime + ValidationSession.THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec():
+    if session.mtime + THREEPID_SESSION_VALIDATION_TIMEOUT_MS < time_msec():
         logger.info("Session expired")
         raise SessionExpiredException()
 
@@ -56,9 +58,9 @@ def validateSessionWithToken(
     # if tokenObj.validated and clientSecret == tokenObj.clientSecret:
     #    return True
 
-    if s.token == token:
-        logger.info("Setting session %s as validated", s.id)
-        valSessionStore.setValidated(s.id, True)
+    if token_info.token == token:
+        logger.info("Setting session %s as validated", session.id)
+        valSessionStore.setValidated(session.id, True)
 
         return {"success": True}
     else:

+ 19 - 13
sydent/validators/emailvalidator.py

@@ -23,7 +23,6 @@ from sydent.validators import common
 
 if TYPE_CHECKING:
     from sydent.sydent import Sydent
-    from sydent.validators import ValidationSession
 
 logger = logging.getLogger(__name__)
 
@@ -57,7 +56,7 @@ class EmailValidator:
         """
         valSessionStore = ThreePidValSessionStore(self.sydent)
 
-        valSession = valSessionStore.getOrCreateTokenSession(
+        valSession, token_info = valSessionStore.getOrCreateTokenSession(
             medium="email", address=emailAddress, clientSecret=clientSecret
         )
 
@@ -72,11 +71,11 @@ class EmailValidator:
         else:
             templateFile = self.sydent.config.email.template
 
-        if int(valSession.sendAttemptNumber) >= int(sendAttempt):
+        if token_info.send_attempt_number >= sendAttempt:
             logger.info(
                 "Not mailing code because current send attempt (%d) is not less than given send attempt (%s)",
-                int(sendAttempt),
-                int(valSession.sendAttemptNumber),
+                sendAttempt,
+                token_info.send_attempt_number,
             )
             return valSession.id
 
@@ -84,12 +83,14 @@ class EmailValidator:
 
         substitutions = {
             "ipaddress": ipstring,
-            "link": self.makeValidateLink(valSession, clientSecret, nextLink),
-            "token": valSession.token,
+            "link": self.makeValidateLink(
+                valSession.id, token_info.token, clientSecret, nextLink
+            ),
+            "token": token_info.token,
         }
         logger.info(
             "Attempting to mail code %s (nextLink: %s) to %s",
-            valSession.token,
+            token_info.token,
             nextLink,
             emailAddress,
         )
@@ -100,12 +101,17 @@ class EmailValidator:
         return valSession.id
 
     def makeValidateLink(
-        self, valSession: "ValidationSession", clientSecret: str, nextLink: str
+        self,
+        session_id: int,
+        token: str,
+        clientSecret: str,
+        nextLink: str,
     ) -> str:
         """
         Creates a validation link that can be sent via email to the user.
 
-        :param valSession: The current validation session.
+        :param session_id: The current validation session's ID.
+        :param token: The token to make a link for.
         :param clientSecret: The client secret to include in the link.
         :param nextLink: The link to redirect the user to once they have completed the
             validation.
@@ -115,9 +121,9 @@ class EmailValidator:
         base = self.sydent.config.http.server_http_url_base
         link = "%s/_matrix/identity/api/v1/validate/email/submitToken?token=%s&client_secret=%s&sid=%d" % (
             base,
-            urllib.parse.quote(valSession.token),
+            urllib.parse.quote(token),
             urllib.parse.quote(clientSecret),
-            valSession.id,
+            session_id,
         )
         if nextLink:
             # manipulate the nextLink to add the sid, because
@@ -127,7 +133,7 @@ class EmailValidator:
                 nextLink += "&"
             else:
                 nextLink += "?"
-            nextLink += "sid=" + urllib.parse.quote(str(valSession.id))
+            nextLink += "sid=" + urllib.parse.quote(str(session_id))
 
             link += "&nextLink=%s" % (urllib.parse.quote(nextLink))
         return link

+ 10 - 10
sydent/validators/msisdnvalidator.py

@@ -16,7 +16,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, Optional
 
-import phonenumbers  # type: ignore
+import phonenumbers
 
 from sydent.db.valsession import ThreePidValSessionStore
 from sydent.sms.openmarket import OpenMarketSMS
@@ -42,7 +42,7 @@ class MsisdnValidator:
         self,
         phoneNumber: phonenumbers.PhoneNumber,
         clientSecret: str,
-        sendAttempt: int,
+        send_attempt: int,
         brand: Optional[str] = None,
     ) -> int:
         """
@@ -51,7 +51,7 @@ class MsisdnValidator:
 
         :param phoneNumber: The phone number to send the email to.
         :param clientSecret: The client secret to use.
-        :param sendAttempt: The current send attempt.
+        :param send_attempt: The current send attempt.
         :param brand: A hint at a brand from the request.
 
         :return: The ID of the session created (or of the existing one if any)
@@ -67,17 +67,17 @@ class MsisdnValidator:
             phoneNumber, phonenumbers.PhoneNumberFormat.E164
         )[1:]
 
-        valSession = valSessionStore.getOrCreateTokenSession(
+        valSession, token_info = valSessionStore.getOrCreateTokenSession(
             medium="msisdn", address=msisdn, clientSecret=clientSecret
         )
 
         valSessionStore.setMtime(valSession.id, time_msec())
 
-        if int(valSession.sendAttemptNumber) >= int(sendAttempt):
+        if token_info.send_attempt_number >= send_attempt:
             logger.info(
                 "Not texting code because current send attempt (%d) is not less than given send attempt (%s)",
-                int(sendAttempt),
-                int(valSession.sendAttemptNumber),
+                send_attempt,
+                token_info.send_attempt_number,
             )
             return valSession.id
 
@@ -86,17 +86,17 @@ class MsisdnValidator:
 
         logger.info(
             "Attempting to text code %s to %s (country %d) with originator %s",
-            valSession.token,
+            token_info.token,
             msisdn,
             phoneNumber.country_code,
             originator,
         )
 
-        smsBody = smsBodyTemplate.format(token=valSession.token)
+        smsBody = smsBodyTemplate.format(token=token_info.token)
 
         self.omSms.sendTextSMS(smsBody, msisdn, originator)
 
-        valSessionStore.setSendAttemptNumber(valSession.id, sendAttempt)
+        valSessionStore.setSendAttemptNumber(valSession.id, send_attempt)
 
         return valSession.id