Bladeren bron

Be stricter about JSON that is accepted by Synapse (#8106)

Patrick Cloke 3 jaren geleden
bovenliggende
commit
eebf52be06

+ 1 - 0
changelog.d/8106.bugfix

@@ -0,0 +1 @@
+Fix a long-standing bug where invalid JSON would be accepted by Synapse.

+ 3 - 3
synapse/api/errors.py

@@ -21,10 +21,10 @@ import typing
 from http import HTTPStatus
 from typing import Dict, List, Optional, Union
 
-from canonicaljson import json
-
 from twisted.web import http
 
+from synapse.util import json_decoder
+
 if typing.TYPE_CHECKING:
     from synapse.types import JsonDict
 
@@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException):
         # try to parse the body as json, to get better errcode/msg, but
         # default to M_UNKNOWN with the HTTP status as the error text
         try:
-            j = json.loads(self.response.decode("utf-8"))
+            j = json_decoder.decode(self.response.decode("utf-8"))
         except ValueError:
             j = {}
 

+ 2 - 3
synapse/federation/federation_server.py

@@ -28,7 +28,6 @@ from typing import (
     Union,
 )
 
-from canonicaljson import json
 from prometheus_client import Counter, Histogram
 
 from twisted.internet import defer
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.types import JsonDict, get_domain_from_id
-from synapse.util import glob_to_regex, unwrapFirstError
+from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
             for device_id, keys in device_keys.items():
                 for key_id, json_str in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_str)
+                        key_id: json_decoder.decode(json_str)
                     }
 
         logger.info(

+ 2 - 3
synapse/federation/sender/transaction_manager.py

@@ -15,8 +15,6 @@
 import logging
 from typing import TYPE_CHECKING, List, Tuple
 
-from canonicaljson import json
-
 from synapse.api.errors import HttpResponseException
 from synapse.events import EventBase
 from synapse.federation.persistence import TransactionActions
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
     tags,
     whitelisted_homeserver,
 )
+from synapse.util import json_decoder
 from synapse.util.metrics import measure_func
 
 if TYPE_CHECKING:
@@ -71,7 +70,7 @@ class TransactionManager(object):
         for edu in pending_edus:
             context = edu.get_context()
             if context:
-                span_contexts.append(extract_text_map(json.loads(context)))
+                span_contexts.append(extract_text_map(json_decoder.decode(context)))
             if keep_destination:
                 edu.strip_context()
 

+ 4 - 4
synapse/handlers/e2e_keys.py

@@ -19,7 +19,7 @@ import logging
 from typing import Dict, List, Optional, Tuple
 
 import attr
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from signedjson.key import VerifyKey, decode_verify_key_bytes
 from signedjson.sign import SignatureVerifyException, verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
-from synapse.util import unwrapFirstError
+from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.retryutils import NotRetryingDestination
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
             for device_id, keys in device_keys.items():
                 for key_id, json_bytes in keys.items():
                     json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json.loads(json_bytes)
+                        key_id: json_decoder.decode(json_bytes)
                     }
 
         @trace
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
 
 
 def _one_time_keys_match(old_key_json, new_key):
-    old_key = json.loads(old_key_json)
+    old_key = json_decoder.decode(old_key_json)
 
     # if either is a string rather than an object, they must match exactly
     if not isinstance(old_key, dict) or not isinstance(new_key, dict):

+ 2 - 3
synapse/handlers/identity.py

@@ -21,8 +21,6 @@ import logging
 import urllib.parse
 from typing import Awaitable, Callable, Dict, List, Optional, Tuple
 
-from canonicaljson import json
-
 from twisted.internet.error import TimeoutError
 
 from synapse.api.errors import (
@@ -34,6 +32,7 @@ from synapse.api.errors import (
 from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, Requester
+from synapse.util import json_decoder
 from synapse.util.hash import sha256_and_url_safe_base64
 from synapse.util.stringutils import assert_valid_client_secret, random_string
 
@@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler):
         except TimeoutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
-            data = json.loads(e.msg)  # XXX WAT?
+            data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
         logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)

+ 3 - 2
synapse/handlers/message.py

@@ -17,7 +17,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 
 from twisted.internet.interfaces import IDelayedCall
 
@@ -55,6 +55,7 @@ from synapse.types import (
     UserID,
     create_requester,
 )
+from synapse.util import json_decoder
 from synapse.util.async_helpers import Linearizer
 from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
@@ -864,7 +865,7 @@ class EventCreationHandler(object):
         # Ensure that we can round trip before trying to persist in db
         try:
             dump = frozendict_json_encoder.encode(event.content)
-            json.loads(dump)
+            json_decoder.decode(dump)
         except Exception:
             logger.exception("Failed to encode content: %r", event.content)
             raise

