|
@@ -18,6 +18,7 @@ from twisted.internet import defer
|
|
|
|
|
|
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
|
|
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
|
|
|
+from synapse.http.endpoint import parse_server_name
|
|
|
from synapse.http.server import JsonResource
|
|
|
from synapse.http.servlet import (
|
|
|
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
|
@@ -99,26 +100,6 @@ class Authenticator(object):
|
|
|
|
|
|
origin = None
|
|
|
|
|
|
- def parse_auth_header(header_str):
|
|
|
- try:
|
|
|
- params = auth.split(" ")[1].split(",")
|
|
|
- param_dict = dict(kv.split("=") for kv in params)
|
|
|
-
|
|
|
- def strip_quotes(value):
|
|
|
- if value.startswith("\""):
|
|
|
- return value[1:-1]
|
|
|
- else:
|
|
|
- return value
|
|
|
-
|
|
|
- origin = strip_quotes(param_dict["origin"])
|
|
|
- key = strip_quotes(param_dict["key"])
|
|
|
- sig = strip_quotes(param_dict["sig"])
|
|
|
- return (origin, key, sig)
|
|
|
- except Exception:
|
|
|
- raise AuthenticationError(
|
|
|
- 400, "Malformed Authorization header", Codes.UNAUTHORIZED
|
|
|
- )
|
|
|
-
|
|
|
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
|
|
|
|
|
|
if not auth_headers:
|
|
@@ -127,8 +108,8 @@ class Authenticator(object):
|
|
|
)
|
|
|
|
|
|
for auth in auth_headers:
|
|
|
- if auth.startswith("X-Matrix"):
|
|
|
- (origin, key, sig) = parse_auth_header(auth)
|
|
|
+ if auth.startswith(b"X-Matrix"):
|
|
|
+ (origin, key, sig) = _parse_auth_header(auth)
|
|
|
json_request["origin"] = origin
|
|
|
json_request["signatures"].setdefault(origin, {})[key] = sig
|
|
|
|
|
@@ -165,6 +146,47 @@ class Authenticator(object):
|
|
|
logger.exception("Error resetting retry timings on %s", origin)
|
|
|
|
|
|
|
|
|
+def _parse_auth_header(header_bytes):
|
|
|
+ """Parse an X-Matrix auth header
|
|
|
+
|
|
|
+ Args:
|
|
|
+ header_bytes (bytes): header value
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple[str, str, str]: origin, key id, signature.
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ AuthenticationError if the header could not be parsed
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ header_str = header_bytes.decode('utf-8')
|
|
|
+ params = header_str.split(" ")[1].split(",")
|
|
|
+ param_dict = dict(kv.split("=") for kv in params)
|
|
|
+
|
|
|
+ def strip_quotes(value):
|
|
|
+ if value.startswith(b"\""):
|
|
|
+ return value[1:-1]
|
|
|
+ else:
|
|
|
+ return value
|
|
|
+
|
|
|
+ origin = strip_quotes(param_dict["origin"])
|
|
|
+ # ensure that the origin is a valid server name
|
|
|
+ parse_server_name(origin)
|
|
|
+
|
|
|
+ key = strip_quotes(param_dict["key"])
|
|
|
+ sig = strip_quotes(param_dict["sig"])
|
|
|
+ return origin, key, sig
|
|
|
+ except Exception as e:
|
|
|
+ logger.warn(
|
|
|
+ "Error parsing auth header '%s': %s",
|
|
|
+ header_bytes.decode('ascii', 'replace'),
|
|
|
+ e,
|
|
|
+ )
|
|
|
+ raise AuthenticationError(
|
|
|
+ 400, "Malformed Authorization header", Codes.UNAUTHORIZED,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
class BaseFederationServlet(object):
|
|
|
REQUIRE_AUTH = True
|
|
|
|