bind.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014 OpenMarket Ltd
  3. # Copyright 2018 New Vector Ltd
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import collections
  17. import json
  18. import logging
  19. import math
  20. import random
  21. import signedjson.sign
  22. from sydent.db.invite_tokens import JoinTokenStore
  23. from sydent.db.threepid_associations import LocalAssociationStore
  24. from sydent.util import time_msec
  25. from sydent.threepid.signer import Signer
  26. from sydent.threepid import ThreepidAssociation
  27. from OpenSSL import SSL
  28. from OpenSSL.SSL import VERIFY_NONE
  29. from StringIO import StringIO
  30. from twisted.internet import reactor, defer, ssl
  31. from twisted.names import client, dns
  32. from twisted.names.error import DNSNameError
  33. from twisted.web.client import FileBodyProducer, Agent
  34. from twisted.web.http_headers import Headers
  35. logger = logging.getLogger(__name__)
  36. class ThreepidBinder:
  37. # the lifetime of a 3pid association
  38. THREEPID_ASSOCIATION_LIFETIME_MS = 100 * 365 * 24 * 60 * 60 * 1000
  39. def __init__(self, sydent):
  40. self.sydent = sydent
  41. def addBinding(self, medium, address, mxid):
  42. """Binds the given 3pid to the given mxid.
  43. It's assumed that we have somehow validated that the given user owns
  44. the given 3pid
  45. Args:
  46. medium (str): the type of 3pid
  47. address (str): the 3pid
  48. mxid (str): the mxid to bind it to
  49. """
  50. localAssocStore = LocalAssociationStore(self.sydent)
  51. createdAt = time_msec()
  52. expires = createdAt + ThreepidBinder.THREEPID_ASSOCIATION_LIFETIME_MS
  53. assoc = ThreepidAssociation(medium, address, mxid, createdAt, createdAt, expires)
  54. localAssocStore.addOrUpdateAssociation(assoc)
  55. self.sydent.pusher.doLocalPush()
  56. joinTokenStore = JoinTokenStore(self.sydent)
  57. pendingJoinTokens = joinTokenStore.getTokens(medium, address)
  58. invites = []
  59. for token in pendingJoinTokens:
  60. token["mxid"] = mxid
  61. token["signed"] = {
  62. "mxid": mxid,
  63. "token": token["token"],
  64. }
  65. token["signed"] = signedjson.sign.sign_json(token["signed"], self.sydent.server_name, self.sydent.keyring.ed25519)
  66. invites.append(token)
  67. if invites:
  68. assoc.extra_fields["invites"] = invites
  69. joinTokenStore.markTokensAsSent(medium, address)
  70. signer = Signer(self.sydent)
  71. sgassoc = signer.signedThreePidAssociation(assoc)
  72. self._notify(sgassoc, 0)
  73. return sgassoc
  74. def removeBinding(self, threepid, mxid):
  75. localAssocStore = LocalAssociationStore(self.sydent)
  76. localAssocStore.removeAssociation(threepid, mxid)
  77. self.sydent.pusher.doLocalPush()
  78. @defer.inlineCallbacks
  79. def _notify(self, assoc, attempt):
  80. mxid = assoc["mxid"]
  81. domain = mxid.split(":")[-1]
  82. server = yield self._pickServer(domain)
  83. callbackUrl = "https://%s/_matrix/federation/v1/3pid/onbind" % (
  84. server,
  85. )
  86. logger.info("Making bind callback to: %s", callbackUrl)
  87. # TODO: Not be woefully insecure
  88. agent = Agent(reactor, InsecureInterceptableContextFactory())
  89. reqDeferred = agent.request(
  90. "POST",
  91. callbackUrl.encode("utf8"),
  92. Headers({
  93. "Content-Type": ["application/json"],
  94. "User-Agent": ["Sydent"],
  95. }),
  96. FileBodyProducer(StringIO(json.dumps(assoc)))
  97. )
  98. reqDeferred.addCallback(
  99. lambda _: logger.info("Successfully notified on bind for %s" % (mxid,))
  100. )
  101. reqDeferred.addErrback(
  102. lambda err: self._notifyErrback(assoc, attempt, err)
  103. )
  104. def _notifyErrback(self, assoc, attempt, error):
  105. logger.warn("Error notifying on bind for %s: %s - rescheduling", assoc["mxid"], error)
  106. reactor.callLater(math.pow(2, attempt), self._notify, assoc, attempt + 1)
  107. # The below is lovingly ripped off of synapse/http/endpoint.py
  108. _Server = collections.namedtuple("_Server", "priority weight host port")
  109. @defer.inlineCallbacks
  110. def _pickServer(self, host):
  111. servers = yield self._fetchServers(host)
  112. if not servers:
  113. defer.returnValue("%s:8448" % (host,))
  114. min_priority = servers[0].priority
  115. weight_indexes = list(
  116. (index, server.weight + 1)
  117. for index, server in enumerate(servers)
  118. if server.priority == min_priority
  119. )
  120. total_weight = sum(weight for index, weight in weight_indexes)
  121. target_weight = random.randint(0, total_weight)
  122. for index, weight in weight_indexes:
  123. target_weight -= weight
  124. if target_weight <= 0:
  125. server = servers[index]
  126. defer.returnValue("%s:%d" % (server.host, server.port,))
  127. return
  128. @defer.inlineCallbacks
  129. def _fetchServers(self, host):
  130. try:
  131. service = "_matrix._tcp.%s" % host
  132. answers, auth, add = yield client.lookupService(service)
  133. except DNSNameError:
  134. answers = []
  135. if (len(answers) == 1
  136. and answers[0].type == dns.SRV
  137. and answers[0].payload
  138. and answers[0].payload.target == dns.Name(".")):
  139. raise DNSNameError("Service %s unavailable", service)
  140. servers = []
  141. for answer in answers:
  142. if answer.type != dns.SRV or not answer.payload:
  143. continue
  144. payload = answer.payload
  145. servers.append(ThreepidBinder._Server(
  146. host=str(payload.target),
  147. port=int(payload.port),
  148. priority=int(payload.priority),
  149. weight=int(payload.weight)
  150. ))
  151. servers.sort()
  152. defer.returnValue(servers)
  153. class InsecureInterceptableContextFactory(ssl.ContextFactory):
  154. """
  155. Factory for PyOpenSSL SSL contexts which accepts any certificate for any domain.
  156. Do not use this since it allows an attacker to intercept your communications.
  157. """
  158. def __init__(self):
  159. self._context = SSL.Context(SSL.SSLv23_METHOD)
  160. self._context.set_verify(VERIFY_NONE, lambda *_: None)
  161. def getContext(self, hostname, port):
  162. return self._context