Browse Source

Prefer `type(x) is int` to `isinstance(x, int)` (#14945)

* Perfer `type(x) is int` to `isinstance(x, int)`

This covered all additional instances I could see where `x` was
user-controlled.
The remaining cases are

```
$ rg -s 'isinstance.*[^_]int'
tests/replication/_base.py
576:        if isinstance(obj, int):

synapse/util/caches/stream_change_cache.py
136:        assert isinstance(stream_pos, int)
214:        assert isinstance(stream_pos, int)
246:        assert isinstance(stream_pos, int)
267:        assert isinstance(stream_pos, int)

synapse/replication/tcp/external_cache.py
133:        if isinstance(result, int):

synapse/metrics/__init__.py
100:        if isinstance(calls, (int, float)):

synapse/handlers/appservice.py
262:        assert isinstance(new_token, int)

synapse/config/_util.py
62:        if isinstance(p, int):
```

which cover metrics, logic related to `jsonschema`, and replication and
data streams. AFAICS these are all internal to Synapse

* Changelog
David Robertson 1 year ago
parent
commit
796a4b7482

+ 1 - 0
changelog.d/14945.misc

@@ -0,0 +1 @@
+Fix various long-standing bugs in Synapse's config, event and request handling where booleans were unintentionally accepted where an integer was expected.

+ 50 - 22
synapse/config/_base.py

@@ -174,15 +174,29 @@ class Config:
 
     @staticmethod
     def parse_size(value: Union[str, int]) -> int:
-        if isinstance(value, int):
+        """Interpret `value` as a number of bytes.
+
+        If an integer is provided it is treated as bytes and is unchanged.
+
+        String byte sizes can have a suffix of 'K' or `M`, representing kibibytes and
+        mebibytes respectively. No suffix is understood as a plain byte count.
+
+        Raises:
+            TypeError, if given something other than an integer or a string
+            ValueError: if given a string not of the form described above.
+        """
+        if type(value) is int:
             return value
-        sizes = {"K": 1024, "M": 1024 * 1024}
-        size = 1
-        suffix = value[-1]
-        if suffix in sizes:
-            value = value[:-1]
-            size = sizes[suffix]
-        return int(value) * size
+        elif type(value) is str:
+            sizes = {"K": 1024, "M": 1024 * 1024}
+            size = 1
+            suffix = value[-1]
+            if suffix in sizes:
+                value = value[:-1]
+                size = sizes[suffix]
+            return int(value) * size
+        else:
+            raise TypeError(f"Bad byte size {value!r}")
 
     @staticmethod
     def parse_duration(value: Union[str, int]) -> int:
@@ -198,22 +212,36 @@ class Config:
 
         Returns:
             The number of milliseconds in the duration.
+
+        Raises:
+            TypeError, if given something other than an integer or a string
+            ValueError: if given a string not of the form described above.
         """
-        if isinstance(value, int):
+        if type(value) is int:
             return value
-        second = 1000
-        minute = 60 * second
-        hour = 60 * minute
-        day = 24 * hour
-        week = 7 * day
-        year = 365 * day
-        sizes = {"s": second, "m": minute, "h": hour, "d": day, "w": week, "y": year}
-        size = 1
-        suffix = value[-1]
-        if suffix in sizes:
-            value = value[:-1]
-            size = sizes[suffix]
-        return int(value) * size
+        elif type(value) is str:
+            second = 1000
+            minute = 60 * second
+            hour = 60 * minute
+            day = 24 * hour
+            week = 7 * day
+            year = 365 * day
+            sizes = {
+                "s": second,
+                "m": minute,
+                "h": hour,
+                "d": day,
+                "w": week,
+                "y": year,
+            }
+            size = 1
+            suffix = value[-1]
+            if suffix in sizes:
+                value = value[:-1]
+                size = sizes[suffix]
+            return int(value) * size
+        else:
+            raise TypeError(f"Bad duration {value!r}")
 
     @staticmethod
     def abspath(file_path: str) -> str:

+ 2 - 2
synapse/config/cache.py

@@ -126,7 +126,7 @@ class CacheConfig(Config):
 
         cache_config = config.get("caches") or {}
         self.global_factor = cache_config.get("global_factor", _DEFAULT_FACTOR_SIZE)
-        if not isinstance(self.global_factor, (int, float)):
+        if type(self.global_factor) not in (int, float):
             raise ConfigError("caches.global_factor must be a number.")
 
         # Load cache factors from the config
@@ -151,7 +151,7 @@ class CacheConfig(Config):
         )
 
         for cache, factor in individual_factors.items():
-            if not isinstance(factor, (int, float)):
+            if type(factor) not in (int, float):
                 raise ConfigError(
                     "caches.per_cache_factors.%s must be a number" % (cache,)
                 )

+ 1 - 1
synapse/config/server.py

@@ -904,7 +904,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
         raise ConfigError(DIRECT_TCP_ERROR, ("listeners", str(num), "type"))
 
     port = listener.get("port")
-    if not isinstance(port, int):
+    if type(port) is not int:
         raise ConfigError("Listener configuration is lacking a valid 'port' option")
 
     tls = listener.get("tls", False)

+ 2 - 2
synapse/events/validator.py

@@ -139,7 +139,7 @@ class EventValidator:
         max_lifetime = event.content.get("max_lifetime")
 
         if min_lifetime is not None:
-            if not isinstance(min_lifetime, int):
+            if type(min_lifetime) is not int:
                 raise SynapseError(
                     code=400,
                     msg="'min_lifetime' must be an integer",
@@ -147,7 +147,7 @@ class EventValidator:
                 )
 
         if max_lifetime is not None:
-            if not isinstance(max_lifetime, int):
+            if type(max_lifetime) is not int:
                 raise SynapseError(
                     code=400,
                     msg="'max_lifetime' must be an integer",

+ 1 - 1
synapse/federation/federation_client.py

@@ -1864,7 +1864,7 @@ class TimestampToEventResponse:
             )
 
         origin_server_ts = d.get("origin_server_ts")
-        if not isinstance(origin_server_ts, int):
+        if type(origin_server_ts) is not int:
             raise ValueError(
                 "Invalid response: 'origin_server_ts' must be a int but received %r"
                 % origin_server_ts

+ 1 - 1
synapse/handlers/message.py

@@ -377,7 +377,7 @@ class MessageHandler:
         """
 
         expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
-        if not isinstance(expiry_ts, int) or event.is_state():
+        if type(expiry_ts) is not int or event.is_state():
             return
 
         # _schedule_expiry_for_event won't actually schedule anything if there's already

+ 1 - 1
synapse/rest/admin/__init__.py

@@ -152,7 +152,7 @@ class PurgeHistoryRestServlet(RestServlet):
             logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
         elif "purge_up_to_ts" in body:
             ts = body["purge_up_to_ts"]
-            if not isinstance(ts, int):
+            if type(ts) is not int:
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "purge_up_to_ts must be an int",

+ 7 - 8
synapse/rest/admin/registration_tokens.py

@@ -143,7 +143,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
         else:
             # Get length of token to generate (default is 16)
             length = body.get("length", 16)
-            if not isinstance(length, int):
+            if type(length) is not int:
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "length must be an integer",
@@ -163,8 +163,7 @@ class NewRegistrationTokenRestServlet(RestServlet):
 
         uses_allowed = body.get("uses_allowed", None)
         if not (
-            uses_allowed is None
-            or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+            uses_allowed is None or (type(uses_allowed) is int and uses_allowed >= 0)
         ):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
@@ -173,13 +172,13 @@ class NewRegistrationTokenRestServlet(RestServlet):
             )
 
         expiry_time = body.get("expiry_time", None)
