|
@@ -38,7 +38,7 @@ class MatrixRestError(Exception):
|
|
|
self.error = error
|
|
|
|
|
|
|
|
|
-def get_args(request, required_args):
|
|
|
+def get_args(request, args, required=True):
|
|
|
"""
|
|
|
Helper function to get arguments for an HTTP request.
|
|
|
Currently takes args from the top level keys of a json object or
|
|
@@ -49,9 +49,14 @@ def get_args(request, required_args):
|
|
|
|
|
|
:param request: The request received by the servlet.
|
|
|
:type request: twisted.web.server.Request
|
|
|
- :param required_args: The args that needs to be found in the
|
|
|
- request's parameters.
|
|
|
- :type required_args: tuple[unicode]
|
|
|
+ :param args: The args to look for in the request's parameters.
|
|
|
+ :type args: tuple[unicode]
|
|
|
+ :param required: Whether to raise a MatrixRestError with 400
|
|
|
+ M_MISSING_PARAMS if an argument is not found.
|
|
|
+ :type required: bool
|
|
|
+
|
|
|
+ :raises: MatrixRestError if required is True and a given parameter
|
|
|
+ was not found in the request's query parameters.
|
|
|
|
|
|
:return: A dict containing the requested args and their values. String values
|
|
|
are of type unicode.
|
|
@@ -59,7 +64,7 @@ def get_args(request, required_args):
|
|
|
"""
|
|
|
v1_path = request.path.startswith(b'/_matrix/identity/api/v1')
|
|
|
|
|
|
- args = None
|
|
|
+ request_args = None
|
|
|
# for v1 paths, only look for json args if content type is json
|
|
|
if (
|
|
|
request.method in (b'POST', b'PUT') and (
|
|
@@ -71,25 +76,25 @@ def get_args(request, required_args):
|
|
|
):
|
|
|
try:
|
|
|
# json.loads doesn't allow bytes in Python 3.5
|
|
|
- args = json.loads(request.content.read().decode("UTF-8"))
|
|
|
+ request_args = json.loads(request.content.read().decode("UTF-8"))
|
|
|
except ValueError:
|
|
|
raise MatrixRestError(400, 'M_BAD_JSON', 'Malformed JSON')
|
|
|
|
|
|
# If we didn't get anything from that, and it's a v1 api path, try the request args
|
|
|
# (riot-web's usage of the ed25519 sign servlet currently involves
|
|
|
# sending the params in the query string with a json body of 'null')
|
|
|
- if args is None and (v1_path or request.method == b'GET'):
|
|
|
- args_bytes = copy.copy(request.args)
|
|
|
+ if request_args is None and (v1_path or request.method == b'GET'):
|
|
|
+ request_args_bytes = copy.copy(request.args)
|
|
|
# Twisted supplies everything as an array because it's valid to
|
|
|
# supply the same params multiple times with www-form-urlencoded
|
|
|
# params. This make it incompatible with the json object though,
|
|
|
# so we need to convert one of them. Since this is the
|
|
|
# backwards-compat option, we convert this one.
|
|
|
- args = {}
|
|
|
- for k, v in args_bytes.items():
|
|
|
+ request_args = {}
|
|
|
+ for k, v in request_args_bytes.items():
|
|
|
if isinstance(v, list) and len(v) == 1:
|
|
|
try:
|
|
|
- args[k.decode("UTF-8")] = v[0].decode("UTF-8")
|
|
|
+ request_args[k.decode("UTF-8")] = v[0].decode("UTF-8")
|
|
|
except UnicodeDecodeError:
|
|
|
# Get a version of the key that has non-UTF-8 characters replaced by
|
|
|
# their \xNN escape sequence so it doesn't raise another exception.
|
|
@@ -100,20 +105,22 @@ def get_args(request, required_args):
|
|
|
"Parameter %s and its value must be valid UTF-8" % safe_k,
|
|
|
)
|
|
|
|
|
|
- elif args is None:
|
|
|
- args = {}
|
|
|
+ elif request_args is None:
|
|
|
+ request_args = {}
|
|
|
|
|
|
- missing = []
|
|
|
- for a in required_args:
|
|
|
- if a not in args:
|
|
|
- missing.append(a)
|
|
|
+ if required:
|
|
|
+ # Check for any missing arguments
|
|
|
+ missing = []
|
|
|
+ for a in args:
|
|
|
+ if a not in request_args:
|
|
|
+ missing.append(a)
|
|
|
|
|
|
- if len(missing) > 0:
|
|
|
- request.setResponseCode(400)
|
|
|
- msg = "Missing parameters: "+(",".join(missing))
|
|
|
- raise MatrixRestError(400, 'M_MISSING_PARAMS', msg)
|
|
|
+ if len(missing) > 0:
|
|
|
+ request.setResponseCode(400)
|
|
|
+ msg = "Missing parameters: "+(",".join(missing))
|
|
|
+ raise MatrixRestError(400, 'M_MISSING_PARAMS', msg)
|
|
|
|
|
|
- return args
|
|
|
+ return request_args
|
|
|
|
|
|
|
|
|
def jsonwrap(f):
|