test_send_email.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Callable, List, Tuple
  15. from zope.interface import implementer
  16. from twisted.internet import defer
  17. from twisted.internet.address import IPv4Address
  18. from twisted.internet.defer import ensureDeferred
  19. from twisted.mail import interfaces, smtp
  20. from tests.server import FakeTransport
  21. from tests.unittest import HomeserverTestCase, override_config
  22. @implementer(interfaces.IMessageDelivery)
  23. class _DummyMessageDelivery:
  24. def __init__(self) -> None:
  25. # (recipient, message) tuples
  26. self.messages: List[Tuple[smtp.Address, bytes]] = []
  27. def receivedHeader(
  28. self,
  29. helo: Tuple[bytes, bytes],
  30. origin: smtp.Address,
  31. recipients: List[smtp.User],
  32. ) -> None:
  33. return None
  34. def validateFrom(
  35. self, helo: Tuple[bytes, bytes], origin: smtp.Address
  36. ) -> smtp.Address:
  37. return origin
  38. def record_message(self, recipient: smtp.Address, message: bytes) -> None:
  39. self.messages.append((recipient, message))
  40. def validateTo(self, user: smtp.User) -> Callable[[], interfaces.IMessageSMTP]:
  41. return lambda: _DummyMessage(self, user)
  42. @implementer(interfaces.IMessageSMTP)
  43. class _DummyMessage:
  44. """IMessageSMTP implementation which saves the message delivered to it
  45. to the _DummyMessageDelivery object.
  46. """
  47. def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
  48. self._delivery = delivery
  49. self._user = user
  50. self._buffer: List[bytes] = []
  51. def lineReceived(self, line: bytes) -> None:
  52. self._buffer.append(line)
  53. def eomReceived(self) -> "defer.Deferred[bytes]":
  54. message = b"\n".join(self._buffer) + b"\n"
  55. self._delivery.record_message(self._user.dest, message)
  56. return defer.succeed(b"saved")
  57. def connectionLost(self) -> None:
  58. pass
  59. class SendEmailHandlerTestCase(HomeserverTestCase):
  60. def test_send_email(self) -> None:
  61. """Happy-path test that we can send email to a non-TLS server."""
  62. h = self.hs.get_send_email_handler()
  63. d = ensureDeferred(
  64. h.send_email(
  65. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  66. )
  67. )
  68. # there should be an attempt to connect to localhost:25
  69. self.assertEqual(len(self.reactor.tcpClients), 1)
  70. (host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
  71. 0
  72. ]
  73. self.assertEqual(host, "localhost")
  74. self.assertEqual(port, 25)
  75. # wire it up to an SMTP server
  76. message_delivery = _DummyMessageDelivery()
  77. server_protocol = smtp.ESMTP()
  78. server_protocol.delivery = message_delivery
  79. # make sure that the server uses the test reactor to set timeouts
  80. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  81. client_protocol = client_factory.buildProtocol(None)
  82. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  83. server_protocol.makeConnection(
  84. FakeTransport(
  85. client_protocol,
  86. self.reactor,
  87. peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
  88. )
  89. )
  90. # the message should now get delivered
  91. self.get_success(d, by=0.1)
  92. # check it arrived
  93. self.assertEqual(len(message_delivery.messages), 1)
  94. user, msg = message_delivery.messages.pop()
  95. self.assertEqual(str(user), "foo@bar.com")
  96. self.assertIn(b"Subject: test subject", msg)
  97. @override_config(
  98. {
  99. "email": {
  100. "notif_from": "noreply@test",
  101. "force_tls": True,
  102. },
  103. }
  104. )
  105. def test_send_email_force_tls(self) -> None:
  106. """Happy-path test that we can send email to an Implicit TLS server."""
  107. h = self.hs.get_send_email_handler()
  108. d = ensureDeferred(
  109. h.send_email(
  110. "foo@bar.com", "test subject", "Tests", "HTML content", "Text content"
  111. )
  112. )
  113. # there should be an attempt to connect to localhost:465
  114. self.assertEqual(len(self.reactor.sslClients), 1)
  115. (
  116. host,
  117. port,
  118. client_factory,
  119. contextFactory,
  120. _timeout,
  121. _bindAddress,
  122. ) = self.reactor.sslClients[0]
  123. self.assertEqual(host, "localhost")
  124. self.assertEqual(port, 465)
  125. # wire it up to an SMTP server
  126. message_delivery = _DummyMessageDelivery()
  127. server_protocol = smtp.ESMTP()
  128. server_protocol.delivery = message_delivery
  129. # make sure that the server uses the test reactor to set timeouts
  130. server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]
  131. client_protocol = client_factory.buildProtocol(None)
  132. client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
  133. server_protocol.makeConnection(
  134. FakeTransport(
  135. client_protocol,
  136. self.reactor,
  137. peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
  138. )
  139. )
  140. # the message should now get delivered
  141. self.get_success(d, by=0.1)
  142. # check it arrived
  143. self.assertEqual(len(message_delivery.messages), 1)
  144. user, msg = message_delivery.messages.pop()
  145. self.assertEqual(str(user), "foo@bar.com")
  146. self.assertIn(b"Subject: test subject", msg)