id_generators.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2014-2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections import deque
  16. import contextlib
  17. import threading
  18. class IdGenerator(object):
  19. def __init__(self, db_conn, table, column):
  20. self._lock = threading.Lock()
  21. self._next_id = _load_current_id(db_conn, table, column)
  22. def get_next(self):
  23. with self._lock:
  24. self._next_id += 1
  25. return self._next_id
  26. def _load_current_id(db_conn, table, column, step=1):
  27. """
  28. Args:
  29. db_conn (object):
  30. table (str):
  31. column (str):
  32. step (int):
  33. Returns:
  34. int
  35. """
  36. cur = db_conn.cursor()
  37. if step == 1:
  38. cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
  39. else:
  40. cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
  41. val, = cur.fetchone()
  42. cur.close()
  43. current_id = int(val) if val else step
  44. return (max if step > 0 else min)(current_id, step)
  45. class StreamIdGenerator(object):
  46. """Used to generate new stream ids when persisting events while keeping
  47. track of which transactions have been completed.
  48. This allows us to get the "current" stream id, i.e. the stream id such that
  49. all ids less than or equal to it have completed. This handles the fact that
  50. persistence of events can complete out of order.
  51. Args:
  52. db_conn(connection): A database connection to use to fetch the
  53. initial value of the generator from.
  54. table(str): A database table to read the initial value of the id
  55. generator from.
  56. column(str): The column of the database table to read the initial
  57. value from the id generator from.
  58. extra_tables(list): List of pairs of database tables and columns to
  59. use to source the initial value of the generator from. The value
  60. with the largest magnitude is used.
  61. step(int): which direction the stream ids grow in. +1 to grow
  62. upwards, -1 to grow downwards.
  63. Usage:
  64. with stream_id_gen.get_next() as stream_id:
  65. # ... persist event ...
  66. """
  67. def __init__(self, db_conn, table, column, extra_tables=[], step=1):
  68. assert step != 0
  69. self._lock = threading.Lock()
  70. self._step = step
  71. self._current = _load_current_id(db_conn, table, column, step)
  72. for table, column in extra_tables:
  73. self._current = (max if step > 0 else min)(
  74. self._current,
  75. _load_current_id(db_conn, table, column, step)
  76. )
  77. self._unfinished_ids = deque()
  78. def get_next(self):
  79. """
  80. Usage:
  81. with stream_id_gen.get_next() as stream_id:
  82. # ... persist event ...
  83. """
  84. with self._lock:
  85. self._current += self._step
  86. next_id = self._current
  87. self._unfinished_ids.append(next_id)
  88. @contextlib.contextmanager
  89. def manager():
  90. try:
  91. yield next_id
  92. finally:
  93. with self._lock:
  94. self._unfinished_ids.remove(next_id)
  95. return manager()
  96. def get_next_mult(self, n):
  97. """
  98. Usage:
  99. with stream_id_gen.get_next(n) as stream_ids:
  100. # ... persist events ...
  101. """
  102. with self._lock:
  103. next_ids = range(
  104. self._current + self._step,
  105. self._current + self._step * (n + 1),
  106. self._step
  107. )
  108. self._current += n * self._step
  109. for next_id in next_ids:
  110. self._unfinished_ids.append(next_id)
  111. @contextlib.contextmanager
  112. def manager():
  113. try:
  114. yield next_ids
  115. finally:
  116. with self._lock:
  117. for next_id in next_ids:
  118. self._unfinished_ids.remove(next_id)
  119. return manager()
  120. def get_current_token(self):
  121. """Returns the maximum stream id such that all stream ids less than or
  122. equal to it have been successfully persisted.
  123. Returns:
  124. int
  125. """
  126. with self._lock:
  127. if self._unfinished_ids:
  128. return self._unfinished_ids[0] - self._step
  129. return self._current
  130. class ChainedIdGenerator(object):
  131. """Used to generate new stream ids where the stream must be kept in sync
  132. with another stream. It generates pairs of IDs, the first element is an
  133. integer ID for this stream, the second element is the ID for the stream
  134. that this stream needs to be kept in sync with."""
  135. def __init__(self, chained_generator, db_conn, table, column):
  136. self.chained_generator = chained_generator
  137. self._lock = threading.Lock()
  138. self._current_max = _load_current_id(db_conn, table, column)
  139. self._unfinished_ids = deque()
  140. def get_next(self):
  141. """
  142. Usage:
  143. with stream_id_gen.get_next() as (stream_id, chained_id):
  144. # ... persist event ...
  145. """
  146. with self._lock:
  147. self._current_max += 1
  148. next_id = self._current_max
  149. chained_id = self.chained_generator.get_current_token()
  150. self._unfinished_ids.append((next_id, chained_id))
  151. @contextlib.contextmanager
  152. def manager():
  153. try:
  154. yield (next_id, chained_id)
  155. finally:
  156. with self._lock:
  157. self._unfinished_ids.remove((next_id, chained_id))
  158. return manager()
  159. def get_current_token(self):
  160. """Returns the maximum stream id such that all stream ids less than or
  161. equal to it have been successfully persisted.
  162. """
  163. with self._lock:
  164. if self._unfinished_ids:
  165. stream_id, chained_id = self._unfinished_ids[0]
  166. return (stream_id - 1, chained_id)
  167. return (self._current_max, self.chained_generator.get_current_token())