account_validity.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2019 New Vector Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import email.mime.multipart
  16. import email.utils
  17. import logging
  18. from email.mime.multipart import MIMEMultipart
  19. from email.mime.text import MIMEText
  20. from typing import List
  21. from synapse.api.errors import StoreError
  22. from synapse.logging.context import make_deferred_yieldable
  23. from synapse.metrics.background_process_metrics import run_as_background_process
  24. from synapse.types import UserID
  25. from synapse.util import stringutils
  26. try:
  27. from synapse.push.mailer import load_jinja2_templates
  28. except ImportError:
  29. load_jinja2_templates = None
  30. logger = logging.getLogger(__name__)
  31. class AccountValidityHandler(object):
  32. def __init__(self, hs):
  33. self.hs = hs
  34. self.config = hs.config
  35. self.store = self.hs.get_datastore()
  36. self.sendmail = self.hs.get_sendmail()
  37. self.clock = self.hs.get_clock()
  38. self._account_validity = self.hs.config.account_validity
  39. if self._account_validity.renew_by_email_enabled and load_jinja2_templates:
  40. # Don't do email-specific configuration if renewal by email is disabled.
  41. try:
  42. app_name = self.hs.config.email_app_name
  43. self._subject = self._account_validity.renew_email_subject % {
  44. "app": app_name
  45. }
  46. self._from_string = self.hs.config.email_notif_from % {"app": app_name}
  47. except Exception:
  48. # If substitution failed, fall back to the bare strings.
  49. self._subject = self._account_validity.renew_email_subject
  50. self._from_string = self.hs.config.email_notif_from
  51. self._raw_from = email.utils.parseaddr(self._from_string)[1]
  52. self._template_html, self._template_text = load_jinja2_templates(
  53. self.config.email_template_dir,
  54. [
  55. self.config.email_expiry_template_html,
  56. self.config.email_expiry_template_text,
  57. ],
  58. apply_format_ts_filter=True,
  59. apply_mxc_to_http_filter=True,
  60. public_baseurl=self.config.public_baseurl,
  61. )
  62. # Check the renewal emails to send and send them every 30min.
  63. def send_emails():
  64. # run as a background process to make sure that the database transactions
  65. # have a logcontext to report to
  66. return run_as_background_process(
  67. "send_renewals", self._send_renewal_emails
  68. )
  69. self.clock.looping_call(send_emails, 30 * 60 * 1000)
  70. async def _send_renewal_emails(self):
  71. """Gets the list of users whose account is expiring in the amount of time
  72. configured in the ``renew_at`` parameter from the ``account_validity``
  73. configuration, and sends renewal emails to all of these users as long as they
  74. have an email 3PID attached to their account.
  75. """
  76. expiring_users = await self.store.get_users_expiring_soon()
  77. if expiring_users:
  78. for user in expiring_users:
  79. await self._send_renewal_email(
  80. user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
  81. )
  82. async def send_renewal_email_to_user(self, user_id: str):
  83. expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
  84. await self._send_renewal_email(user_id, expiration_ts)
  85. async def _send_renewal_email(self, user_id: str, expiration_ts: int):
  86. """Sends out a renewal email to every email address attached to the given user
  87. with a unique link allowing them to renew their account.
  88. Args:
  89. user_id: ID of the user to send email(s) to.
  90. expiration_ts: Timestamp in milliseconds for the expiration date of
  91. this user's account (used in the email templates).
  92. """
  93. addresses = await self._get_email_addresses_for_user(user_id)
  94. # Stop right here if the user doesn't have at least one email address.
  95. # In this case, they will have to ask their server admin to renew their
  96. # account manually.
  97. # We don't need to do a specific check to make sure the account isn't
  98. # deactivated, as a deactivated account isn't supposed to have any
  99. # email address attached to it.
  100. if not addresses:
  101. return
  102. try:
  103. user_display_name = await self.store.get_profile_displayname(
  104. UserID.from_string(user_id).localpart
  105. )
  106. if user_display_name is None:
  107. user_display_name = user_id
  108. except StoreError:
  109. user_display_name = user_id
  110. renewal_token = await self._get_renewal_token(user_id)
  111. url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
  112. self.hs.config.public_baseurl,
  113. renewal_token,
  114. )
  115. template_vars = {
  116. "display_name": user_display_name,
  117. "expiration_ts": expiration_ts,
  118. "url": url,
  119. }
  120. html_text = self._template_html.render(**template_vars)
  121. html_part = MIMEText(html_text, "html", "utf8")
  122. plain_text = self._template_text.render(**template_vars)
  123. text_part = MIMEText(plain_text, "plain", "utf8")
  124. for address in addresses:
  125. raw_to = email.utils.parseaddr(address)[1]
  126. multipart_msg = MIMEMultipart("alternative")
  127. multipart_msg["Subject"] = self._subject
  128. multipart_msg["From"] = self._from_string
  129. multipart_msg["To"] = address
  130. multipart_msg["Date"] = email.utils.formatdate()
  131. multipart_msg["Message-ID"] = email.utils.make_msgid()
  132. multipart_msg.attach(text_part)
  133. multipart_msg.attach(html_part)
  134. logger.info("Sending renewal email to %s", address)
  135. await make_deferred_yieldable(
  136. self.sendmail(
  137. self.hs.config.email_smtp_host,
  138. self._raw_from,
  139. raw_to,
  140. multipart_msg.as_string().encode("utf8"),
  141. reactor=self.hs.get_reactor(),
  142. port=self.hs.config.email_smtp_port,
  143. requireAuthentication=self.hs.config.email_smtp_user is not None,
  144. username=self.hs.config.email_smtp_user,
  145. password=self.hs.config.email_smtp_pass,
  146. requireTransportSecurity=self.hs.config.require_transport_security,
  147. )
  148. )
  149. await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
  150. async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
  151. """Retrieve the list of email addresses attached to a user's account.
  152. Args:
  153. user_id: ID of the user to lookup email addresses for.
  154. Returns:
  155. Email addresses for this account.
  156. """
  157. threepids = await self.store.user_get_threepids(user_id)
  158. addresses = []
  159. for threepid in threepids:
  160. if threepid["medium"] == "email":
  161. addresses.append(threepid["address"])
  162. return addresses
  163. async def _get_renewal_token(self, user_id: str) -> str:
  164. """Generates a 32-byte long random string that will be inserted into the
  165. user's renewal email's unique link, then saves it into the database.
  166. Args:
  167. user_id: ID of the user to generate a string for.
  168. Returns:
  169. The generated string.
  170. Raises:
  171. StoreError(500): Couldn't generate a unique string after 5 attempts.
  172. """
  173. attempts = 0
  174. while attempts < 5:
  175. try:
  176. renewal_token = stringutils.random_string(32)
  177. await self.store.set_renewal_token_for_user(user_id, renewal_token)
  178. return renewal_token
  179. except StoreError:
  180. attempts += 1
  181. raise StoreError(500, "Couldn't generate a unique string as refresh string.")
  182. async def renew_account(self, renewal_token: str) -> bool:
  183. """Renews the account attached to a given renewal token by pushing back the
  184. expiration date by the current validity period in the server's configuration.
  185. Args:
  186. renewal_token: Token sent with the renewal request.
  187. Returns:
  188. Whether the provided token is valid.
  189. """
  190. try:
  191. user_id = await self.store.get_user_from_renewal_token(renewal_token)
  192. except StoreError:
  193. return False
  194. logger.debug("Renewing an account for user %s", user_id)
  195. await self.renew_account_for_user(user_id)
  196. return True
  197. async def renew_account_for_user(
  198. self, user_id: str, expiration_ts: int = None, email_sent: bool = False
  199. ) -> int:
  200. """Renews the account attached to a given user by pushing back the
  201. expiration date by the current validity period in the server's
  202. configuration.
  203. Args:
  204. renewal_token: Token sent with the renewal request.
  205. expiration_ts: New expiration date. Defaults to now + validity period.
  206. email_sen: Whether an email has been sent for this validity period.
  207. Defaults to False.
  208. Returns:
  209. New expiration date for this account, as a timestamp in
  210. milliseconds since epoch.
  211. """
  212. if expiration_ts is None:
  213. expiration_ts = self.clock.time_msec() + self._account_validity.period
  214. await self.store.set_account_validity_for_user(
  215. user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
  216. )
  217. return expiration_ts