valsession.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import sydent.util.tokenutils
  16. from sydent.validators import ValidationSession, IncorrectClientSecretException, InvalidSessionIdException, \
  17. SessionExpiredException, SessionNotValidatedException
  18. from sydent.util import time_msec
  19. class ThreePidValSessionStore:
  20. def __init__(self, syd):
  21. self.sydent = syd
  22. def getOrCreateTokenSession(self, medium, address, clientSecret):
  23. cur = self.sydent.db.cursor()
  24. cur.execute("select s.id, s.medium, s.address, s.clientSecret, s.validated, s.mtime, "
  25. "t.token, t.sendAttemptNumber from threepid_validation_sessions s,threepid_token_auths t "
  26. "where s.medium = ? and s.address = ? and s.clientSecret = ? and t.validationSession = s.id",
  27. (medium, address, clientSecret))
  28. row = cur.fetchone()
  29. if row:
  30. s = ValidationSession(row[0], row[1], row[2], row[3], row[4], row[5])
  31. s.token = row[6]
  32. s.sendAttemptNumber = row[7]
  33. return s
  34. sid = self.addValSession(medium, address, clientSecret, time_msec(), commit=False)
  35. tokenString = sydent.util.tokenutils.generateTokenForMedium(medium)
  36. cur.execute("insert into threepid_token_auths (validationSession, token, sendAttemptNumber) values (?, ?, ?)",
  37. (sid, tokenString, -1))
  38. self.sydent.db.commit()
  39. s = ValidationSession(sid, medium, address, clientSecret, False, time_msec())
  40. s.token = tokenString
  41. s.sendAttemptNumber = -1
  42. return s
  43. def addValSession(self, medium, address, clientSecret, mtime, commit=True):
  44. cur = self.sydent.db.cursor()
  45. cur.execute("insert into threepid_validation_sessions ('medium', 'address', 'clientSecret', 'mtime')" +
  46. " values (?, ?, ?, ?)", (medium, address, clientSecret, mtime))
  47. if commit:
  48. self.sydent.db.commit()
  49. return cur.lastrowid
  50. def setSendAttemptNumber(self, sid, attemptNo):
  51. cur = self.sydent.db.cursor()
  52. cur.execute("update threepid_token_auths set sendAttemptNumber = ? where id = ?", (attemptNo, sid))
  53. self.sydent.db.commit()
  54. def setValidated(self, sid, validated):
  55. cur = self.sydent.db.cursor()
  56. cur.execute("update threepid_validation_sessions set validated = ? where id = ?", (validated, sid))
  57. self.sydent.db.commit()
  58. def setMtime(self, sid, mtime):
  59. cur = self.sydent.db.cursor()
  60. cur.execute("update threepid_validation_sessions set mtime = ? where id = ?", (mtime, sid))
  61. self.sydent.db.commit()
  62. def getSessionById(self, sid):
  63. cur = self.sydent.db.cursor()
  64. cur.execute("select id, medium, address, clientSecret, validated, mtime from "+
  65. "threepid_validation_sessions where id = ?", (sid,))
  66. row = cur.fetchone()
  67. if not row:
  68. return None
  69. return ValidationSession(row[0], row[1], row[2], row[3], row[4], row[5])
  70. def getTokenSessionById(self, sid):
  71. cur = self.sydent.db.cursor()
  72. cur.execute("select s.id, s.medium, s.address, s.clientSecret, s.validated, s.mtime, "
  73. "t.token, t.sendAttemptNumber from threepid_validation_sessions s,threepid_token_auths t "
  74. "where s.id = ? and t.validationSession = s.id", (sid,))
  75. row = cur.fetchone()
  76. if row:
  77. s = ValidationSession(row[0], row[1], row[2], row[3], row[4], row[5])
  78. s.token = row[6]
  79. s.sendAttemptNumber = row[7]
  80. return s
  81. return None
  82. def getValidatedSession(self, sid, clientSecret):
  83. """
  84. Retrieve a validated and still-valid session whose client secret matches the one passed in
  85. """
  86. s = self.getSessionById(sid)
  87. if not s:
  88. raise InvalidSessionIdException()
  89. if not s.clientSecret == clientSecret:
  90. raise IncorrectClientSecretException()
  91. if s.mtime + ValidationSession.THREEPID_SESSION_VALID_LIFETIME_MS < time_msec():
  92. raise SessionExpiredException()
  93. if not s.validated:
  94. raise SessionNotValidatedException()
  95. return s