|
@@ -1,5 +1,6 @@
|
|
|
# -*- coding: utf-8 -*-
|
|
|
# Copyright 2014-2016 OpenMarket Ltd
|
|
|
+# Copyright 2018 New Vector Ltd
|
|
|
#
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
# you may not use this file except in compliance with the License.
|
|
@@ -14,11 +15,13 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
from ._base import SQLBaseStore
|
|
|
+from .util.id_generators import StreamIdGenerator
|
|
|
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList, cached
|
|
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
|
|
|
|
|
from twisted.internet import defer
|
|
|
|
|
|
+import abc
|
|
|
import logging
|
|
|
import ujson as json
|
|
|
|
|
@@ -26,39 +29,36 @@ import ujson as json
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
-class ReceiptsStore(SQLBaseStore):
|
|
|
+class ReceiptsWorkerStore(SQLBaseStore):
|
|
|
+ """This is an abstract base class where subclasses must implement
|
|
|
+ `get_max_receipt_stream_id` which can be called in the initializer.
|
|
|
+ """
|
|
|
+
|
|
|
+ # This ABCMeta metaclass ensures that we cannot be instantiated without
|
|
|
+ # the abstract methods being implemented.
|
|
|
+ __metaclass__ = abc.ABCMeta
|
|
|
+
|
|
|
def __init__(self, db_conn, hs):
|
|
|
- super(ReceiptsStore, self).__init__(db_conn, hs)
|
|
|
+ super(ReceiptsWorkerStore, self).__init__(db_conn, hs)
|
|
|
|
|
|
self._receipts_stream_cache = StreamChangeCache(
|
|
|
- "ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
|
|
|
+ "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id()
|
|
|
)
|
|
|
|
|
|
+ @abc.abstractmethod
|
|
|
+ def get_max_receipt_stream_id(self):
|
|
|
+ """Get the current max stream ID for receipts stream
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ int
|
|
|
+ """
|
|
|
+ raise NotImplementedError()
|
|
|
+
|
|
|
@cachedInlineCallbacks()
|
|
|
def get_users_with_read_receipts_in_room(self, room_id):
|
|
|
receipts = yield self.get_receipts_for_room(room_id, "m.read")
|
|
|
defer.returnValue(set(r['user_id'] for r in receipts))
|
|
|
|
|
|
- def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
|
|
- user_id):
|
|
|
- if receipt_type != "m.read":
|
|
|
- return
|
|
|
-
|
|
|
- # Returns an ObservableDeferred
|
|
|
- res = self.get_users_with_read_receipts_in_room.cache.get(
|
|
|
- room_id, None, update_metrics=False,
|
|
|
- )
|
|
|
-
|
|
|
- if res:
|
|
|
- if isinstance(res, defer.Deferred) and res.called:
|
|
|
- res = res.result
|
|
|
- if user_id in res:
|
|
|
- # We'd only be adding to the set, so no point invalidating if the
|
|
|
- # user is already there
|
|
|
- return
|
|
|
-
|
|
|
- self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
|
|
-
|
|
|
@cached(num_args=2)
|
|
|
def get_receipts_for_room(self, room_id, receipt_type):
|
|
|
return self._simple_select_list(
|
|
@@ -270,9 +270,62 @@ class ReceiptsStore(SQLBaseStore):
|
|
|
}
|
|
|
defer.returnValue(results)
|
|
|
|
|
|
+ def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
|
|
+ if last_id == current_id:
|
|
|
+ return defer.succeed([])
|
|
|
+
|
|
|
+ def get_all_updated_receipts_txn(txn):
|
|
|
+ sql = (
|
|
|
+ "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
|
|
+ " FROM receipts_linearized"
|
|
|
+ " WHERE ? < stream_id AND stream_id <= ?"
|
|
|
+ " ORDER BY stream_id ASC"
|
|
|
+ )
|
|
|
+ args = [last_id, current_id]
|
|
|
+ if limit is not None:
|
|
|
+ sql += " LIMIT ?"
|
|
|
+ args.append(limit)
|
|
|
+ txn.execute(sql, args)
|
|
|
+
|
|
|
+ return txn.fetchall()
|
|
|
+ return self.runInteraction(
|
|
|
+ "get_all_updated_receipts", get_all_updated_receipts_txn
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+class ReceiptsStore(ReceiptsWorkerStore):
|
|
|
+ def __init__(self, db_conn, hs):
|
|
|
+ # We instantiate this first as the ReceiptsWorkerStore constructor
|
|
|
+ # needs to be able to call get_max_receipt_stream_id
|
|
|
+ self._receipts_id_gen = StreamIdGenerator(
|
|
|
+ db_conn, "receipts_linearized", "stream_id"
|
|
|
+ )
|
|
|
+
|
|
|
+ super(ReceiptsStore, self).__init__(db_conn, hs)
|
|
|
+
|
|
|
def get_max_receipt_stream_id(self):
|
|
|
return self._receipts_id_gen.get_current_token()
|
|
|
|
|
|
+ def _invalidate_get_users_with_receipts_in_room(self, room_id, receipt_type,
|
|
|
+ user_id):
|
|
|
+ if receipt_type != "m.read":
|
|
|
+ return
|
|
|
+
|
|
|
+ # Returns an ObservableDeferred
|
|
|
+ res = self.get_users_with_read_receipts_in_room.cache.get(
|
|
|
+ room_id, None, update_metrics=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ if res:
|
|
|
+ if isinstance(res, defer.Deferred) and res.called:
|
|
|
+ res = res.result
|
|
|
+ if user_id in res:
|
|
|
+ # We'd only be adding to the set, so no point invalidating if the
|
|
|
+ # user is already there
|
|
|
+ return
|
|
|
+
|
|
|
+ self.get_users_with_read_receipts_in_room.invalidate((room_id,))
|
|
|
+
|
|
|
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
|
|
|
user_id, event_id, data, stream_id):
|
|
|
txn.call_after(
|
|
@@ -457,25 +510,3 @@ class ReceiptsStore(SQLBaseStore):
|
|
|
"data": json.dumps(data),
|
|
|
}
|
|
|
)
|
|
|
-
|
|
|
- def get_all_updated_receipts(self, last_id, current_id, limit=None):
|
|
|
- if last_id == current_id:
|
|
|
- return defer.succeed([])
|
|
|
-
|
|
|
- def get_all_updated_receipts_txn(txn):
|
|
|
- sql = (
|
|
|
- "SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
|
|
|
- " FROM receipts_linearized"
|
|
|
- " WHERE ? < stream_id AND stream_id <= ?"
|
|
|
- " ORDER BY stream_id ASC"
|
|
|
- )
|
|
|
- args = [last_id, current_id]
|
|
|
- if limit is not None:
|
|
|
- sql += " LIMIT ?"
|
|
|
- args.append(limit)
|
|
|
- txn.execute(sql, args)
|
|
|
-
|
|
|
- return txn.fetchall()
|
|
|
- return self.runInteraction(
|
|
|
- "get_all_updated_receipts", get_all_updated_receipts_txn
|
|
|
- )
|