peer.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. # Copyright 2014 OpenMarket Ltd
  2. # Copyright 2019 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import binascii
  16. import json
  17. import logging
  18. from typing import TYPE_CHECKING, Any, Dict
  19. import signedjson.key
  20. import signedjson.sign
  21. from twisted.internet import defer
  22. from twisted.internet.defer import Deferred
  23. from twisted.web.client import readBody
  24. from twisted.web.iweb import IResponse
  25. from unpaddedbase64 import decode_base64
  26. from sydent.config import ConfigError
  27. from sydent.db.hashing_metadata import HashingMetadataStore
  28. from sydent.db.threepid_associations import GlobalAssociationStore
  29. from sydent.threepid import threePidAssocFromDict
  30. from sydent.util import json_decoder
  31. from sydent.util.hash import sha256_and_url_safe_base64
  32. from sydent.util.stringutils import normalise_address
  33. if TYPE_CHECKING:
  34. from sydent.sydent import Sydent
  35. logger = logging.getLogger(__name__)
  36. SIGNING_KEY_ALGORITHM = "ed25519"
  37. class Peer:
  38. def __init__(self, servername, pubkeys):
  39. self.servername = servername
  40. self.pubkeys = pubkeys
  41. self.is_being_pushed_to = False
  42. def pushUpdates(self, sgAssocs) -> "Deferred":
  43. """
  44. :param sgAssocs: Sequence of (originId, sgAssoc) tuples where originId is the id on the creating server and
  45. sgAssoc is the json object of the signed association
  46. """
  47. pass
  48. class LocalPeer(Peer):
  49. """
  50. The local peer (ourselves: essentially copying from the local associations table to the global one)
  51. """
  52. def __init__(self, sydent: "Sydent") -> None:
  53. super().__init__(sydent.config.general.server_name, {})
  54. self.sydent = sydent
  55. self.hashing_store = HashingMetadataStore(sydent)
  56. globalAssocStore = GlobalAssociationStore(self.sydent)
  57. lastId = globalAssocStore.lastIdFromServer(self.servername)
  58. self.lastId = lastId if lastId is not None else -1
  59. def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
  60. """
  61. Saves the given associations in the global associations store. Only stores an
  62. association if its ID is greater than the last seen ID.
  63. :param sgAssocs: The associations to save.
  64. :return: True
  65. """
  66. globalAssocStore = GlobalAssociationStore(self.sydent)
  67. for localId in sgAssocs:
  68. if localId > self.lastId:
  69. assocObj = threePidAssocFromDict(sgAssocs[localId])
  70. # ensure we are casefolding email addresses
  71. assocObj.address = normalise_address(assocObj.address, assocObj.medium)
  72. if assocObj.mxid is not None:
  73. # Assign a lookup_hash to this association
  74. str_to_hash = " ".join(
  75. [
  76. assocObj.address,
  77. assocObj.medium,
  78. self.hashing_store.get_lookup_pepper(),
  79. ],
  80. )
  81. assocObj.lookup_hash = sha256_and_url_safe_base64(str_to_hash)
  82. # We can probably skip verification for the local peer (although it could
  83. # be good as a sanity check)
  84. globalAssocStore.addAssociation(
  85. assocObj,
  86. json.dumps(sgAssocs[localId]),
  87. self.sydent.config.general.server_name,
  88. localId,
  89. )
  90. else:
  91. globalAssocStore.removeAssociation(
  92. assocObj.medium, assocObj.address
  93. )
  94. d = defer.succeed(True)
  95. return d
  96. class RemotePeer(Peer):
  97. def __init__(
  98. self,
  99. sydent: "Sydent",
  100. server_name: str,
  101. port: int,
  102. pubkeys: Dict[str, str],
  103. lastSentVersion: int,
  104. ) -> None:
  105. """
  106. :param sydent: The current Sydent instance.
  107. :param server_name: The peer's server name.
  108. :param port: The peer's port.
  109. :param pubkeys: The peer's public keys in a dict[key_id, key_b64]
  110. :param lastSentVersion: The ID of the last association sent to the peer.
  111. """
  112. super().__init__(server_name, pubkeys)
  113. self.sydent = sydent
  114. self.port = port
  115. self.lastSentVersion = lastSentVersion
  116. # look up or build the replication URL
  117. replication_url = self.sydent.config.http.base_replication_urls.get(server_name)
  118. if replication_url is None:
  119. if not port:
  120. port = 1001
  121. replication_url = "https://%s:%i" % (server_name, port)
  122. if replication_url[-1:] != "/":
  123. replication_url += "/"
  124. replication_url += "_matrix/identity/replicate/v1/push"
  125. self.replication_url = replication_url
  126. # Get verify key for this peer
  127. # Check if their key is base64 or hex encoded
  128. pubkey = self.pubkeys[SIGNING_KEY_ALGORITHM]
  129. try:
  130. # Check for hex encoding
  131. int(pubkey, 16)
  132. # Decode hex into bytes
  133. pubkey_decoded = binascii.unhexlify(pubkey)
  134. logger.warning(
  135. "Peer public key of %s is hex encoded. Please update to base64 encoding",
  136. server_name,
  137. )
  138. except ValueError:
  139. # Check for base64 encoding
  140. try:
  141. pubkey_decoded = decode_base64(pubkey)
  142. except Exception as e:
  143. raise ConfigError(
  144. "Unable to decode public key for peer %s: %s" % (server_name, e),
  145. )
  146. self.verify_key = signedjson.key.decode_verify_key_bytes(
  147. SIGNING_KEY_ALGORITHM + ":", pubkey_decoded
  148. )
  149. # Attach metadata
  150. self.verify_key.alg = SIGNING_KEY_ALGORITHM
  151. self.verify_key.version = 0
  152. def verifySignedAssociation(self, assoc: Dict[Any, Any]) -> None:
  153. """Verifies a signature on a signed association. Raises an exception if the
  154. signature is incorrect or couldn't be verified.
  155. :param assoc: A signed association.
  156. """
  157. if "signatures" not in assoc:
  158. raise NoSignaturesException()
  159. key_ids = signedjson.sign.signature_ids(assoc, self.servername)
  160. if (
  161. not key_ids
  162. or len(key_ids) == 0
  163. or not key_ids[0].startswith(SIGNING_KEY_ALGORITHM + ":")
  164. ):
  165. e = NoMatchingSignatureException()
  166. e.foundSigs = assoc["signatures"].keys()
  167. e.requiredServername = self.servername
  168. raise e
  169. # Verify the JSON
  170. signedjson.sign.verify_signed_json(assoc, self.servername, self.verify_key)
  171. def pushUpdates(self, sgAssocs: Dict[int, Dict[str, Any]]) -> "Deferred":
  172. """
  173. Pushes the given associations to the peer.
  174. :param sgAssocs: The associations to push.
  175. :return: A deferred which results in the response to the push request.
  176. """
  177. body = {"sgAssocs": sgAssocs}
  178. reqDeferred = self.sydent.replicationHttpsClient.postJson(
  179. self.replication_url, body
  180. )
  181. # XXX: We'll also need to prune the deleted associations out of the
  182. # local associations table once they've been replicated to all peers
  183. # (ie. remove the record we kept in order to propagate the deletion to
  184. # other peers).
  185. updateDeferred = defer.Deferred()
  186. reqDeferred.addCallback(self._pushSuccess, updateDeferred=updateDeferred)
  187. reqDeferred.addErrback(self._pushFailed, updateDeferred=updateDeferred)
  188. return updateDeferred
  189. def _pushSuccess(
  190. self,
  191. result: "IResponse",
  192. updateDeferred: "Deferred",
  193. ) -> None:
  194. """
  195. Processes a successful push request. If the request resulted in a status code
  196. that's not a success, consider it a failure
  197. :param result: The HTTP response.
  198. :param updateDeferred: The deferred to make either succeed or fail depending on
  199. the status code.
  200. """
  201. if result.code >= 200 and result.code < 300:
  202. updateDeferred.callback(result)
  203. else:
  204. d = readBody(result)
  205. d.addCallback(self._failedPushBodyRead, updateDeferred=updateDeferred)
  206. d.addErrback(self._pushFailed, updateDeferred=updateDeferred)
  207. def _failedPushBodyRead(self, body: bytes, updateDeferred: "Deferred") -> None:
  208. """
  209. Processes a response body from a failed push request, then calls the error
  210. callback of the provided deferred.
  211. :param body: The response body.
  212. :param updateDeferred: The deferred to call the error callback of.
  213. """
  214. errObj = json_decoder.decode(body.decode("utf8"))
  215. e = RemotePeerError()
  216. e.errorDict = errObj
  217. updateDeferred.errback(e)
  218. def _pushFailed(
  219. self,
  220. failure,
  221. updateDeferred: "Deferred",
  222. ) -> None:
  223. """
  224. Processes a failed push request, by calling the error callback of the given
  225. deferred with it.
  226. :param failure: The failure to process.
  227. :type failure: twisted.python.failure.Failure
  228. :param updateDeferred: The deferred to call the error callback of.
  229. """
  230. updateDeferred.errback(failure)
  231. return None
  232. class NoSignaturesException(Exception):
  233. pass
  234. class NoMatchingSignatureException(Exception):
  235. def __str__(self):
  236. return "Found signatures: %s, required server name: %s" % (
  237. self.foundSigs,
  238. self.requiredServername,
  239. )
  240. class RemotePeerError(Exception):
  241. def __str__(self):
  242. return repr(self.errorDict)