id_generators.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import contextlib
  16. import threading
  17. from collections import deque
  18. class IdGenerator(object):
  19. def __init__(self, db_conn, table, column):
  20. self._lock = threading.Lock()
  21. self._next_id = _load_current_id(db_conn, table, column)
  22. def get_next(self):
  23. with self._lock:
  24. self._next_id += 1
  25. return self._next_id
  26. def _load_current_id(db_conn, table, column, step=1):
  27. """
  28. Args:
  29. db_conn (object):
  30. table (str):
  31. column (str):
  32. step (int):
  33. Returns:
  34. int
  35. """
  36. cur = db_conn.cursor()
  37. if step == 1:
  38. cur.execute("SELECT MAX(%s) FROM %s" % (column, table))
  39. else:
  40. cur.execute("SELECT MIN(%s) FROM %s" % (column, table))
  41. (val,) = cur.fetchone()
  42. cur.close()
  43. current_id = int(val) if val else step
  44. return (max if step > 0 else min)(current_id, step)
  45. class StreamIdGenerator(object):
  46. """Used to generate new stream ids when persisting events while keeping
  47. track of which transactions have been completed.
  48. This allows us to get the "current" stream id, i.e. the stream id such that
  49. all ids less than or equal to it have completed. This handles the fact that
  50. persistence of events can complete out of order.
  51. Args:
  52. db_conn(connection): A database connection to use to fetch the
  53. initial value of the generator from.
  54. table(str): A database table to read the initial value of the id
  55. generator from.
  56. column(str): The column of the database table to read the initial
  57. value from the id generator from.
  58. extra_tables(list): List of pairs of database tables and columns to
  59. use to source the initial value of the generator from. The value
  60. with the largest magnitude is used.
  61. step(int): which direction the stream ids grow in. +1 to grow
  62. upwards, -1 to grow downwards.
  63. Usage:
  64. with stream_id_gen.get_next() as stream_id:
  65. # ... persist event ...
  66. """
  67. def __init__(self, db_conn, table, column, extra_tables=[], step=1):
  68. assert step != 0
  69. self._lock = threading.Lock()
  70. self._step = step
  71. self._current = _load_current_id(db_conn, table, column, step)
  72. for table, column in extra_tables:
  73. self._current = (max if step > 0 else min)(
  74. self._current, _load_current_id(db_conn, table, column, step)
  75. )
  76. self._unfinished_ids = deque()
  77. def get_next(self):
  78. """
  79. Usage:
  80. with stream_id_gen.get_next() as stream_id:
  81. # ... persist event ...
  82. """
  83. with self._lock:
  84. self._current += self._step
  85. next_id = self._current
  86. self._unfinished_ids.append(next_id)
  87. @contextlib.contextmanager
  88. def manager():
  89. try:
  90. yield next_id
  91. finally:
  92. with self._lock:
  93. self._unfinished_ids.remove(next_id)
  94. return manager()
  95. def get_next_mult(self, n):
  96. """
  97. Usage:
  98. with stream_id_gen.get_next(n) as stream_ids:
  99. # ... persist events ...
  100. """
  101. with self._lock:
  102. next_ids = range(
  103. self._current + self._step,
  104. self._current + self._step * (n + 1),
  105. self._step,
  106. )
  107. self._current += n * self._step
  108. for next_id in next_ids:
  109. self._unfinished_ids.append(next_id)
  110. @contextlib.contextmanager
  111. def manager():
  112. try:
  113. yield next_ids
  114. finally:
  115. with self._lock:
  116. for next_id in next_ids:
  117. self._unfinished_ids.remove(next_id)
  118. return manager()
  119. def get_current_token(self):
  120. """Returns the maximum stream id such that all stream ids less than or
  121. equal to it have been successfully persisted.
  122. Returns:
  123. int
  124. """
  125. with self._lock:
  126. if self._unfinished_ids:
  127. return self._unfinished_ids[0] - self._step
  128. return self._current
  129. class ChainedIdGenerator(object):
  130. """Used to generate new stream ids where the stream must be kept in sync
  131. with another stream. It generates pairs of IDs, the first element is an
  132. integer ID for this stream, the second element is the ID for the stream
  133. that this stream needs to be kept in sync with."""
  134. def __init__(self, chained_generator, db_conn, table, column):
  135. self.chained_generator = chained_generator
  136. self._lock = threading.Lock()
  137. self._current_max = _load_current_id(db_conn, table, column)
  138. self._unfinished_ids = deque()
  139. def get_next(self):
  140. """
  141. Usage:
  142. with stream_id_gen.get_next() as (stream_id, chained_id):
  143. # ... persist event ...
  144. """
  145. with self._lock:
  146. self._current_max += 1
  147. next_id = self._current_max
  148. chained_id = self.chained_generator.get_current_token()
  149. self._unfinished_ids.append((next_id, chained_id))
  150. @contextlib.contextmanager
  151. def manager():
  152. try:
  153. yield (next_id, chained_id)
  154. finally:
  155. with self._lock:
  156. self._unfinished_ids.remove((next_id, chained_id))
  157. return manager()
  158. def get_current_token(self):
  159. """Returns the maximum stream id such that all stream ids less than or
  160. equal to it have been successfully persisted.
  161. """
  162. with self._lock:
  163. if self._unfinished_ids:
  164. stream_id, chained_id = self._unfinished_ids[0]
  165. return stream_id - 1, chained_id
  166. return self._current_max, self.chained_generator.get_current_token()