-        if not isinstance(expiry_time, (int, type(None))):
+        if type(expiry_time) not in (int, type(None)):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "expiry_time must be an integer or null",
                 Codes.INVALID_PARAM,
             )
-        if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+        if type(expiry_time) is int and expiry_time < self.clock.time_msec():
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "expiry_time must not be in the past",
@@ -284,7 +283,7 @@ class RegistrationTokenRestServlet(RestServlet):
             uses_allowed = body["uses_allowed"]
             if not (
                 uses_allowed is None
-                or (isinstance(uses_allowed, int) and uses_allowed >= 0)
+                or (type(uses_allowed) is int and uses_allowed >= 0)
             ):
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
@@ -295,13 +294,13 @@ class RegistrationTokenRestServlet(RestServlet):
 
         if "expiry_time" in body:
             expiry_time = body["expiry_time"]
-            if not isinstance(expiry_time, (int, type(None))):
+            if type(expiry_time) not in (int, type(None)):
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "expiry_time must be an integer or null",
                     Codes.INVALID_PARAM,
                 )
-            if isinstance(expiry_time, int) and expiry_time < self.clock.time_msec():
+            if type(expiry_time) is int and expiry_time < self.clock.time_msec():
                 raise SynapseError(
                     HTTPStatus.BAD_REQUEST,
                     "expiry_time must not be in the past",

+ 3 - 3
synapse/rest/admin/users.py

@@ -973,7 +973,7 @@ class UserTokenRestServlet(RestServlet):
         body = parse_json_object_from_request(request, allow_empty_body=True)
 
         valid_until_ms = body.get("valid_until_ms")
-        if valid_until_ms and not isinstance(valid_until_ms, int):
+        if type(valid_until_ms) not in (int, type(None)):
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST, "'valid_until_ms' parameter must be an int"
             )
