1
0

test_send_email.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import List, Tuple
  15. from zope.interface import implementer
  16. from twisted.internet import defer
  17. from twisted.internet.address import IPv4Address
  18. from twisted.internet.defer import ensureDeferred
  19. from twisted.mail import interfaces, smtp
  20. from tests.server import FakeTransport
  21. from tests.unittest import HomeserverTestCase, override_config
  22. @implementer(interfaces.IMessageDelivery)
  23. class _DummyMessageDelivery:
  24. def __init__(self):
  25. # (recipient, message) tuples
  26. self.messages: List[Tuple[smtp.Address, bytes]] = []
  27. def receivedHeader(self, helo, origin, recipients):
  28. return None
  29. def validateFrom(self, helo, origin):
  30. return origin
  31. def record_message(self, recipient: smtp.Address, message: bytes):
  32. self.messages.append((recipient, message))
  33. def validateTo(self, user: smtp.User):
  34. return lambda: _DummyMessage(self, user)
  35. @implementer(interfaces.IMessageSMTP)
  36. class _DummyMessage:
  37. """IMessageSMTP implementation which saves the message delivered to it
  38. to the _DummyMessageDelivery object.
  39. """
  40. def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
  41. self._delivery = delivery
  42. self._user = user
  43. self._buffer: List[bytes] = []
  44. def lineReceived(self, line):
  45. self._buffer.append(line)
  46. def eomReceived(self):
  47. message = b"\n".join(self._buffer) + b"\n"
  48. self._delivery.record_message(self._user.dest, message)
  49. return defer.succeed(b"saved")
  50. def connectionLost(self):
  51. pass
  52. class SendEmailHandlerTestCase(HomeserverTestCase):
  53. def test_send_email(self):
  54. """Happy-path test that we can send email to a non-TLS server."""
  55. h = self.hs.get_send_email_handler()
  56. d = ensureDeferred(
  57. h.send_email(
  58. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  59. )
  60. )
  61. # there should be an attempt to connect to localhost:25
  62. self.assertEqual(len(self.reactor.tcpClients), 1)
  63. (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
  64. 0
  65. ]
  66. self.assertEqual(host, "localhost")
  67. self.assertEqual(port, 25)
  68. # wire it up to an SMTP server
  69. message_delivery = _DummyMessageDelivery()
  70. server_protocol = smtp.ESMTP()
  71. server_protocol.delivery = message_delivery
  72. # make sure that the server uses the test reactor to set timeouts
  73. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  74. client_protocol = client_factory.buildProtocol(None)
  75. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  76. server_protocol.makeConnection(
  77. FakeTransport(
  78. client_protocol,
  79. self.reactor,
  80. peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
  81. )
  82. )
  83. # the message should now get delivered
  84. self.get_success(d, by=0.1)
  85. # check it arrived
  86. self.assertEqual(len(message_delivery.messages), 1)
  87. user, msg = message_delivery.messages.pop()
  88. self.assertEqual(str(user), "foo@bar.com")
  89. self.assertIn(b"Subject: test subject", msg)
  90. @override_config(
  91. {
  92. "email": {
  93. "notif_from": "noreply@test",
  94. "force_tls": True,
  95. },
  96. }
  97. )
  98. def test_send_email_force_tls(self):
  99. """Happy-path test that we can send email to an Implicit TLS server."""
  100. h = self.hs.get_send_email_handler()
  101. d = ensureDeferred(
  102. h.send_email(
  103. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  104. )
  105. )
  106. # there should be an attempt to connect to localhost:465
  107. self.assertEqual(len(self.reactor.sslClients), 1)
  108. (
  109. host,
  110. port,
  111. client_factory,
  112. contextFactory,
  113. _timeout,
  114. _bindAddress,
  115. ) = self.reactor.sslClients[0]
  116. self.assertEqual(host, "localhost")
  117. self.assertEqual(port, 465)
  118. # wire it up to an SMTP server
  119. message_delivery = _DummyMessageDelivery()
  120. server_protocol = smtp.ESMTP()
  121. server_protocol.delivery = message_delivery
  122. # make sure that the server uses the test reactor to set timeouts
  123. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  124. client_protocol = client_factory.buildProtocol(None)
  125. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  126. server_protocol.makeConnection(
  127. FakeTransport(
  128. client_protocol,
  129. self.reactor,
  130. peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
  131. )
  132. )
  133. # the message should now get delivered
  134. self.get_success(d, by=0.1)
  135. # check it arrived
  136. self.assertEqual(len(message_delivery.messages), 1)
  137. user, msg = message_delivery.messages.pop()
  138. self.assertEqual(str(user), "foo@bar.com")
  139. self.assertIn(b"Subject: test subject", msg)