+ 3 - 3
synapse/handlers/oidc_handler.py

@@ -12,7 +12,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import UserID, map_username_to_mxid_localpart
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -367,7 +367,7 @@ class OidcHandler:
             # and check for an error field. If not, we respond with a generic
             # error message.
             try:
-                resp = json.loads(resp_body.decode("utf-8"))
+                resp = json_decoder.decode(resp_body.decode("utf-8"))
                 error = resp["error"]
                 description = resp.get("error_description", error)
             except (ValueError, KeyError):
@@ -384,7 +384,7 @@ class OidcHandler:
 
         # Since it is a not a 5xx code, body should be a valid JSON. It will
         # raise if not.
-        resp = json.loads(resp_body.decode("utf-8"))
+        resp = json_decoder.decode(resp_body.decode("utf-8"))
 
         if "error" in resp:
             error = resp["error"]

+ 2 - 3
synapse/handlers/ui_auth/checkers.py

@@ -16,13 +16,12 @@
 import logging
 from typing import Any
 
-from canonicaljson import json
-
 from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.config.emailconfig import ThreepidBehaviour
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
-            resp_body = json.loads(data.decode("utf-8"))
+            resp_body = json_decoder.decode(data.decode("utf-8"))
 
         if "success" in resp_body:
             # Note that we do NOT check the hostname here: we explicitly

+ 6 - 5
synapse/http/client.py

@@ -19,7 +19,7 @@ import urllib
 from io import BytesIO
 
 import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from netaddr import IPAddress
 from prometheus_client import Counter
 from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
 logger = logging.getLogger(__name__)
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
             actual_headers.update(headers)
 
         body = await self.get_raw(uri, args, headers=headers)
-        return json.loads(body.decode("utf-8"))
+        return json_decoder.decode(body.decode("utf-8"))
 
     async def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body

+ 2 - 3
synapse/http/federation/well_known_resolver.py

@@ -13,7 +13,6 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 import random
 import time
@@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
 
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.metrics import Measure
 
@@ -181,7 +180,7 @@ class WellKnownResolver(object):
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode("utf-8"))
+            parsed_body = json_decoder.decode(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
 
             result = parsed_body["m.server"].encode("ascii")

+ 2 - 3
synapse/http/servlet.py

@@ -17,9 +17,8 @@
 
 import logging
 
-from canonicaljson import json
-
 from synapse.api.errors import Codes, SynapseError
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         return None
 
     try:
-        content = json.loads(content_bytes.decode("utf-8"))
+        content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)

+ 5 - 2
synapse/logging/opentracing.py

@@ -177,6 +177,7 @@ from canonicaljson import json
 from twisted.internet import defer
 
 from synapse.config import ConfigError
+from synapse.util import json_decoder
 
 if TYPE_CHECKING:
     from synapse.http.site import SynapseRequest
