test_new_matrix_user.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2018 New Vector
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from mock import Mock
  16. from synapse._scripts.register_new_matrix_user import request_registration
  17. from tests.unittest import TestCase
  18. class RegisterTestCase(TestCase):
  19. def test_success(self):
  20. """
  21. The script will fetch a nonce, and then generate a MAC with it, and then
  22. post that MAC.
  23. """
  24. def get(url, verify=None):
  25. r = Mock()
  26. r.status_code = 200
  27. r.json = lambda: {"nonce": "a"}
  28. return r
  29. def post(url, json=None, verify=None):
  30. # Make sure we are sent the correct info
  31. self.assertEqual(json["username"], "user")
  32. self.assertEqual(json["password"], "pass")
  33. self.assertEqual(json["nonce"], "a")
  34. # We want a 40-char hex MAC
  35. self.assertEqual(len(json["mac"]), 40)
  36. r = Mock()
  37. r.status_code = 200
  38. return r
  39. requests = Mock()
  40. requests.get = get
  41. requests.post = post
  42. # The fake stdout will be written here
  43. out = []
  44. err_code = []
  45. request_registration(
  46. "user",
  47. "pass",
  48. "matrix.org",
  49. "shared",
  50. admin=False,
  51. requests=requests,
  52. _print=out.append,
  53. exit=err_code.append,
  54. )
  55. # We should get the success message making sure everything is OK.
  56. self.assertIn("Success!", out)
  57. # sys.exit shouldn't have been called.
  58. self.assertEqual(err_code, [])
  59. def test_failure_nonce(self):
  60. """
  61. If the script fails to fetch a nonce, it throws an error and quits.
  62. """
  63. def get(url, verify=None):
  64. r = Mock()
  65. r.status_code = 404
  66. r.reason = "Not Found"
  67. r.json = lambda: {"not": "error"}
  68. return r
  69. requests = Mock()
  70. requests.get = get
  71. # The fake stdout will be written here
  72. out = []
  73. err_code = []
  74. request_registration(
  75. "user",
  76. "pass",
  77. "matrix.org",
  78. "shared",
  79. admin=False,
  80. requests=requests,
  81. _print=out.append,
  82. exit=err_code.append,
  83. )
  84. # Exit was called
  85. self.assertEqual(err_code, [1])
  86. # We got an error message
  87. self.assertIn("ERROR! Received 404 Not Found", out)
  88. self.assertNotIn("Success!", out)
  89. def test_failure_post(self):
  90. """
  91. The script will fetch a nonce, and then if the final POST fails, will
  92. report an error and quit.
  93. """
  94. def get(url, verify=None):
  95. r = Mock()
  96. r.status_code = 200
  97. r.json = lambda: {"nonce": "a"}
  98. return r
  99. def post(url, json=None, verify=None):
  100. # Make sure we are sent the correct info
  101. self.assertEqual(json["username"], "user")
  102. self.assertEqual(json["password"], "pass")
  103. self.assertEqual(json["nonce"], "a")
  104. # We want a 40-char hex MAC
  105. self.assertEqual(len(json["mac"]), 40)
  106. r = Mock()
  107. # Then 500 because we're jerks
  108. r.status_code = 500
  109. r.reason = "Broken"
  110. return r
  111. requests = Mock()
  112. requests.get = get
  113. requests.post = post
  114. # The fake stdout will be written here
  115. out = []
  116. err_code = []
  117. request_registration(
  118. "user",
  119. "pass",
  120. "matrix.org",
  121. "shared",
  122. admin=False,
  123. requests=requests,
  124. _print=out.append,
  125. exit=err_code.append,
  126. )
  127. # Exit was called
  128. self.assertEqual(err_code, [1])
  129. # We got an error message
  130. self.assertIn("ERROR! Received 500 Broken", out)
  131. self.assertNotIn("Success!", out)