filtering.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # -*- coding: utf-8 -*-
  2. # Copyright 2015, 2016 OpenMarket Ltd
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from twisted.internet import defer
  16. from ._base import SQLBaseStore
  17. from synapse.util.caches.descriptors import cachedInlineCallbacks
  18. import simplejson as json
  19. class FilteringStore(SQLBaseStore):
  20. @cachedInlineCallbacks(num_args=2)
  21. def get_user_filter(self, user_localpart, filter_id):
  22. def_json = yield self._simple_select_one_onecol(
  23. table="user_filters",
  24. keyvalues={
  25. "user_id": user_localpart,
  26. "filter_id": filter_id,
  27. },
  28. retcol="filter_json",
  29. allow_none=False,
  30. desc="get_user_filter",
  31. )
  32. defer.returnValue(json.loads(str(def_json).decode("utf-8")))
  33. def add_user_filter(self, user_localpart, user_filter):
  34. def_json = json.dumps(user_filter).encode("utf-8")
  35. # Need an atomic transaction to SELECT the maximal ID so far then
  36. # INSERT a new one
  37. def _do_txn(txn):
  38. sql = (
  39. "SELECT MAX(filter_id) FROM user_filters "
  40. "WHERE user_id = ?"
  41. )
  42. txn.execute(sql, (user_localpart,))
  43. max_id = txn.fetchone()[0]
  44. if max_id is None:
  45. filter_id = 0
  46. else:
  47. filter_id = max_id + 1
  48. sql = (
  49. "INSERT INTO user_filters (user_id, filter_id, filter_json)"
  50. "VALUES(?, ?, ?)"
  51. )
  52. txn.execute(sql, (user_localpart, filter_id, def_json))
  53. return filter_id
  54. return self.runInteraction("add_user_filter", _do_txn)