__init__.py 5.2 KB

  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import logging
  16. import json
  17. import copy
  18. from twisted.internet import defer
  19. from twisted.web import server
  20. logger = logging.getLogger(__name__)
  21. logger = logging.getLogger(__name__)
  22. class MatrixRestError(Exception):
  23. """
  24. Handled by the jsonwrap wrapper. Any servlets that don't use this
  25. wrapper should catch this exception themselves.
  26. """
  27. def __init__(self, httpStatus, errcode, error):
  28. super(Exception, self).__init__(error)
  29. self.httpStatus = httpStatus
  30. self.errcode = errcode
  31. self.error = error
  32. def get_args(request, required_args):
  33. """
  34. Helper function to get arguments for an HTTP request.
  35. Currently takes args from the top level keys of a json object or
  36. www-form-urlencoded for backwards compatability on v1 endpoints only.
  37. Returns a tuple (error, args) where if error is non-null,
  38. the request is malformed. Otherwise, args contains the
  39. parameters passed.
  40. """
  41. v1_path = request.path.startswith('/_matrix/identity/v1')
  42. args = None
  43. # for v1 paths, only look for json args if content type is json
  44. if (
  45. request.method == 'POST' and (
  46. not v1_path or (
  47. request.requestHeaders.hasHeader('Content-Type') and
  48. request.requestHeaders.getRawHeaders('Content-Type')[0].startswith('application/json')
  49. )
  50. )
  51. ):
  52. try:
  53. args = json.load(request.content)
  54. except ValueError:
  55. raise MatrixRestError(400, 'M_BAD_JSON', 'Malformed JSON')
  56. # If we didn't get anything from that, and it's a v1 api path, try the request args
  57. # (riot-web's usage of the ed25519 sign servlet currently involves
  58. # sending the params in the query string with a json body of 'null')
  59. if args is None and (v1_path or request.method == 'GET'):
  60. args = copy.copy(request.args)
  61. # Twisted supplies everything as an array because it's valid to
  62. # supply the same params multiple times with www-form-urlencoded
  63. # params. This make it incompatible with the json object though,
  64. # so we need to convert one of them. Since this is the
  65. # backwards-compat option, we convert this one.
  66. for k, v in args.items():
  67. if isinstance(v, list) and len(v) == 1:
  68. args[k] = v[0]
  69. elif args is None:
  70. args = {}
  71. missing = []
  72. for a in required_args:
  73. if a not in args:
  74. missing.append(a)
  75. if len(missing) > 0:
  76. request.setResponseCode(400)
  77. msg = "Missing parameters: "+(",".join(missing))
  78. raise MatrixRestError(400, 'M_MISSING_PARAMS', msg)
  79. return args
  80. def jsonwrap(f):
  81. @functools.wraps(f)
  82. def inner(*args, **kwargs):
  83. try:
  84. return json.dumps(f(*args, **kwargs)).encode("UTF-8")
  85. except MatrixRestError as e:
  86. request = args[1]
  87. request.setResponseCode(e.httpStatus)
  88. return json.dumps({
  89. "errcode": e.errcode,
  90. "error": e.error,
  91. })
  92. except Exception:
  93. logger.exception("Exception processing request");
  94. request = args[1]
  95. request.setResponseCode(500)
  96. return json.dumps({
  97. "errcode": "M_UNKNOWN",
  98. "error": "Internal Server Error",
  99. })
  100. return inner
  101. def deferjsonwrap(f):
  102. def reqDone(resp, request):
  103. request.setResponseCode(200)
  104. request.write(json.dumps(resp).encode("UTF-8"))
  105. request.finish()
  106. def reqErr(failure, request):
  107. if failure.check(MatrixRestError) is not None:
  108. request.setResponseCode(failure.value.httpStatus)
  109. request.write(json.dumps({'errcode': failure.value.errcode, 'error': failure.value.error}))
  110. else:
  111. logger.error("Request processing failed: %r, %s", failure, failure.getTraceback())
  112. request.setResponseCode(500)
  113. request.write(json.dumps({'errcode': 'M_UNKNOWN', 'error': 'Internal Server Error'}))
  114. request.finish()
  115. def inner(*args, **kwargs):
  116. request = args[1]
  117. d = defer.maybeDeferred(f, *args, **kwargs)
  118. d.addCallback(reqDone, request)
  119. d.addErrback(reqErr, request)
  120. return server.NOT_DONE_YET
  121. return inner
  122. def send_cors(request):
  123. request.setHeader(b"Content-Type", b"application/json")
  124. request.setHeader("Access-Control-Allow-Origin", "*")
  125. request.setHeader("Access-Control-Allow-Methods",
  127. request.setHeader("Access-Control-Allow-Headers",
  128. "Origin, X-Requested-With, Content-Type, Accept, Authorization")