bind.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import collections
  17. import json
  18. import logging
  19. import math
  20. import random
  21. import signedjson.sign
  22. from sydent.db.invite_tokens import JoinTokenStore
  23. from sydent.db.threepid_associations import LocalAssociationStore
  24. from sydent.util import time_msec
  25. from sydent.threepid.signer import Signer
  26. from sydent.http.httpclient import FederationHttpClient
  27. from sydent.threepid import ThreepidAssociation
  28. from OpenSSL import SSL
  29. from OpenSSL.SSL import VERIFY_NONE
  30. from StringIO import StringIO
  31. from twisted.internet import reactor, defer, ssl
  32. from twisted.names import client, dns
  33. from twisted.names.error import DNSNameError
  34. from twisted.web.client import FileBodyProducer, Agent
  35. from twisted.web.http_headers import Headers
  36. logger = logging.getLogger(__name__)
  37. class ThreepidBinder:
  38. # the lifetime of a 3pid association
  39. THREEPID_ASSOCIATION_LIFETIME_MS = 100 * 365 * 24 * 60 * 60 * 1000
  40. def __init__(self, sydent):
  41. self.sydent = sydent
  42. def addBinding(self, medium, address, mxid):
  43. """Binds the given 3pid to the given mxid.
  44. It's assumed that we have somehow validated that the given user owns
  45. the given 3pid
  46. Args:
  47. medium (str): the type of 3pid
  48. address (str): the 3pid
  49. mxid (str): the mxid to bind it to
  50. """
  51. localAssocStore = LocalAssociationStore(self.sydent)
  52. createdAt = time_msec()
  53. expires = createdAt + ThreepidBinder.THREEPID_ASSOCIATION_LIFETIME_MS
  54. assoc = ThreepidAssociation(medium, address, mxid, createdAt, createdAt, expires)
  55. localAssocStore.addOrUpdateAssociation(assoc)
  56. self.sydent.pusher.doLocalPush()
  57. joinTokenStore = JoinTokenStore(self.sydent)
  58. pendingJoinTokens = joinTokenStore.getTokens(medium, address)
  59. invites = []
  60. for token in pendingJoinTokens:
  61. token["mxid"] = mxid
  62. token["signed"] = {
  63. "mxid": mxid,
  64. "token": token["token"],
  65. }
  66. token["signed"] = signedjson.sign.sign_json(token["signed"], self.sydent.server_name, self.sydent.keyring.ed25519)
  67. invites.append(token)
  68. if invites:
  69. assoc.extra_fields["invites"] = invites
  70. joinTokenStore.markTokensAsSent(medium, address)
  71. signer = Signer(self.sydent)
  72. sgassoc = signer.signedThreePidAssociation(assoc)
  73. self._notify(sgassoc, 0)
  74. return sgassoc
  75. def removeBinding(self, threepid, mxid):
  76. localAssocStore = LocalAssociationStore(self.sydent)
  77. localAssocStore.removeAssociation(threepid, mxid)
  78. self.sydent.pusher.doLocalPush()
  79. @defer.inlineCallbacks
  80. def _notify(self, assoc, attempt):
  81. mxid = assoc["mxid"]
  82. domain = mxid.split(":")[-1]
  83. server = yield self._pickServer(domain)
  84. post_url = "https://%s/_matrix/federation/v1/3pid/onbind" % (
  85. server,
  86. )
  87. logger.info("Making bind callback to: %s", post_url)
  88. # Make a POST to the chosen Synapse server
  89. http_client = FederationHttpClient(self.sydent)
  90. try:
  91. response = yield http_client.post_json_get_nothing(post_url, assoc, {})
  92. except Exception as e:
  93. self._notifyErrback(assoc, attempt, e)
  94. return
  95. # If the request failed, try again with exponential backoff
  96. if response.code != 200:
  97. self._notifyErrback(
  98. assoc, attempt, "Non-OK error code received (%d)" % response.code
  99. )
  100. else:
  101. logger.info("Successfully notified on bind for %s" % (mxid,))
  102. def _notifyErrback(self, assoc, attempt, error):
  103. logger.warn("Error notifying on bind for %s: %s - rescheduling", assoc["mxid"], error)
  104. reactor.callLater(math.pow(2, attempt), self._notify, assoc, attempt + 1)
  105. # The below is lovingly ripped off of synapse/http/endpoint.py
  106. _Server = collections.namedtuple("_Server", "priority weight host port")
  107. @defer.inlineCallbacks
  108. def _pickServer(self, host):
  109. servers = yield self._fetchServers(host)
  110. if not servers:
  111. defer.returnValue("%s:8448" % (host,))
  112. min_priority = servers[0].priority
  113. weight_indexes = list(
  114. (index, server.weight + 1)
  115. for index, server in enumerate(servers)
  116. if server.priority == min_priority
  117. )
  118. total_weight = sum(weight for index, weight in weight_indexes)
  119. target_weight = random.randint(0, total_weight)
  120. for index, weight in weight_indexes:
  121. target_weight -= weight
  122. if target_weight <= 0:
  123. server = servers[index]
  124. defer.returnValue("%s:%d" % (server.host, server.port,))
  125. return
  126. @defer.inlineCallbacks
  127. def _fetchServers(self, host):
  128. try:
  129. service = "_matrix._tcp.%s" % host
  130. answers, auth, add = yield client.lookupService(service)
  131. except DNSNameError:
  132. answers = []
  133. if (len(answers) == 1
  134. and answers[0].type == dns.SRV
  135. and answers[0].payload
  136. and answers[0].payload.target == dns.Name(".")):
  137. raise DNSNameError("Service %s unavailable", service)
  138. servers = []
  139. for answer in answers:
  140. if answer.type != dns.SRV or not answer.payload:
  141. continue
  142. payload = answer.payload
  143. servers.append(ThreepidBinder._Server(
  144. host=str(payload.target),
  145. port=int(payload.port),
  146. priority=int(payload.priority),
  147. weight=int(payload.weight)
  148. ))
  149. servers.sort()
  150. defer.returnValue(servers)