test_msisdn.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2021 The Matrix.org Foundation C.I.C.
  2. # Licensed under the Apache License, Version 2.0 (the "License");
  3. # you may not use this file except in compliance with the License.
  4. # You may obtain a copy of the License at
  5. #
  6. # http://www.apache.org/licenses/LICENSE-2.0
  7. #
  8. # Unless required by applicable law or agreed to in writing, software
  9. # distributed under the License is distributed on an "AS IS" BASIS,
  10. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  11. # See the License for the specific language governing permissions and
  12. # limitations under the License.
  13. import asyncio
  14. import os.path
  15. from typing import Optional
  16. from unittest.mock import Mock, patch
  17. import attr
  18. from twisted.trial import unittest
  19. from sydent.types import JsonDict
  20. from tests.utils import make_request, make_sydent
  21. @attr.s(auto_attribs=True)
  22. class FakeHeader:
  23. """
  24. A fake header object
  25. """
  26. headers: dict
  27. def getAllRawHeaders(self):
  28. return self.headers
  29. @attr.s(auto_attribs=True)
  30. class FakeResponse:
  31. """A fake twisted.web.IResponse object"""
  32. # HTTP response code
  33. code: int
  34. # Fake Header
  35. headers: FakeHeader
  36. class TestRequestCode(unittest.TestCase):
  37. def setUp(self) -> None:
  38. # Create a new sydent
  39. config = {
  40. "general": {
  41. "templates.path": os.path.join(
  42. os.path.dirname(os.path.dirname(__file__)), "res"
  43. ),
  44. },
  45. }
  46. self.sydent = make_sydent(test_config=config)
  47. def _make_request(self, url: str, body: Optional[JsonDict] = None) -> Mock:
  48. # Patch out the email sending so we can investigate the resulting email.
  49. with patch("sydent.sms.openmarket.OpenMarketSMS.sendTextSMS") as sendTextSMS:
  50. # We can't use AsyncMock until Python 3.8. Instead, mock the
  51. # function as returning a future.
  52. f = asyncio.Future()
  53. f.set_result(Mock())
  54. sendTextSMS.return_value = f
  55. request, channel = make_request(
  56. self.sydent.reactor,
  57. self.sydent.clientApiHttpServer.factory,
  58. "POST",
  59. url,
  60. body,
  61. )
  62. self.assertEqual(channel.code, 200)
  63. return sendTextSMS
  64. def test_request_code(self) -> None:
  65. self.sydent.run()
  66. sendSMS_mock = self._make_request(
  67. "/_matrix/identity/api/v1/validate/msisdn/requestToken",
  68. {
  69. "phone_number": "447700900750",
  70. "country": "GB",
  71. "client_secret": "oursecret",
  72. "send_attempt": 0,
  73. },
  74. )
  75. sendSMS_mock.assert_called_once()
  76. def test_request_code_via_url_query_params(self) -> None:
  77. self.sydent.run()
  78. url = (
  79. "/_matrix/identity/api/v1/validate/msisdn/requestToken?"
  80. "phone_number=447700900750"
  81. "&country=GB"
  82. "&client_secret=oursecret"
  83. "&send_attempt=0"
  84. )
  85. sendSMS_mock = self._make_request(url)
  86. sendSMS_mock.assert_called_once()
  87. @patch("sydent.http.httpclient.HTTPClient.post_json_maybe_get_json")
  88. def test_bad_api_response_raises_exception(self, post_json: Mock) -> None:
  89. """Test that an error response from OpenMarket raises an exception
  90. and that the requester receives an error code."""
  91. header = FakeHeader({})
  92. resp = FakeResponse(code=400, headers=header), {}
  93. post_json.return_value = resp
  94. self.sydent.run()
  95. request, channel = make_request(
  96. self.sydent.reactor,
  97. self.sydent.clientApiHttpServer.factory,
  98. "POST",
  99. "/_matrix/identity/api/v1/validate/msisdn/requestToken",
  100. {
  101. "phone_number": "447700900750",
  102. "country": "GB",
  103. "client_secret": "oursecret",
  104. "send_attempt": 0,
  105. },
  106. )
  107. self.assertEqual(channel.code, 500)