@@ -1125,14 +1125,14 @@ class RateLimitRestServlet(RestServlet):
         messages_per_second = body.get("messages_per_second", 0)
         burst_count = body.get("burst_count", 0)
 
-        if not isinstance(messages_per_second, int) or messages_per_second < 0:
+        if type(messages_per_second) is not int or messages_per_second < 0:
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (messages_per_second,),
                 errcode=Codes.INVALID_PARAM,
             )
 
-        if not isinstance(burst_count, int) or burst_count < 0:
+        if type(burst_count) is not int or burst_count < 0:
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "%r parameter must be a positive int" % (burst_count,),

+ 1 - 1
synapse/rest/client/report_event.py

@@ -54,7 +54,7 @@ class ReportEventRestServlet(RestServlet):
                 "Param 'reason' must be a string",
                 Codes.BAD_JSON,
             )
-        if not isinstance(body.get("score", 0), int):
+        if type(body.get("score", 0)) is not int:
             raise SynapseError(
                 HTTPStatus.BAD_REQUEST,
                 "Param 'score' must be an integer",

+ 1 - 1
synapse/rest/media/v1/oembed.py

@@ -200,7 +200,7 @@ class OEmbedProvider:
                 calc_description_and_urls(open_graph_response, oembed["html"])
             for size in ("width", "height"):
                 val = oembed.get(size)
-                if val is not None and isinstance(val, int):
+                if type(val) is int:
                     open_graph_response[f"og:video:{size}"] = val
 
         elif oembed_type == "link":

+ 1 - 1
synapse/rest/media/v1/thumbnailer.py

@@ -77,7 +77,7 @@ class Thumbnailer:
             image_exif = self.image._getexif()  # type: ignore
             if image_exif is not None:
                 image_orientation = image_exif.get(EXIF_ORIENTATION_TAG)
-                assert isinstance(image_orientation, int)
+                assert type(image_orientation) is int
                 self.transpose_method = EXIF_TRANSPOSE_MAPPINGS.get(image_orientation)
         except Exception as e:
             # A lot of parsing errors can happen when parsing EXIF

+ 3 - 3
synapse/storage/databases/main/events.py

@@ -1651,7 +1651,7 @@ class PersistEventsStore:
             if self._ephemeral_messages_enabled:
                 # If there's an expiry timestamp on the event, store it.
                 expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
-                if isinstance(expiry_ts, int) and not event.is_state():
+                if type(expiry_ts) is int and not event.is_state():
                     self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
 
         # Insert into the room_memberships table.
@@ -2133,10 +2133,10 @@ class PersistEventsStore:
         ):
             if (
                 "min_lifetime" in event.content
-                and not isinstance(event.content.get("min_lifetime"), int)
+                and type(event.content["min_lifetime"]) is not int
             ) or (
                 "max_lifetime" in event.content
-                and not isinstance(event.content.get("max_lifetime"), int)
+                and type(event.content["max_lifetime"]) is not int
             ):
                 # Ignore the event if one of the value isn't an integer.
                 return