@@ -499,7 +500,9 @@ def start_active_span_from_edu(
     if opentracing is None:
         return _noop_context_manager()
 
-    carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
+    carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
+        "opentracing", {}
+    )
     context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
     _references = [
         opentracing.child_of(span_context_from_string(x))
@@ -699,7 +702,7 @@ def span_context_from_string(carrier):
     Returns:
         The active span context decoded from a string.
     """
-    carrier = json.loads(carrier)
+    carrier = json_decoder.decode(carrier)
     return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
 
 

+ 5 - 7
synapse/replication/tcp/commands.py

@@ -21,9 +21,7 @@ import abc
 import logging
 from typing import Tuple, Type
 
-from canonicaljson import json
-
-from synapse.util import json_encoder as _json_encoder
+from synapse.util import json_decoder, json_encoder
 
 logger = logging.getLogger(__name__)
 
@@ -125,7 +123,7 @@ class RdataCommand(Command):
             stream_name,
             instance_name,
             None if token == "batch" else int(token),
-            json.loads(row_json),
+            json_decoder.decode(row_json),
         )
 
     def to_line(self):
@@ -134,7 +132,7 @@ class RdataCommand(Command):
                 self.stream_name,
                 self.instance_name,
                 str(self.token) if self.token is not None else "batch",
-                _json_encoder.encode(self.row),
+                json_encoder.encode(self.row),
             )
         )
 
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
     def from_line(cls, line):
         user_id, jsn = line.split(" ", 1)
 
-        access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
+        access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
 
         return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
 
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
         return (
             self.user_id
             + " "
-            + _json_encoder.encode(
+            + json_encoder.encode(
                 (
                     self.access_token,
                     self.ip,

+ 7 - 4
synapse/rest/client/v1/room.py

@@ -21,8 +21,6 @@ import re
 from typing import List, Optional
 from urllib import parse as urlparse
 
-from canonicaljson import json
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
@@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.util import json_decoder
 
 MYPY = False
 if MYPY:
@@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
             if (
                 event_filter
                 and event_filter.filter_json.get("event_format", "client")
@@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet):
         filter_str = parse_string(request, b"filter", encoding="utf-8")
         if filter_str:
             filter_json = urlparse.unquote(filter_str)
-            event_filter = Filter(json.loads(filter_json))  # type: Optional[Filter]
+            event_filter = Filter(
+                json_decoder.decode(filter_json)
+            )  # type: Optional[Filter]
         else:
             event_filter = None
 

+ 2 - 3
synapse/rest/client/v2_alpha/sync.py

@@ -16,8 +16,6 @@
 import itertools
 import logging
 
-from canonicaljson import json
-
 from synapse.api.constants import PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.sync import SyncConfig
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.types import StreamToken
+from synapse.util import json_decoder
 
 from ._base import client_patterns, set_timeline_upper_limit
 
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
             filter_collection = DEFAULT_FILTER_COLLECTION
         elif filter_id.startswith("{"):
             try:
-                filter_object = json.loads(filter_id)
+                filter_object = json_decoder.decode(filter_id)
                 set_timeline_upper_limit(
                     filter_object, self.hs.config.filter_timeline_limit
                 )

+ 5 - 3
synapse/rest/key/v2/remote_key_resource.py

@@ -15,19 +15,19 @@
 import logging
 from typing import Dict, Set
 
-from canonicaljson import json
 from signedjson.sign import sign_json
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
 
 class RemoteKey(DirectServeJsonResource):
-    """HTTP resource for retreiving the TLS certificate and NACL signature
+    """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
     that the NACL signature for the remote server is valid. Returns a dict of
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
                     # Cast to bytes since postgresql returns a memoryview.
                     json_results.add(bytes(result["key_json"]))
 
+        # If there is a cache miss, request the missing keys, then recurse (and
+        # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
             await self.fetcher.get_keys(cache_misses)
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
             for key_json in json_results:
-                key_json = json.loads(key_json.decode("utf-8"))
+                key_json = json_decoder.decode(key_json.decode("utf-8"))
                 for signing_key in self.config.key_server_signing_keys:
                     key_json = sign_json(key_json, self.config.server_name, signing_key)
 

+ 3 - 4
synapse/storage/_base.py

@@ -19,12 +19,11 @@ import random
 from abc import ABCMeta
 from typing import Any, Optional
 
-from canonicaljson import json
-
 from synapse.storage.database import LoggingTransaction  # noqa: F401
 from synapse.storage.database import make_in_list_sql_clause  # noqa: F401
 from synapse.storage.database import DatabasePool
 from synapse.types import Collection, get_domain_from_id
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -99,13 +98,13 @@ def db_to_json(db_content):
     if isinstance(db_content, memoryview):
         db_content = db_content.tobytes()
 
-    # Decode it to a Unicode string before feeding it to json.loads, since
+    # Decode it to a Unicode string before feeding it to the JSON decoder, since
     # Python 3.5 does not support deserializing bytes.
     if isinstance(db_content, (bytes, bytearray)):
         db_content = db_content.decode("utf8")
 
     try:
-        return json.loads(db_content)
+        return json_decoder.decode(db_content)
     except Exception:
         logging.warning("Tried to decode '%r' as JSON and failed", db_content)
         raise

+ 14 - 2
synapse/storage/databases/main/events_worker.py

@@ -596,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
             if not allow_rejected and rejected_reason:
                 continue
 
-            d = db_to_json(row["json"])
-            internal_metadata = db_to_json(row["internal_metadata"])
+            # If the event or metadata cannot be parsed, log the error and act
+            # as if the event is unknown.
+            try:
+                d = db_to_json(row["json"])
+            except ValueError:
+                logger.error("Unable to parse json from event: %s", event_id)
+                continue
+            try:
+                internal_metadata = db_to_json(row["internal_metadata"])
+            except ValueError:
+                logger.error(
+                    "Unable to parse internal_metadata from event: %s", event_id
+                )
+                continue
 
             format_version = row["format_version"]
             if format_version is None:

+ 12 - 2
synapse/util/__init__.py

@@ -25,8 +25,18 @@ from synapse.logging import context
 
 logger = logging.getLogger(__name__)
 
-# Create a custom encoder to reduce the whitespace produced by JSON encoding.
-json_encoder = json.JSONEncoder(separators=(",", ":"))
+
+def _reject_invalid_json(val):
+    """Do not allow Infinity, -Infinity, or NaN values in JSON."""
+    raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
+
+
+# Create a custom encoder to reduce the whitespace produced by JSON encoding and
+# ensure that valid JSON is produced.
+json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
+
+# Create a custom decoder to reject Python extensions to JSON.
+json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
 
 
 def unwrapFirstError(failure):