Browse Source

Merge branch 'release-v0.14.0' of github.com:matrix-org/synapse

Erik Johnston 8 years ago
parent
commit
5fbdf2bcec
100 changed files with 3786 additions and 1819 deletions
  1. 65 0
      CHANGES.rst
  2. 1 0
      MANIFEST.in
  3. 38 2
      README.rst
  4. 22 0
      jenkins-flake8.sh
  5. 61 0
      jenkins-postgres.sh
  6. 55 0
      jenkins-sqlite.sh
  7. 25 0
      jenkins-unittests.sh
  8. 6 1
      jenkins.sh
  9. 22 1
      scripts-dev/definitions.py
  10. 67 0
      scripts-dev/tail-synapse.py
  11. 0 1
      scripts/gen_password
  12. 39 0
      scripts/hash_password
  13. 7 2
      scripts/synapse_port_db
  14. 1 1
      synapse/__init__.py
  15. 88 27
      synapse/api/auth.py
  16. 0 1
      synapse/api/constants.py
  17. 4 1
      synapse/api/filtering.py
  18. 6 2
      synapse/app/homeserver.py
  19. 3 3
      synapse/app/synctl.py
  20. 1 1
      synapse/config/__main__.py
  21. 1 1
      synapse/config/_base.py
  22. 40 0
      synapse/config/api.py
  23. 2 1
      synapse/config/homeserver.py
  24. 4 0
      synapse/config/registration.py
  25. 3 0
      synapse/config/repository.py
  26. 3 0
      synapse/crypto/keyclient.py
  27. 8 1
      synapse/events/__init__.py
  28. 2 1
      synapse/federation/federation_client.py
  29. 22 6
      synapse/federation/federation_server.py
  30. 1 0
      synapse/federation/transport/client.py
  31. 27 17
      synapse/federation/transport/server.py
  32. 171 37
      synapse/handlers/_base.py
  33. 70 15
      synapse/handlers/auth.py
  34. 102 20
      synapse/handlers/directory.py
  35. 35 78
      synapse/handlers/events.py
  36. 105 49
      synapse/handlers/federation.py
  37. 64 43
      synapse/handlers/message.py
  38. 520 364
      synapse/handlers/presence.py
  39. 24 24
      synapse/handlers/profile.py
  40. 0 2
      synapse/handlers/receipts.py
  41. 44 8
      synapse/handlers/register.py
  42. 395 336
      synapse/handlers/room.py
  43. 76 9
      synapse/handlers/sync.py
  44. 14 0
      synapse/handlers/typing.py
  45. 15 2
      synapse/http/client.py
  46. 26 5
      synapse/http/server.py
  47. 85 5
      synapse/http/servlet.py
  48. 55 1
      synapse/notifier.py
  49. 5 6
      synapse/push/__init__.py
  50. 3 3
      synapse/push/action_generator.py
  51. 50 7
      synapse/push/baserules.py
  52. 8 6
      synapse/push/bulk_push_rule_evaluator.py
  53. 112 0
      synapse/push/clientformat.py
  54. 1 2
      synapse/push/httppusher.py
  55. 7 12
      synapse/push/push_rule_evaluator.py
  56. 25 33
      synapse/push/pusherpool.py
  57. 2 2
      synapse/python_dependencies.py
  58. 14 0
      synapse/replication/__init__.py
  59. 367 0
      synapse/replication/resource.py
  60. 2 0
      synapse/rest/__init__.py
  61. 1 1
      synapse/rest/client/v1/admin.py
  62. 50 16
      synapse/rest/client/v1/directory.py
  63. 1 1
      synapse/rest/client/v1/initial_sync.py
  64. 10 17
      synapse/rest/client/v1/login.py
  65. 72 0
      synapse/rest/client/v1/logout.py
  66. 21 18
      synapse/rest/client/v1/presence.py
  67. 6 6
      synapse/rest/client/v1/profile.py
  68. 65 206
      synapse/rest/client/v1/push_rule.py
  69. 12 17
      synapse/rest/client/v1/pusher.py
  70. 3 13
      synapse/rest/client/v1/register.py
  71. 74 100
      synapse/rest/client/v1/room.py
  72. 1 1
      synapse/rest/client/v1/voip.py
  73. 0 22
      synapse/rest/client/v2_alpha/_base.py
  74. 6 6
      synapse/rest/client/v2_alpha/account.py
  75. 4 17
      synapse/rest/client/v2_alpha/account_data.py
  76. 3 2
      synapse/rest/client/v2_alpha/auth.py
  77. 2 8
      synapse/rest/client/v2_alpha/filter.py
  78. 7 15
      synapse/rest/client/v2_alpha/keys.py
  79. 3 0
      synapse/rest/client/v2_alpha/receipts.py
  80. 48 11
      synapse/rest/client/v2_alpha/register.py
  81. 10 6
      synapse/rest/client/v2_alpha/sync.py
  82. 3 9
      synapse/rest/client/v2_alpha/tags.py
  83. 3 3
      synapse/rest/client/v2_alpha/tokenrefresh.py
  84. 2 10
      synapse/rest/key/v2/remote_key_resource.py
  85. 2 2
      synapse/rest/media/v0/content_repository.py
  86. 2 2
      synapse/rest/media/v1/base_resource.py
  87. 68 75
      synapse/state.py
  88. 74 11
      synapse/storage/__init__.py
  89. 27 9
      synapse/storage/_base.py
  90. 38 6
      synapse/storage/account_data.py
  91. 21 13
      synapse/storage/appservice.py
  92. 15 2
      synapse/storage/directory.py
  93. 1 1
      synapse/storage/end_to_end_keys.py
  94. 3 2
      synapse/storage/engines/__init__.py
  95. 3 2
      synapse/storage/engines/postgres.py
  96. 3 2
      synapse/storage/engines/sqlite3.py
  97. 4 4
      synapse/storage/event_federation.py
  98. 4 5
      synapse/storage/event_push_actions.py
  99. 92 38
      synapse/storage/events.py
  100. 1 1
      synapse/storage/keys.py

+ 65 - 0
CHANGES.rst

@@ -1,3 +1,68 @@
+Changes in synapse v0.14.0 (2016-03-30)
+=======================================
+
+No changes from v0.14.0-rc2
+
+Changes in synapse v0.14.0-rc2 (2016-03-23)
+===========================================
+
+Features:
+
+* Add published room list API (PR #657)
+
+Changes:
+
+* Change various caches to consume less memory (PR #656, #658, #660, #662,
+  #663, #665)
+* Allow rooms to be published without requiring an alias (PR #664)
+* Intern common strings in caches to reduce memory footprint (#666)
+
+Bug fixes:
+
+* Fix reject invites over federation (PR #646)
+* Fix bug where registration was not idempotent (PR #649)
+* Update aliases event after deleting aliases (PR #652)
+* Fix unread notification count, which was sometimes wrong (PR #661)
+
+Changes in synapse v0.14.0-rc1 (2016-03-14)
+===========================================
+
+Features:
+
+* Add event_id to response to state event PUT (PR #581)
+* Allow guest users access to messages in rooms they have joined (PR #587)
+* Add config for what state is included in a room invite (PR #598)
+* Send the inviter's member event in room invite state (PR #607)
+* Add error codes for malformed/bad JSON in /login (PR #608)
+* Add support for changing the actions for default rules (PR #609)
+* Add environment variable SYNAPSE_CACHE_FACTOR, default it to 0.1 (PR #612)
+* Add ability for alias creators to delete aliases (PR #614)
+* Add profile information to invites (PR #624)
+
+Changes:
+
+* Enforce user_id exclusivity for AS registrations (PR #572)
+* Make adding push rules idempotent (PR #587)
+* Improve presence performance (PR #582, #586)
+* Change presence semantics for ``last_active_ago`` (PR #582, #586)
+* Don't allow ``m.room.create`` to be changed (PR #596)
+* Add 800x600 to default list of valid thumbnail sizes (PR #616)
+* Always include kicks and bans in full /sync (PR #625)
+* Send history visibility on boundary changes (PR #626)
+* Register endpoint now returns a refresh_token (PR #637)
+
+Bug fixes:
+
+* Fix bug where we returned incorrect state in /sync (PR #573)
+* Always return a JSON object from push rule API (PR #606)
+* Fix bug where registering without a user id sometimes failed (PR #610)
+* Report size of ExpiringCache in cache size metrics (PR #611)
+* Fix rejection of invites to empty rooms (PR #615)
+* Fix usage of ``bcrypt`` to not use ``checkpw`` (PR #619)
+* Pin ``pysaml2`` dependency (PR #634)
+* Fix bug in ``/sync`` where timeline order was incorrect for backfilled events
+  (PR #635)
+
 Changes in synapse v0.13.3 (2016-02-11)
 =======================================
 

+ 1 - 0
MANIFEST.in

@@ -21,5 +21,6 @@ recursive-include synapse/static *.html
 recursive-include synapse/static *.js
 
 exclude jenkins.sh
+exclude jenkins*.sh
 
 prune demo/etc

+ 38 - 2
README.rst

@@ -525,7 +525,6 @@ Logging In To An Existing Account
 Just enter the ``@localpart:my.domain.here`` Matrix user ID and password into
 the form and click the Login button.
 
-
 Identity Servers
 ================
 
@@ -545,6 +544,26 @@ as the primary means of identity and E2E encryption is not complete. As such,
 we are running a single identity server (https://matrix.org) at the current
 time.
 
+Password reset
+==============
+
+If a user has registered an email address to their account using an identity
+server, they can request a password-reset token via clients such as Vector.
+
+A manual password reset can be done via direct database access as follows.
+
+First calculate the hash of the new password:
+
+    $ source ~/.synapse/bin/activate
+    $ ./scripts/hash_password
+    Password: 
+    Confirm password: 
+    $2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
+
+Then update the `users` table in the database:
+
+    UPDATE users SET password_hash='$2a$12$xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'
+        WHERE name='@test:test.com';
 
 Where's the spec?!
 ==================
@@ -565,4 +584,21 @@ sphinxcontrib-napoleon::
 Building internal API documentation::
 
     python setup.py build_sphinx
-    
+
+
+
+Halp!! Synapse eats all my RAM!
+===============================
+
+Synapse's architecture is quite RAM hungry currently - we deliberately
+cache a lot of recent room data and metadata in RAM in order to speed up
+common requests.  We'll improve this in future, but for now the easiest
+way to either reduce the RAM usage (at the risk of slowing things down)
+is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
+variable.  Roughly speaking, a SYNAPSE_CACHE_FACTOR of 1.0 will max out
+at around 3-4GB of resident memory - this is what we currently run the
+matrix.org on.  The default setting is currently 0.1, which is probably
+around a ~700MB footprint.  You can dial it down further to 0.02 if
+desired, which targets roughly ~512MB.  Conversely you can dial it up if
+you need performance for lots of users and have a box with a lot of RAM.
+

+ 22 - 0
jenkins-flake8.sh

@@ -0,0 +1,22 @@
+#!/bin/bash
+
+set -eux
+
+: ${WORKSPACE:="$(pwd)"}
+
+export PYTHONDONTWRITEBYTECODE=yep
+export SYNAPSE_CACHE_FACTOR=1
+
+# Output test results as junit xml
+export TRIAL_FLAGS="--reporter=subunit"
+export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
+# Write coverage reports to a separate file for each process
+export COVERAGE_OPTS="-p"
+export DUMP_COVERAGE_COMMAND="coverage help"
+
+# Output flake8 violations to violations.flake8.log
+export PEP8SUFFIX="--output-file=violations.flake8.log"
+
+rm .coverage* || echo "No coverage files to remove"
+
+tox -e packaging -e pep8

+ 61 - 0
jenkins-postgres.sh

@@ -0,0 +1,61 @@
+#!/bin/bash
+
+set -eux
+
+: ${WORKSPACE:="$(pwd)"}
+
+export PYTHONDONTWRITEBYTECODE=yep
+export SYNAPSE_CACHE_FACTOR=1
+
+# Output test results as junit xml
+export TRIAL_FLAGS="--reporter=subunit"
+export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
+# Write coverage reports to a separate file for each process
+export COVERAGE_OPTS="-p"
+export DUMP_COVERAGE_COMMAND="coverage help"
+
+# Output flake8 violations to violations.flake8.log
+# Don't exit with non-0 status code on Jenkins,
+# so that the build steps continue and a later step can decided whether to
+# UNSTABLE or FAILURE this build.
+export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
+
+rm .coverage* || echo "No coverage files to remove"
+
+tox --notest -e py27
+
+TOX_BIN=$WORKSPACE/.tox/py27/bin
+$TOX_BIN/pip install psycopg2
+
+: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
+
+if [[ ! -e .sytest-base ]]; then
+  git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
+else
+  (cd .sytest-base; git fetch -p)
+fi
+
+rm -rf sytest
+git clone .sytest-base sytest --shared
+cd sytest
+
+git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
+
+: ${PORT_BASE:=8000}
+
+./jenkins/prep_sytest_for_postgres.sh
+
+echo >&2 "Running sytest with PostgreSQL";
+./jenkins/install_and_run.sh --coverage \
+                             --python $TOX_BIN/python \
+                             --synapse-directory $WORKSPACE \
+                             --port-base $PORT_BASE
+
+cd ..
+cp sytest/.coverage.* .
+
+# Combine the coverage reports
+echo "Combining:" .coverage.*
+$TOX_BIN/python -m coverage combine
+# Output coverage to coverage.xml
+$TOX_BIN/coverage xml -o coverage.xml

+ 55 - 0
jenkins-sqlite.sh

@@ -0,0 +1,55 @@
+#!/bin/bash
+
+set -eux
+
+: ${WORKSPACE:="$(pwd)"}
+
+export PYTHONDONTWRITEBYTECODE=yep
+export SYNAPSE_CACHE_FACTOR=1
+
+# Output test results as junit xml
+export TRIAL_FLAGS="--reporter=subunit"
+export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
+# Write coverage reports to a separate file for each process
+export COVERAGE_OPTS="-p"
+export DUMP_COVERAGE_COMMAND="coverage help"
+
+# Output flake8 violations to violations.flake8.log
+# Don't exit with non-0 status code on Jenkins,
+# so that the build steps continue and a later step can decided whether to
+# UNSTABLE or FAILURE this build.
+export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
+
+rm .coverage* || echo "No coverage files to remove"
+
+tox --notest -e py27
+TOX_BIN=$WORKSPACE/.tox/py27/bin
+
+: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
+
+if [[ ! -e .sytest-base ]]; then
+  git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
+else
+  (cd .sytest-base; git fetch -p)
+fi
+
+rm -rf sytest
+git clone .sytest-base sytest --shared
+cd sytest
+
+git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
+
+: ${PORT_BASE:=8500}
+./jenkins/install_and_run.sh --coverage \
+                             --python $TOX_BIN/python \
+                             --synapse-directory $WORKSPACE \
+                             --port-base $PORT_BASE
+
+cd ..
+cp sytest/.coverage.* .
+
+# Combine the coverage reports
+echo "Combining:" .coverage.*
+$TOX_BIN/python -m coverage combine
+# Output coverage to coverage.xml
+$TOX_BIN/coverage xml -o coverage.xml

+ 25 - 0
jenkins-unittests.sh

@@ -0,0 +1,25 @@
+#!/bin/bash
+
+set -eux
+
+: ${WORKSPACE:="$(pwd)"}
+
+export PYTHONDONTWRITEBYTECODE=yep
+export SYNAPSE_CACHE_FACTOR=1
+
+# Output test results as junit xml
+export TRIAL_FLAGS="--reporter=subunit"
+export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
+# Write coverage reports to a separate file for each process
+export COVERAGE_OPTS="-p"
+export DUMP_COVERAGE_COMMAND="coverage help"
+
+# Output flake8 violations to violations.flake8.log
+# Don't exit with non-0 status code on Jenkins,
+# so that the build steps continue and a later step can decided whether to
+# UNSTABLE or FAILURE this build.
+export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
+
+rm .coverage* || echo "No coverage files to remove"
+
+tox -e py27

+ 6 - 1
jenkins.sh

@@ -1,6 +1,11 @@
-#!/bin/bash -eu
+#!/bin/bash
+
+set -eux
+
+: ${WORKSPACE:="$(pwd)"}
 
 export PYTHONDONTWRITEBYTECODE=yep
+export SYNAPSE_CACHE_FACTOR=1
 
 # Output test results as junit xml
 export TRIAL_FLAGS="--reporter=subunit"

+ 22 - 1
scripts-dev/definitions.py

@@ -86,9 +86,12 @@ def used_names(prefix, item, defs, names):
     for name, funcs in defs.get('class', {}).items():
         used_names(prefix + name + ".", name, funcs, names)
 
+    path = prefix.rstrip('.')
     for used in defs.get('uses', ()):
         if used in names:
-            names[used].setdefault('used', {}).setdefault(item, []).append(prefix.rstrip('.'))
+            if item:
+                names[item].setdefault('uses', []).append(used)
+            names[used].setdefault('used', {}).setdefault(item, []).append(path)
 
 
 if __name__ == '__main__':
@@ -113,6 +116,10 @@ if __name__ == '__main__':
         "--referrers", default=0, type=int,
         help="Include referrers up to the given depth"
     )
+    parser.add_argument(
+        "--referred", default=0, type=int,
+        help="Include referred down to the given depth"
+    )
     parser.add_argument(
         "--format", default="yaml",
         help="Output format, one of 'yaml' or 'dot'"
@@ -161,6 +168,20 @@ if __name__ == '__main__':
                 continue
             result[name] = definition
 
+    referred_depth = args.referred
+    referred = set()
+    while referred_depth:
+        referred_depth -= 1
+        for entry in result.values():
+            for uses in entry.get("uses", ()):
+                referred.add(uses)
+        for name, definition in names.items():
+            if not name in referred:
+                continue
+            if ignore and any(pattern.match(name) for pattern in ignore):
+                continue
+            result[name] = definition
+
     if args.format == 'yaml':
         yaml.dump(result, sys.stdout, default_flow_style=False)
     elif args.format == 'dot':

+ 67 - 0
scripts-dev/tail-synapse.py

@@ -0,0 +1,67 @@
+import requests
+import collections
+import sys
+import time
+import json
+
+Entry = collections.namedtuple("Entry", "name position rows")
+
+ROW_TYPES = {}
+
+
+def row_type_for_columns(name, column_names):
+    column_names = tuple(column_names)
+    row_type = ROW_TYPES.get((name, column_names))
+    if row_type is None:
+        row_type = collections.namedtuple(name, column_names)
+        ROW_TYPES[(name, column_names)] = row_type
+    return row_type
+
+
+def parse_response(content):
+    streams = json.loads(content)
+    result = {}
+    for name, value in streams.items():
+        row_type = row_type_for_columns(name, value["field_names"])
+        position = value["position"]
+        rows = [row_type(*row) for row in value["rows"]]
+        result[name] = Entry(name, position, rows)
+    return result
+
+
+def replicate(server, streams):
+    return parse_response(requests.get(
+        server + "/_synapse/replication",
+        verify=False,
+        params=streams
+    ).content)
+
+
+def main():
+    server = sys.argv[1]
+
+    streams = None
+    while not streams:
+        try:
+            streams = {
+                row.name: row.position
+                for row in replicate(server, {"streams":"-1"})["streams"].rows
+            }
+        except requests.exceptions.ConnectionError as e:
+            time.sleep(0.1)
+
+    while True:
+        try:
+            results = replicate(server, streams)
+        except:
+            sys.stdout.write("connection_lost("+ repr(streams) + ")\n")
+            break
+        for update in results.values():
+            for row in update.rows:
+                sys.stdout.write(repr(row) + "\n")
+            streams[update.name] = update.position
+
+
+
+if __name__=='__main__':
+    main()

+ 0 - 1
scripts/gen_password

@@ -1 +0,0 @@
-perl -MCrypt::Random -MCrypt::Eksblowfish::Bcrypt -e 'print Crypt::Eksblowfish::Bcrypt::bcrypt("secret", "\$2\$12\$" . Crypt::Eksblowfish::Bcrypt::en_base64(Crypt::Random::makerandom_octet(Length=>16)))."\n"'

+ 39 - 0
scripts/hash_password

@@ -0,0 +1,39 @@
+#!/usr/bin/env python
+
+import argparse
+import bcrypt
+import getpass
+
+bcrypt_rounds=12
+
+def prompt_for_pass():
+    password = getpass.getpass("Password: ")
+
+    if not password:
+        raise Exception("Password cannot be blank.")
+
+    confirm_password = getpass.getpass("Confirm password: ")
+
+    if password != confirm_password:
+        raise Exception("Passwords do not match.")
+
+    return password
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(
+        description="Calculate the hash of a new password, so that passwords"
+                    " can be reset")
+    parser.add_argument(
+        "-p", "--password",
+        default=None,
+        help="New password for user. Will prompt if omitted.",
+    )
+
+    args = parser.parse_args()
+    password = args.password
+
+    if not password:
+        password = prompt_for_pass()
+
+    print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds))
+

+ 7 - 2
scripts/synapse_port_db

@@ -309,8 +309,8 @@ class Porter(object):
                 **self.postgres_config["args"]
             )
 
-            sqlite_engine = create_engine("sqlite3")
-            postgres_engine = create_engine("psycopg2")
+            sqlite_engine = create_engine(FakeConfig(sqlite_config))
+            postgres_engine = create_engine(FakeConfig(postgres_config))
 
             self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
             self.postgres_store = Store(postgres_db_pool, postgres_engine)
@@ -792,3 +792,8 @@ if __name__ == "__main__":
     if end_error_exec_info:
         exc_type, exc_value, exc_traceback = end_error_exec_info
         traceback.print_exception(exc_type, exc_value, exc_traceback)
+
+
+class FakeConfig:
+    def __init__(self, database_config):
+        self.database_config = database_config

+ 1 - 1
synapse/__init__.py

@@ -16,4 +16,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.13.3"
+__version__ = "0.14.0"

+ 88 - 27
synapse/api/auth.py

@@ -434,31 +434,46 @@ class Auth(object):
 
         if event.user_id != invite_event.user_id:
             return False
-        try:
-            public_key = invite_event.content["public_key"]
-            if signed["mxid"] != event.state_key:
-                return False
-            if signed["token"] != token:
-                return False
-            for server, signature_block in signed["signatures"].items():
-                for key_name, encoded_signature in signature_block.items():
-                    if not key_name.startswith("ed25519:"):
-                        return False
-                    verify_key = decode_verify_key_bytes(
-                        key_name,
-                        decode_base64(public_key)
-                    )
-                    verify_signed_json(signed, server, verify_key)
 
-                    # We got the public key from the invite, so we know that the
-                    # correct server signed the signed bundle.
-                    # The caller is responsible for checking that the signing
-                    # server has not revoked that public key.
-                    return True
+        if signed["mxid"] != event.state_key:
             return False
-        except (KeyError, SignatureVerifyException,):
+        if signed["token"] != token:
             return False
 
+        for public_key_object in self.get_public_keys(invite_event):
+            public_key = public_key_object["public_key"]
+            try:
+                for server, signature_block in signed["signatures"].items():
+                    for key_name, encoded_signature in signature_block.items():
+                        if not key_name.startswith("ed25519:"):
+                            continue
+                        verify_key = decode_verify_key_bytes(
+                            key_name,
+                            decode_base64(public_key)
+                        )
+                        verify_signed_json(signed, server, verify_key)
+
+                        # We got the public key from the invite, so we know that the
+                        # correct server signed the signed bundle.
+                        # The caller is responsible for checking that the signing
+                        # server has not revoked that public key.
+                        return True
+            except (KeyError, SignatureVerifyException,):
+                continue
+        return False
+
+    def get_public_keys(self, invite_event):
+        public_keys = []
+        if "public_key" in invite_event.content:
+            o = {
+                "public_key": invite_event.content["public_key"],
+            }
+            if "key_validity_url" in invite_event.content:
+                o["key_validity_url"] = invite_event.content["key_validity_url"]
+            public_keys.append(o)
+        public_keys.extend(invite_event.content.get("public_keys", []))
+        return public_keys
+
     def _get_power_level_event(self, auth_events):
         key = (EventTypes.PowerLevels, "", )
         return auth_events.get(key)
@@ -519,7 +534,7 @@ class Auth(object):
                 )
 
             access_token = request.args["access_token"][0]
-            user_info = yield self._get_user_by_access_token(access_token)
+            user_info = yield self.get_user_by_access_token(access_token)
             user = user_info["user"]
             token_id = user_info["token_id"]
             is_guest = user_info["is_guest"]
@@ -580,7 +595,7 @@ class Auth(object):
         defer.returnValue(user_id)
 
     @defer.inlineCallbacks
-    def _get_user_by_access_token(self, token):
+    def get_user_by_access_token(self, token):
         """ Get a registered user's ID.
 
         Args:
@@ -799,17 +814,16 @@ class Auth(object):
 
         return auth_ids
 
-    @log_function
-    def _can_send_event(self, event, auth_events):
+    def _get_send_level(self, etype, state_key, auth_events):
         key = (EventTypes.PowerLevels, "", )
         send_level_event = auth_events.get(key)
         send_level = None
         if send_level_event:
             send_level = send_level_event.content.get("events", {}).get(
-                event.type
+                etype
             )
             if send_level is None:
-                if hasattr(event, "state_key"):
+                if state_key is not None:
                     send_level = send_level_event.content.get(
                         "state_default", 50
                     )
@@ -823,6 +837,13 @@ class Auth(object):
         else:
             send_level = 0
 
+        return send_level
+
+    @log_function
+    def _can_send_event(self, event, auth_events):
+        send_level = self._get_send_level(
+            event.type, event.get("state_key", None), auth_events
+        )
         user_level = self._get_user_power_level(event.user_id, auth_events)
 
         if user_level < send_level:
@@ -967,3 +988,43 @@ class Auth(object):
                     "You don't have permission to add ops level greater "
                     "than your own"
                 )
+
+    @defer.inlineCallbacks
+    def check_can_change_room_list(self, room_id, user):
+        """Check if the user is allowed to edit the room's entry in the
+        published room list.
+
+        Args:
+            room_id (str)
+            user (UserID)
+        """
+
+        is_admin = yield self.is_server_admin(user)
+        if is_admin:
+            defer.returnValue(True)
+
+        user_id = user.to_string()
+        yield self.check_joined_room(room_id, user_id)
+
+        # We currently require the user is a "moderator" in the room. We do this
+        # by checking if they would (theoretically) be able to change the
+        # m.room.aliases events
+        power_level_event = yield self.state.get_current_state(
+            room_id, EventTypes.PowerLevels, ""
+        )
+
+        auth_events = {}
+        if power_level_event:
+            auth_events[(EventTypes.PowerLevels, "")] = power_level_event
+
+        send_level = self._get_send_level(
+            EventTypes.Aliases, "", auth_events
+        )
+        user_level = self._get_user_power_level(user_id, auth_events)
+
+        if user_level < send_level:
+            raise AuthError(
+                403,
+                "This server requires you to be a moderator in the room to"
+                " edit its room list entry"
+            )

+ 0 - 1
synapse/api/constants.py

@@ -32,7 +32,6 @@ class PresenceState(object):
     OFFLINE = u"offline"
     UNAVAILABLE = u"unavailable"
     ONLINE = u"online"
-    FREE_FOR_CHAT = u"free_for_chat"
 
 
 class JoinRules(object):

+ 4 - 1
synapse/api/filtering.py

@@ -198,7 +198,10 @@ class Filter(object):
         sender = event.get("sender", None)
         if not sender:
             # Presence events have their 'sender' in content.user_id
-            sender = event.get("content", {}).get("user_id", None)
+            content = event.get("content")
+            # account_data has been allowed to have non-dict content, so check type first
+            if isinstance(content, dict):
+                sender = content.get("user_id")
 
         return self.check_fields(
             event.get("room_id", None),

+ 6 - 2
synapse/app/homeserver.py

@@ -63,6 +63,7 @@ from synapse.config.homeserver import HomeServerConfig
 from synapse.crypto import context_factory
 from synapse.util.logcontext import LoggingContext
 from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
+from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
 from synapse.federation.transport.server import TransportLayerServer
 
 from synapse import events
@@ -169,6 +170,9 @@ class SynapseHomeServer(HomeServer):
                 if name == "metrics" and self.get_config().enable_metrics:
                     resources[METRICS_PREFIX] = MetricsResource(self)
 
+                if name == "replication":
+                    resources[REPLICATION_PREFIX] = ReplicationResource(self)
+
         root_resource = create_resource_tree(resources)
         if tls:
             reactor.listenSSL(
@@ -382,7 +386,7 @@ def setup(config_options):
 
     tls_server_context_factory = context_factory.ServerContextFactory(config)
 
-    database_engine = create_engine(config.database_config["name"])
+    database_engine = create_engine(config)
     config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
 
     hs = SynapseHomeServer(
@@ -718,7 +722,7 @@ def run(hs):
     if hs.config.daemonize:
 
         if hs.config.print_pidfile:
-            print hs.config.pid_file
+            print (hs.config.pid_file)
 
         daemon = Daemonize(
             app="synapse-homeserver",

+ 3 - 3
synapse/app/synctl.py

@@ -29,13 +29,13 @@ NORMAL = "\x1b[m"
 
 
 def start(configfile):
-    print "Starting ...",
+    print ("Starting ...")
     args = SYNAPSE
     args.extend(["--daemonize", "-c", configfile])
 
     try:
         subprocess.check_call(args)
-        print GREEN + "started" + NORMAL
+        print (GREEN + "started" + NORMAL)
     except subprocess.CalledProcessError as e:
         print (
             RED +
@@ -48,7 +48,7 @@ def stop(pidfile):
     if os.path.exists(pidfile):
         pid = int(open(pidfile).read())
         os.kill(pid, signal.SIGTERM)
-        print GREEN + "stopped" + NORMAL
+        print (GREEN + "stopped" + NORMAL)
 
 
 def main():

+ 1 - 1
synapse/config/__main__.py

@@ -28,7 +28,7 @@ if __name__ == "__main__":
             sys.stderr.write("\n" + e.message + "\n")
             sys.exit(1)
 
-        print getattr(config, key)
+        print (getattr(config, key))
         sys.exit(0)
     else:
         sys.stderr.write("Unknown command %r\n" % (action,))

+ 1 - 1
synapse/config/_base.py

@@ -104,7 +104,7 @@ class Config(object):
         dir_path = cls.abspath(dir_path)
         try:
             os.makedirs(dir_path)
-        except OSError, e:
+        except OSError as e:
             if e.errno != errno.EEXIST:
                 raise
         if not os.path.isdir(dir_path):

+ 40 - 0
synapse/config/api.py

@@ -0,0 +1,40 @@
+# Copyright 2015, 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ._base import Config
+
+from synapse.api.constants import EventTypes
+
+
+class ApiConfig(Config):
+
+    def read_config(self, config):
+        self.room_invite_state_types = config.get("room_invite_state_types", [
+            EventTypes.JoinRules,
+            EventTypes.CanonicalAlias,
+            EventTypes.RoomAvatar,
+            EventTypes.Name,
+        ])
+
+    def default_config(cls, **kwargs):
+        return """\
+        ## API Configuration ##
+
+        # A list of event types that will be included in the room_invite_state
+        room_invite_state_types:
+            - "{JoinRules}"
+            - "{CanonicalAlias}"
+            - "{RoomAvatar}"
+            - "{Name}"
+        """.format(**vars(EventTypes))

+ 2 - 1
synapse/config/homeserver.py

@@ -23,6 +23,7 @@ from .captcha import CaptchaConfig
 from .voip import VoipConfig
 from .registration import RegistrationConfig
 from .metrics import MetricsConfig
+from .api import ApiConfig
 from .appservice import AppServiceConfig
 from .key import KeyConfig
 from .saml2 import SAML2Config
@@ -32,7 +33,7 @@ from .password import PasswordConfig
 
 class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
                        RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
-                       VoipConfig, RegistrationConfig, MetricsConfig,
+                       VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
                        AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
                        PasswordConfig,):
     pass

+ 4 - 0
synapse/config/registration.py

@@ -37,6 +37,10 @@ class RegistrationConfig(Config):
         self.trusted_third_party_id_servers = config["trusted_third_party_id_servers"]
         self.allow_guest_access = config.get("allow_guest_access", False)
 
+        self.invite_3pid_guest = (
+            self.allow_guest_access and config.get("invite_3pid_guest", False)
+        )
+
     def default_config(self, **kwargs):
         registration_shared_secret = random_string_with_symbols(50)
 

+ 3 - 0
synapse/config/repository.py

@@ -97,4 +97,7 @@ class ContentRepositoryConfig(Config):
         - width: 640
           height: 480
           method: scale
+        - width: 800
+          height: 600
+          method: scale
         """ % locals()

+ 3 - 0
synapse/crypto/keyclient.py

@@ -36,6 +36,7 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
 
     factory = SynapseKeyClientFactory()
     factory.path = path
+    factory.host = server_name
     endpoint = matrix_federation_endpoint(
         reactor, server_name, ssl_context_factory, timeout=30
     )
@@ -81,6 +82,8 @@ class SynapseKeyClientProtocol(HTTPClient):
         self.host = self.transport.getHost()
         logger.debug("Connected to %s", self.host)
         self.sendCommand(b"GET", self.path)
+        if self.host:
+            self.sendHeader(b"Host", self.host)
         self.endHeaders()
         self.timer = reactor.callLater(
             self.timeout,

+ 8 - 1
synapse/events/__init__.py

@@ -14,6 +14,7 @@
 # limitations under the License.
 
 from synapse.util.frozenutils import freeze
+from synapse.util.caches import intern_dict
 
 
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
@@ -140,6 +141,10 @@ class FrozenEvent(EventBase):
 
         unsigned = dict(event_dict.pop("unsigned", {}))
 
+        # We intern these strings because they turn up a lot (especially when
+        # caching).
+        event_dict = intern_dict(event_dict)
+
         if USE_FROZEN_DICTS:
             frozen_dict = freeze(event_dict)
         else:
@@ -168,5 +173,7 @@ class FrozenEvent(EventBase):
 
     def __repr__(self):
         return "<FrozenEvent event_id='%s', type='%s', state_key='%s'>" % (
-            self.event_id, self.type, self.get("state_key", None),
+            self.get("event_id", None),
+            self.get("type", None),
+            self.get("state_key", None),
         )

+ 2 - 1
synapse/federation/federation_client.py

@@ -114,7 +114,7 @@ class FederationClient(FederationBase):
 
     @log_function
     def make_query(self, destination, query_type, args,
-                   retry_on_dns_fail=True):
+                   retry_on_dns_fail=False):
         """Sends a federation Query to a remote homeserver of the given type
         and arguments.
 
@@ -418,6 +418,7 @@ class FederationClient(FederationBase):
                     "Failed to make_%s via %s: %s",
                     membership, destination, e.message
                 )
+                raise
 
         raise RuntimeError("Failed to send to any server.")
 

+ 22 - 6
synapse/federation/federation_server.py

@@ -137,8 +137,8 @@ class FederationServer(FederationBase):
                 logger.exception("Failed to handle PDU")
 
         if hasattr(transaction, "edus"):
-            for edu in [Edu(**x) for x in transaction.edus]:
-                self.received_edu(
+            for edu in (Edu(**x) for x in transaction.edus):
+                yield self.received_edu(
                     transaction.origin,
                     edu.edu_type,
                     edu.content
@@ -161,11 +161,17 @@ class FederationServer(FederationBase):
         )
         defer.returnValue((200, response))
 
+    @defer.inlineCallbacks
     def received_edu(self, origin, edu_type, content):
         received_edus_counter.inc()
 
         if edu_type in self.edu_handlers:
-            self.edu_handlers[edu_type](origin, content)
+            try:
+                yield self.edu_handlers[edu_type](origin, content)
+            except SynapseError as e:
+                logger.info("Failed to handle edu %r: %r", edu_type, e)
+            except Exception as e:
+                logger.exception("Failed to handle edu %r", edu_type, e)
         else:
             logger.warn("Received EDU of type %s with no handler", edu_type)
 
@@ -525,7 +531,6 @@ class FederationServer(FederationBase):
         yield self.handler.on_receive_pdu(
             origin,
             pdu,
-            backfilled=False,
             state=state,
             auth_chain=auth_chain,
         )
@@ -543,8 +548,19 @@ class FederationServer(FederationBase):
         return event
 
     @defer.inlineCallbacks
-    def exchange_third_party_invite(self, invite):
-        ret = yield self.handler.exchange_third_party_invite(invite)
+    def exchange_third_party_invite(
+            self,
+            sender_user_id,
+            target_user_id,
+            room_id,
+            signed,
+    ):
+        ret = yield self.handler.exchange_third_party_invite(
+            sender_user_id,
+            target_user_id,
+            room_id,
+            signed,
+        )
         defer.returnValue(ret)
 
     @defer.inlineCallbacks

+ 1 - 0
synapse/federation/transport/client.py

@@ -160,6 +160,7 @@ class TransportLayerClient(object):
             path=path,
             args=args,
             retry_on_dns_fail=retry_on_dns_fail,
+            timeout=10000,
         )
 
         defer.returnValue(content)

+ 27 - 17
synapse/federation/transport/server.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.api.urls import FEDERATION_PREFIX as PREFIX
 from synapse.api.errors import Codes, SynapseError
 from synapse.http.server import JsonResource
+from synapse.http.servlet import parse_json_object_from_request
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 import functools
@@ -174,7 +175,7 @@ class BaseFederationServlet(object):
 
 
 class FederationSendServlet(BaseFederationServlet):
-    PATH = "/send/([^/]*)/"
+    PATH = "/send/(?P<transaction_id>[^/]*)/"
 
     def __init__(self, handler, server_name, **kwargs):
         super(FederationSendServlet, self).__init__(
@@ -249,7 +250,7 @@ class FederationPullServlet(BaseFederationServlet):
 
 
 class FederationEventServlet(BaseFederationServlet):
-    PATH = "/event/([^/]*)/"
+    PATH = "/event/(?P<event_id>[^/]*)/"
 
     # This is when someone asks for a data item for a given server data_id pair.
     def on_GET(self, origin, content, query, event_id):
@@ -257,7 +258,7 @@ class FederationEventServlet(BaseFederationServlet):
 
 
 class FederationStateServlet(BaseFederationServlet):
-    PATH = "/state/([^/]*)/"
+    PATH = "/state/(?P<context>[^/]*)/"
 
     # This is when someone asks for all data for a given context.
     def on_GET(self, origin, content, query, context):
@@ -269,7 +270,7 @@ class FederationStateServlet(BaseFederationServlet):
 
 
 class FederationBackfillServlet(BaseFederationServlet):
-    PATH = "/backfill/([^/]*)/"
+    PATH = "/backfill/(?P<context>[^/]*)/"
 
     def on_GET(self, origin, content, query, context):
         versions = query["v"]
@@ -284,7 +285,7 @@ class FederationBackfillServlet(BaseFederationServlet):
 
 
 class FederationQueryServlet(BaseFederationServlet):
-    PATH = "/query/([^/]*)"
+    PATH = "/query/(?P<query_type>[^/]*)"
 
     # This is when we receive a server-server Query
     def on_GET(self, origin, content, query, query_type):
@@ -295,7 +296,7 @@ class FederationQueryServlet(BaseFederationServlet):
 
 
 class FederationMakeJoinServlet(BaseFederationServlet):
-    PATH = "/make_join/([^/]*)/([^/]*)"
+    PATH = "/make_join/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
@@ -304,7 +305,7 @@ class FederationMakeJoinServlet(BaseFederationServlet):
 
 
 class FederationMakeLeaveServlet(BaseFederationServlet):
-    PATH = "/make_leave/([^/]*)/([^/]*)"
+    PATH = "/make_leave/(?P<context>[^/]*)/(?P<user_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query, context, user_id):
@@ -313,7 +314,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
 
 
 class FederationSendLeaveServlet(BaseFederationServlet):
-    PATH = "/send_leave/([^/]*)/([^/]*)"
+    PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<txid>[^/]*)"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, room_id, txid):
@@ -322,14 +323,14 @@ class FederationSendLeaveServlet(BaseFederationServlet):
 
 
 class FederationEventAuthServlet(BaseFederationServlet):
-    PATH = "/event_auth/([^/]*)/([^/]*)"
+    PATH = "/event_auth(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
     def on_GET(self, origin, content, query, context, event_id):
         return self.handler.on_event_auth(origin, context, event_id)
 
 
 class FederationSendJoinServlet(BaseFederationServlet):
-    PATH = "/send_join/([^/]*)/([^/]*)"
+    PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, context, event_id):
@@ -340,7 +341,7 @@ class FederationSendJoinServlet(BaseFederationServlet):
 
 
 class FederationInviteServlet(BaseFederationServlet):
-    PATH = "/invite/([^/]*)/([^/]*)"
+    PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, context, event_id):
@@ -351,7 +352,7 @@ class FederationInviteServlet(BaseFederationServlet):
 
 
 class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
-    PATH = "/exchange_third_party_invite/([^/]*)"
+    PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_PUT(self, origin, content, query, room_id):
@@ -380,7 +381,7 @@ class FederationClientKeysClaimServlet(BaseFederationServlet):
 
 
 class FederationQueryAuthServlet(BaseFederationServlet):
-    PATH = "/query_auth/([^/]*)/([^/]*)"
+    PATH = "/query_auth/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, context, event_id):
@@ -393,7 +394,7 @@ class FederationQueryAuthServlet(BaseFederationServlet):
 
 class FederationGetMissingEventsServlet(BaseFederationServlet):
     # TODO(paul): Why does this path alone end with "/?" optional?
-    PATH = "/get_missing_events/([^/]*)/?"
+    PATH = "/get_missing_events/(?P<room_id>[^/]*)/?"
 
     @defer.inlineCallbacks
     def on_POST(self, origin, content, query, room_id):
@@ -419,13 +420,22 @@ class On3pidBindServlet(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        content_bytes = request.content.read()
-        content = json.loads(content_bytes)
+        content = parse_json_object_from_request(request)
         if "invites" in content:
             last_exception = None
             for invite in content["invites"]:
                 try:
-                    yield self.handler.exchange_third_party_invite(invite)
+                    if "signed" not in invite or "token" not in invite["signed"]:
+                        message = ("Rejecting received notification of third-"
+                                   "party invite without signed: %s" % (invite,))
+                        logger.info(message)
+                        raise SynapseError(400, message)
+                    yield self.handler.exchange_third_party_invite(
+                        invite["sender"],
+                        invite["mxid"],
+                        invite["room_id"],
+                        invite["signed"],
+                    )
                 except Exception as e:
                     last_exception = e
             if last_exception:

+ 171 - 37
synapse/handlers/_base.py

@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import LimitExceededError, SynapseError, AuthError
 from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.api.constants import Membership, EventTypes
-from synapse.types import UserID, RoomAlias
+from synapse.types import UserID, RoomAlias, Requester
 from synapse.push.action_generator import ActionGenerator
 
 from synapse.util.logcontext import PreserveLoggingContext
@@ -29,6 +29,14 @@ import logging
 logger = logging.getLogger(__name__)
 
 
+VISIBILITY_PRIORITY = (
+    "world_readable",
+    "shared",
+    "invited",
+    "joined",
+)
+
+
 class BaseHandler(object):
     """
     Common base class for the event handlers.
@@ -53,9 +61,15 @@ class BaseHandler(object):
         self.event_builder_factory = hs.get_event_builder_factory()
 
     @defer.inlineCallbacks
-    def _filter_events_for_clients(self, user_tuples, events, event_id_to_state):
+    def filter_events_for_clients(self, user_tuples, events, event_id_to_state):
         """ Returns dict of user_id -> list of events that user is allowed to
         see.
+
+        :param (str, bool) user_tuples: (user id, is_peeking) for each
+            user to be checked. is_peeking should be true if:
+              * the user is not currently a member of the room, and:
+              * the user has not been a member of the room since the given
+                events
         """
         forgotten = yield defer.gatherResults([
             self.store.who_forgot_in_room(
@@ -72,18 +86,38 @@ class BaseHandler(object):
         def allowed(event, user_id, is_peeking):
             state = event_id_to_state[event.event_id]
 
+            # get the room_visibility at the time of the event.
             visibility_event = state.get((EventTypes.RoomHistoryVisibility, ""), None)
             if visibility_event:
                 visibility = visibility_event.content.get("history_visibility", "shared")
             else:
                 visibility = "shared"
 
+            if visibility not in VISIBILITY_PRIORITY:
+                visibility = "shared"
+
+            # if it was world_readable, it's easy: everyone can read it
             if visibility == "world_readable":
                 return True
 
-            if is_peeking:
-                return False
+            # Always allow history visibility events on boundaries. This is done
+            # by setting the effective visibility to the least restrictive
+            # of the old vs new.
+            if event.type == EventTypes.RoomHistoryVisibility:
+                prev_content = event.unsigned.get("prev_content", {})
+                prev_visibility = prev_content.get("history_visibility", None)
+
+                if prev_visibility not in VISIBILITY_PRIORITY:
+                    prev_visibility = "shared"
+
+                new_priority = VISIBILITY_PRIORITY.index(visibility)
+                old_priority = VISIBILITY_PRIORITY.index(prev_visibility)
+                if old_priority < new_priority:
+                    visibility = prev_visibility
 
+            # get the user's membership at the time of the event. (or rather,
+            # just *after* the event. Which means that people can see their
+            # own join events, but not (currently) their own leave events.)
             membership_event = state.get((EventTypes.Member, user_id), None)
             if membership_event:
                 if membership_event.event_id in event_id_forgotten:
@@ -93,20 +127,29 @@ class BaseHandler(object):
             else:
                 membership = None
 
+            # if the user was a member of the room at the time of the event,
+            # they can see it.
             if membership == Membership.JOIN:
                 return True
 
-            if event.type == EventTypes.RoomHistoryVisibility:
-                return not is_peeking
+            if visibility == "joined":
+                # we weren't a member at the time of the event, so we can't
+                # see this event.
+                return False
 
-            if visibility == "shared":
-                return True
-            elif visibility == "joined":
-                return membership == Membership.JOIN
             elif visibility == "invited":
+                # user can also see the event if they were *invited* at the time
+                # of the event.
                 return membership == Membership.INVITE
 
-            return True
+            else:
+                # visibility is shared: user can also see the event if they have
+                # become a member since the event
+                #
+                # XXX: if the user has subsequently joined and then left again,
+                # ideally we would share history up to the point they left. But
+                # we don't know when they left.
+                return not is_peeking
 
         defer.returnValue({
             user_id: [
@@ -119,7 +162,17 @@ class BaseHandler(object):
 
     @defer.inlineCallbacks
     def _filter_events_for_client(self, user_id, events, is_peeking=False):
-        # Assumes that user has at some point joined the room if not is_guest.
+        """
+        Check which events a user is allowed to see
+
+        :param str user_id: user id to be checked
+        :param [synapse.events.EventBase] events: list of events to be checked
+        :param bool is_peeking should be True if:
+              * the user is not currently a member of the room, and:
+              * the user has not been a member of the room since the given
+                events
+        :rtype [synapse.events.EventBase]
+        """
         types = (
             (EventTypes.RoomHistoryVisibility, ""),
             (EventTypes.Member, user_id),
@@ -128,15 +181,15 @@ class BaseHandler(object):
             frozenset(e.event_id for e in events),
             types=types
         )
-        res = yield self._filter_events_for_clients(
+        res = yield self.filter_events_for_clients(
             [(user_id, is_peeking)], events, event_id_to_state
         )
         defer.returnValue(res.get(user_id, []))
 
-    def ratelimit(self, user_id):
+    def ratelimit(self, requester):
         time_now = self.clock.time()
         allowed, time_allowed = self.ratelimiter.send_message(
-            user_id, time_now,
+            requester.user.to_string(), time_now,
             msg_rate_hz=self.hs.config.rc_messages_per_second,
             burst_count=self.hs.config.rc_message_burst_count,
         )
@@ -147,7 +200,7 @@ class BaseHandler(object):
 
     @defer.inlineCallbacks
     def _create_new_client_event(self, builder):
-        latest_ret = yield self.store.get_latest_events_in_room(
+        latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
             builder.room_id,
         )
 
@@ -156,7 +209,10 @@ class BaseHandler(object):
         else:
             depth = 1
 
-        prev_events = [(e, h) for e, h, _ in latest_ret]
+        prev_events = [
+            (event_id, prev_hashes)
+            for event_id, prev_hashes, _ in latest_ret
+        ]
 
         builder.prev_events = prev_events
         builder.depth = depth
@@ -165,6 +221,50 @@ class BaseHandler(object):
 
         context = yield state_handler.compute_event_context(builder)
 
+        # If we've received an invite over federation, there are no latest
+        # events in the room, because we don't know enough about the graph
+        # fragment we received to treat it like a graph, so the above returned
+        # no relevant events. It may have returned some events (if we have
+        # joined and left the room), but not useful ones, like the invite.
+        if (
+            not self.is_host_in_room(context.current_state) and
+            builder.type == EventTypes.Member
+        ):
+            prev_member_event = yield self.store.get_room_member(
+                builder.sender, builder.room_id
+            )
+
+            # The prev_member_event may already be in context.current_state,
+            # despite us not being present in the room; in particular, if
+            # inviting user, and all other local users, have already left.
+            #
+            # In that case, we have all the information we need, and we don't
+            # want to drop "context" - not least because we may need to handle
+            # the invite locally, which will require us to have the whole
+            # context (not just prev_member_event) to auth it.
+            #
+            context_event_ids = (
+                e.event_id for e in context.current_state.values()
+            )
+
+            if (
+                prev_member_event and
+                prev_member_event.event_id not in context_event_ids
+            ):
+                # The prev_member_event is missing from context, so it must
+                # have arrived over federation and is an outlier. We forcibly
+                # set our context to the invite we received over federation
+                builder.prev_events = (
+                    prev_member_event.event_id,
+                    prev_member_event.prev_events
+                )
+
+                context = yield state_handler.compute_event_context(
+                    builder,
+                    old_state=(prev_member_event,),
+                    outlier=True
+                )
+
         if builder.is_state():
             builder.prev_state = yield self.store.add_event_hashes(
                 context.prev_state_events
@@ -187,10 +287,40 @@ class BaseHandler(object):
             (event, context,)
         )
 
+    def is_host_in_room(self, current_state):
+        room_members = [
+            (state_key, event.membership)
+            for ((event_type, state_key), event) in current_state.items()
+            if event_type == EventTypes.Member
+        ]
+        if len(room_members) == 0:
+            # Have we just created the room, and is this about to be the very
+            # first member event?
+            create_event = current_state.get(("m.room.create", ""))
+            if create_event:
+                return True
+        for (state_key, membership) in room_members:
+            if (
+                UserID.from_string(state_key).domain == self.hs.hostname
+                and membership == Membership.JOIN
+            ):
+                return True
+        return False
+
     @defer.inlineCallbacks
-    def handle_new_client_event(self, event, context, extra_users=[]):
+    def handle_new_client_event(
+        self,
+        requester,
+        event,
+        context,
+        ratelimit=True,
+        extra_users=[]
+    ):
         # We now need to go and hit out to wherever we need to hit out to.
 
+        if ratelimit:
+            self.ratelimit(requester)
+
         self.auth.check(event, auth_events=context.current_state)
 
         yield self.maybe_kick_guest_users(event, context.current_state.values())
@@ -215,6 +345,12 @@ class BaseHandler(object):
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
+                def is_inviter_member_event(e):
+                    return (
+                        e.type == EventTypes.Member and
+                        e.sender == event.sender
+                    )
+
                 event.unsigned["invite_room_state"] = [
                     {
                         "type": e.type,
@@ -223,12 +359,8 @@ class BaseHandler(object):
                         "sender": e.sender,
                     }
                     for k, e in context.current_state.items()
-                    if e.type in (
-                        EventTypes.JoinRules,
-                        EventTypes.CanonicalAlias,
-                        EventTypes.RoomAvatar,
-                        EventTypes.Name,
-                    )
+                    if e.type in self.hs.config.room_invite_state_types
+                    or is_inviter_member_event(e)
                 ]
 
                 invitee = UserID.from_string(event.state_key)
@@ -264,6 +396,12 @@ class BaseHandler(object):
                         "You don't have permission to redact events"
                     )
 
+        if event.type == EventTypes.Create and context.current_state:
+            raise AuthError(
+                403,
+                "Changing the room create event is forbidden",
+            )
+
         action_generator = ActionGenerator(self.hs)
         yield action_generator.handle_push_actions_for_event(
             event, context, self
@@ -316,7 +454,8 @@ class BaseHandler(object):
                 if member_event.type != EventTypes.Member:
                     continue
 
-                if not self.hs.is_mine(UserID.from_string(member_event.state_key)):
+                target_user = UserID.from_string(member_event.state_key)
+                if not self.hs.is_mine(target_user):
                     continue
 
                 if member_event.content["membership"] not in {
@@ -338,18 +477,13 @@ class BaseHandler(object):
                 # and having homeservers have their own users leave keeps more
                 # of that decision-making and control local to the guest-having
                 # homeserver.
-                message_handler = self.hs.get_handlers().message_handler
-                yield message_handler.create_and_send_event(
-                    {
-                        "type": EventTypes.Member,
-                        "state_key": member_event.state_key,
-                        "content": {
-                            "membership": Membership.LEAVE,
-                            "kind": "guest"
-                        },
-                        "room_id": member_event.room_id,
-                        "sender": member_event.state_key
-                    },
+                requester = Requester(target_user, "", True)
+                handler = self.hs.get_handlers().room_member_handler
+                yield handler.update_membership(
+                    requester,
+                    target_user,
+                    member_event.room_id,
+                    "leave",
                     ratelimit=False,
                 )
             except Exception as e:

+ 70 - 15
synapse/handlers/auth.py

@@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
 
 
 class AuthHandler(BaseHandler):
+    SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
     def __init__(self, hs):
         super(AuthHandler, self).__init__(hs)
@@ -66,15 +67,18 @@ class AuthHandler(BaseHandler):
                         'auth' key: this method prompts for auth if none is sent.
             clientip (str): The IP address of the client.
         Returns:
-            A tuple of (authed, dict, dict) where authed is true if the client
-            has successfully completed an auth flow. If it is true, the first
-            dict contains the authenticated credentials of each stage.
+            A tuple of (authed, dict, dict, session_id) where authed is true if
+            the client has successfully completed an auth flow. If it is true
+            the first dict contains the authenticated credentials of each stage.
 
             If authed is false, the first dictionary is the server response to
             the login request and should be passed back to the client.
 
             In either case, the second dict contains the parameters for this
             request (which may have been given only in a previous call).
+
+            session_id is the ID of this session, either passed in by the client
+            or assigned by the call to check_auth
         """
 
         authdict = None
@@ -103,7 +107,10 @@ class AuthHandler(BaseHandler):
 
         if not authdict:
             defer.returnValue(
-                (False, self._auth_dict_for_flows(flows, session), clientdict)
+                (
+                    False, self._auth_dict_for_flows(flows, session),
+                    clientdict, session['id']
+                )
             )
 
         if 'creds' not in session:
@@ -122,12 +129,11 @@ class AuthHandler(BaseHandler):
         for f in flows:
             if len(set(f) - set(creds.keys())) == 0:
                 logger.info("Auth completed with creds: %r", creds)
-                self._remove_session(session)
-                defer.returnValue((True, creds, clientdict))
+                defer.returnValue((True, creds, clientdict, session['id']))
 
         ret = self._auth_dict_for_flows(flows, session)
         ret['completed'] = creds.keys()
-        defer.returnValue((False, ret, clientdict))
+        defer.returnValue((False, ret, clientdict, session['id']))
 
     @defer.inlineCallbacks
     def add_oob_auth(self, stagetype, authdict, clientip):
@@ -154,6 +160,43 @@ class AuthHandler(BaseHandler):
             defer.returnValue(True)
         defer.returnValue(False)
 
+    def get_session_id(self, clientdict):
+        """
+        Gets the session ID for a client given the client dictionary
+        :param clientdict: The dictionary sent by the client in the request
+        :return: The string session ID the client sent. If the client did not
+                 send a session ID, returns None.
+        """
+        sid = None
+        if clientdict and 'auth' in clientdict:
+            authdict = clientdict['auth']
+            if 'session' in authdict:
+                sid = authdict['session']
+        return sid
+
+    def set_session_data(self, session_id, key, value):
+        """
+        Store a key-value pair into the sessions data associated with this
+        request. This data is stored server-side and cannot be modified by
+        the client.
+        :param session_id: (string) The ID of this session as returned from check_auth
+        :param key: (string) The key to store the data under
+        :param value: (any) The data to store
+        """
+        sess = self._get_session_info(session_id)
+        sess.setdefault('serverdict', {})[key] = value
+        self._save_session(sess)
+
+    def get_session_data(self, session_id, key, default=None):
+        """
+        Retrieve data stored with set_session_data
+        :param session_id: (string) The ID of this session as returned from check_auth
+        :param key: (string) The key to store the data under
+        :param default: (any) Value to return if the key has not been set
+        """
+        sess = self._get_session_info(session_id)
+        return sess.setdefault('serverdict', {}).get(key, default)
+
     @defer.inlineCallbacks
     def _check_password_auth(self, authdict, _):
         if "user" not in authdict or "password" not in authdict:
@@ -432,13 +475,18 @@ class AuthHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def set_password(self, user_id, newpassword):
+    def set_password(self, user_id, newpassword, requester=None):
         password_hash = self.hash(newpassword)
 
+        except_access_token_ids = [requester.access_token_id] if requester else []
+
         yield self.store.user_set_password_hash(user_id, password_hash)
-        yield self.store.user_delete_access_tokens(user_id)
-        yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
-        yield self.store.flush_user(user_id)
+        yield self.store.user_delete_access_tokens(
+            user_id, except_access_token_ids
+        )
+        yield self.hs.get_pusherpool().remove_pushers_by_user(
+            user_id, except_access_token_ids
+        )
 
     @defer.inlineCallbacks
     def add_threepid(self, user_id, medium, address, validated_at):
@@ -450,11 +498,18 @@ class AuthHandler(BaseHandler):
     def _save_session(self, session):
         # TODO: Persistent storage
         logger.debug("Saving session %s", session)
+        session["last_used"] = self.hs.get_clock().time_msec()
         self.sessions[session["id"]] = session
+        self._prune_sessions()
 
-    def _remove_session(self, session):
-        logger.debug("Removing session %s", session)
-        del self.sessions[session["id"]]
+    def _prune_sessions(self):
+        for sid, sess in self.sessions.items():
+            last_used = 0
+            if 'last_used' in sess:
+                last_used = sess['last_used']
+            now = self.hs.get_clock().time_msec()
+            if last_used < now - AuthHandler.SESSION_EXPIRE_MS:
+                del self.sessions[sid]
 
     def hash(self, password):
         """Computes a secure hash of password.
@@ -477,4 +532,4 @@ class AuthHandler(BaseHandler):
         Returns:
             Whether self.hash(password) == stored_hash (bool).
         """
-        return bcrypt.checkpw(password, stored_hash)
+        return bcrypt.hashpw(password, stored_hash) == stored_hash

+ 102 - 20
synapse/handlers/directory.py

@@ -17,9 +17,9 @@
 from twisted.internet import defer
 from ._base import BaseHandler
 
-from synapse.api.errors import SynapseError, Codes, CodeMessageException
+from synapse.api.errors import SynapseError, Codes, CodeMessageException, AuthError
 from synapse.api.constants import EventTypes
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserID
 
 import logging
 import string
@@ -32,13 +32,15 @@ class DirectoryHandler(BaseHandler):
     def __init__(self, hs):
         super(DirectoryHandler, self).__init__(hs)
 
+        self.state = hs.get_state_handler()
+
         self.federation = hs.get_replication_layer()
         self.federation.register_query_handler(
             "directory", self.on_directory_query
         )
 
     @defer.inlineCallbacks
-    def _create_association(self, room_alias, room_id, servers=None):
+    def _create_association(self, room_alias, room_id, servers=None, creator=None):
         # general association creation for both human users and app services
 
         for wchar in string.whitespace:
@@ -60,7 +62,8 @@ class DirectoryHandler(BaseHandler):
         yield self.store.create_room_alias_association(
             room_alias,
             room_id,
-            servers
+            servers,
+            creator=creator,
         )
 
     @defer.inlineCallbacks
@@ -77,7 +80,7 @@ class DirectoryHandler(BaseHandler):
                 400, "This alias is reserved by an application service.",
                 errcode=Codes.EXCLUSIVE
             )
-        yield self._create_association(room_alias, room_id, servers)
+        yield self._create_association(room_alias, room_id, servers, creator=user_id)
 
     @defer.inlineCallbacks
     def create_appservice_association(self, service, room_alias, room_id,
@@ -92,10 +95,14 @@ class DirectoryHandler(BaseHandler):
         yield self._create_association(room_alias, room_id, servers)
 
     @defer.inlineCallbacks
-    def delete_association(self, user_id, room_alias):
+    def delete_association(self, requester, user_id, room_alias):
         # association deletion for human users
 
-        # TODO Check if server admin
+        can_delete = yield self._user_can_delete_alias(room_alias, user_id)
+        if not can_delete:
+            raise AuthError(
+                403, "You don't have permission to delete the alias.",
+            )
 
         can_delete = yield self.can_modify_alias(
             room_alias,
@@ -107,7 +114,25 @@ class DirectoryHandler(BaseHandler):
                 errcode=Codes.EXCLUSIVE
             )
 
-        yield self._delete_association(room_alias)
+        room_id = yield self._delete_association(room_alias)
+
+        try:
+            yield self.send_room_alias_update_event(
+                requester,
+                requester.user.to_string(),
+                room_id
+            )
+
+            yield self._update_canonical_alias(
+                requester,
+                requester.user.to_string(),
+                room_id,
+                room_alias,
+            )
+        except AuthError as e:
+            logger.info("Failed to update alias events: %s", e)
+
+        defer.returnValue(room_id)
 
     @defer.inlineCallbacks
     def delete_appservice_association(self, service, room_alias):
@@ -124,11 +149,9 @@ class DirectoryHandler(BaseHandler):
         if not self.hs.is_mine(room_alias):
             raise SynapseError(400, "Room alias must be local")
 
-        yield self.store.delete_room_alias(room_alias)
+        room_id = yield self.store.delete_room_alias(room_alias)
 
-        # TODO - Looks like _update_room_alias_event has never been implemented
-        # if room_id:
-        #    yield self._update_room_alias_events(user_id, room_id)
+        defer.returnValue(room_id)
 
     @defer.inlineCallbacks
     def get_association(self, room_alias):
@@ -212,17 +235,44 @@ class DirectoryHandler(BaseHandler):
             )
 
     @defer.inlineCallbacks
-    def send_room_alias_update_event(self, user_id, room_id):
+    def send_room_alias_update_event(self, requester, user_id, room_id):
         aliases = yield self.store.get_aliases_for_room(room_id)
 
         msg_handler = self.hs.get_handlers().message_handler
-        yield msg_handler.create_and_send_event({
-            "type": EventTypes.Aliases,
-            "state_key": self.hs.hostname,
-            "room_id": room_id,
-            "sender": user_id,
-            "content": {"aliases": aliases},
-        }, ratelimit=False)
+        yield msg_handler.create_and_send_nonmember_event(
+            requester,
+            {
+                "type": EventTypes.Aliases,
+                "state_key": self.hs.hostname,
+                "room_id": room_id,
+                "sender": user_id,
+                "content": {"aliases": aliases},
+            },
+            ratelimit=False
+        )
+
+    @defer.inlineCallbacks
+    def _update_canonical_alias(self, requester, user_id, room_id, room_alias):
+        alias_event = yield self.state.get_current_state(
+            room_id, EventTypes.CanonicalAlias, ""
+        )
+
+        alias_str = room_alias.to_string()
+        if not alias_event or alias_event.content.get("alias", "") != alias_str:
+            return
+
+        msg_handler = self.hs.get_handlers().message_handler
+        yield msg_handler.create_and_send_nonmember_event(
+            requester,
+            {
+                "type": EventTypes.CanonicalAlias,
+                "state_key": "",
+                "room_id": room_id,
+                "sender": user_id,
+                "content": {},
+            },
+            ratelimit=False
+        )
 
     @defer.inlineCallbacks
     def get_association_from_room_alias(self, room_alias):
@@ -257,3 +307,35 @@ class DirectoryHandler(BaseHandler):
                 return
         # either no interested services, or no service with an exclusive lock
         defer.returnValue(True)
+
+    @defer.inlineCallbacks
+    def _user_can_delete_alias(self, alias, user_id):
+        creator = yield self.store.get_room_alias_creator(alias.to_string())
+
+        if creator and creator == user_id:
+            defer.returnValue(True)
+
+        is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
+        defer.returnValue(is_admin)
+
+    @defer.inlineCallbacks
+    def edit_published_room_list(self, requester, room_id, visibility):
+        """Edit the entry of the room in the published room list.
+
+        requester
+        room_id (str)
+        visibility (str): "public" or "private"
+        """
+        if requester.is_guest:
+            raise AuthError(403, "Guests cannot edit the published room list")
+
+        if visibility not in ["public", "private"]:
+            raise SynapseError(400, "Invalid visibility setting")
+
+        room = yield self.store.get_room(room_id)
+        if room is None:
+            raise SynapseError(400, "Unknown room")
+
+        yield self.auth.check_can_change_room_list(room_id, requester.user)
+
+        yield self.store.set_room_is_public(room_id, visibility == "public")

+ 35 - 78
synapse/handlers/events.py

@@ -18,7 +18,8 @@ from twisted.internet import defer
 from synapse.util.logutils import log_function
 from synapse.types import UserID
 from synapse.events.utils import serialize_event
-from synapse.util.logcontext import preserve_context_over_fn
+from synapse.api.constants import Membership, EventTypes
+from synapse.events import EventBase
 
 from ._base import BaseHandler
 
@@ -29,20 +30,6 @@ import random
 logger = logging.getLogger(__name__)
 
 
-def started_user_eventstream(distributor, user):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "started_user_eventstream", user
-    )
-
-
-def stopped_user_eventstream(distributor, user):
-    return preserve_context_over_fn(
-        distributor.fire,
-        "stopped_user_eventstream", user
-    )
-
-
 class EventStreamHandler(BaseHandler):
 
     def __init__(self, hs):
@@ -61,61 +48,6 @@ class EventStreamHandler(BaseHandler):
 
         self.notifier = hs.get_notifier()
 
-    @defer.inlineCallbacks
-    def started_stream(self, user):
-        """Tells the presence handler that we have started an eventstream for
-        the user:
-
-        Args:
-            user (User): The user who started a stream.
-        Returns:
-            A deferred that completes once their presence has been updated.
-        """
-        if user not in self._streams_per_user:
-            # Make sure we set the streams per user to 1 here rather than
-            # setting it to zero and incrementing the value below.
-            # Otherwise this may race with stopped_stream causing the
-            # user to be erased from the map before we have a chance
-            # to increment it.
-            self._streams_per_user[user] = 1
-            if user in self._stop_timer_per_user:
-                try:
-                    self.clock.cancel_call_later(
-                        self._stop_timer_per_user.pop(user)
-                    )
-                except:
-                    logger.exception("Failed to cancel event timer")
-            else:
-                yield started_user_eventstream(self.distributor, user)
-        else:
-            self._streams_per_user[user] += 1
-
-    def stopped_stream(self, user):
-        """If there are no streams for a user this starts a timer that will
-        notify the presence handler that we haven't got an event stream for
-        the user unless the user starts a new stream in 30 seconds.
-
-        Args:
-            user (User): The user who stopped a stream.
-        """
-        self._streams_per_user[user] -= 1
-        if not self._streams_per_user[user]:
-            del self._streams_per_user[user]
-
-            # 30 seconds of grace to allow the client to reconnect again
-            #   before we think they're gone
-            def _later():
-                logger.debug("_later stopped_user_eventstream %s", user)
-
-                self._stop_timer_per_user.pop(user, None)
-
-                return stopped_user_eventstream(self.distributor, user)
-
-            logger.debug("Scheduling _later: for %s", user)
-            self._stop_timer_per_user[user] = (
-                self.clock.call_later(30, _later)
-            )
-
     @defer.inlineCallbacks
     @log_function
     def get_stream(self, auth_user_id, pagin_config, timeout=0,
@@ -126,11 +58,12 @@ class EventStreamHandler(BaseHandler):
         If `only_keys` is not None, events from keys will be sent down.
         """
         auth_user = UserID.from_string(auth_user_id)
+        presence_handler = self.hs.get_handlers().presence_handler
 
-        try:
-            if affect_presence:
-                yield self.started_stream(auth_user)
-
+        context = yield presence_handler.user_syncing(
+            auth_user_id, affect_presence=affect_presence,
+        )
+        with context:
             if timeout:
                 # If they've set a timeout set a minimum limit.
                 timeout = max(timeout, 500)
@@ -145,6 +78,34 @@ class EventStreamHandler(BaseHandler):
                 is_guest=is_guest, explicit_room_id=room_id
             )
 
+            # When the user joins a new room, or another user joins a currently
+            # joined room, we need to send down presence for those users.
+            to_add = []
+            for event in events:
+                if not isinstance(event, EventBase):
+                    continue
+                if event.type == EventTypes.Member:
+                    if event.membership != Membership.JOIN:
+                        continue
+                    # Send down presence.
+                    if event.state_key == auth_user_id:
+                        # Send down presence for everyone in the room.
+                        users = yield self.store.get_users_in_room(event.room_id)
+                        states = yield presence_handler.get_states(
+                            users,
+                            as_event=True,
+                        )
+                        to_add.extend(states)
+                    else:
+
+                        ev = yield presence_handler.get_state(
+                            UserID.from_string(event.state_key),
+                            as_event=True,
+                        )
+                        to_add.append(ev)
+
+            events.extend(to_add)
+
             time_now = self.clock.time_msec()
 
             chunks = [
@@ -159,10 +120,6 @@ class EventStreamHandler(BaseHandler):
 
             defer.returnValue(chunk)
 
-        finally:
-            if affect_presence:
-                self.stopped_stream(auth_user)
-
 
 class EventHandler(BaseHandler):
 

+ 105 - 49
synapse/handlers/federation.py

@@ -14,6 +14,9 @@
 # limitations under the License.
 
 """Contains handlers for federation events."""
+from signedjson.key import decode_verify_key_bytes
+from signedjson.sign import verify_signed_json
+from unpaddedbase64 import decode_base64
 
 from ._base import BaseHandler
 
@@ -99,7 +102,7 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, backfilled, state=None,
+    def on_receive_pdu(self, origin, pdu, state=None,
                        auth_chain=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
@@ -120,7 +123,6 @@ class FederationHandler(BaseHandler):
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
         # in the room work
-        current_state = None
         is_in_room = yield self.auth.check_host_in_room(
             event.room_id,
             self.server_name
@@ -183,8 +185,6 @@ class FederationHandler(BaseHandler):
                     origin,
                     event,
                     state=state,
-                    backfilled=backfilled,
-                    current_state=current_state,
                 )
             except AuthError as e:
                 raise FederationError(
@@ -213,18 +213,17 @@ class FederationHandler(BaseHandler):
             except StoreError:
                 logger.exception("Failed to store room.")
 
-        if not backfilled:
-            extra_users = []
-            if event.type == EventTypes.Member:
-                target_user_id = event.state_key
-                target_user = UserID.from_string(target_user_id)
-                extra_users.append(target_user)
+        extra_users = []
+        if event.type == EventTypes.Member:
+            target_user_id = event.state_key
+            target_user = UserID.from_string(target_user_id)
+            extra_users.append(target_user)
 
-            with PreserveLoggingContext():
-                self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id,
-                    extra_users=extra_users
-                )
+        with PreserveLoggingContext():
+            self.notifier.on_new_room_event(
+                event, event_stream_id, max_stream_id,
+                extra_users=extra_users
+            )
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -469,7 +468,7 @@ class FederationHandler(BaseHandler):
                         limit=100,
                         extremities=[e for e in extremities.keys()]
                     )
-                except SynapseError:
+                except SynapseError as e:
                     logger.info(
                         "Failed to backfill from %s because %s",
                         dom, e,
@@ -644,7 +643,7 @@ class FederationHandler(BaseHandler):
                     continue
 
                 try:
-                    self.on_receive_pdu(origin, p, backfilled=False)
+                    self.on_receive_pdu(origin, p)
                 except:
                     logger.exception("Couldn't handle pdu")
 
@@ -776,7 +775,6 @@ class FederationHandler(BaseHandler):
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
-            backfilled=False,
         )
 
         target_user = UserID.from_string(event.state_key)
@@ -810,7 +808,21 @@ class FederationHandler(BaseHandler):
             target_hosts,
             signed_event
         )
-        defer.returnValue(None)
+
+        context = yield self.state_handler.compute_event_context(event)
+
+        event_stream_id, max_stream_id = yield self.store.persist_event(
+            event,
+            context=context,
+        )
+
+        target_user = UserID.from_string(event.state_key)
+        self.notifier.on_new_room_event(
+            event, event_stream_id, max_stream_id,
+            extra_users=[target_user],
+        )
+
+        defer.returnValue(event)
 
     @defer.inlineCallbacks
     def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
@@ -1056,8 +1068,7 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_event(self, origin, event, state=None, backfilled=False,
-                          current_state=None, auth_events=None):
+    def _handle_new_event(self, origin, event, state=None, auth_events=None):
 
         outlier = event.internal_metadata.is_outlier()
 
@@ -1067,7 +1078,7 @@ class FederationHandler(BaseHandler):
             auth_events=auth_events,
         )
 
-        if not backfilled and not event.internal_metadata.is_outlier():
+        if not event.internal_metadata.is_outlier():
             action_generator = ActionGenerator(self.hs)
             yield action_generator.handle_push_actions_for_event(
                 event, context, self
@@ -1076,9 +1087,7 @@ class FederationHandler(BaseHandler):
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
-            backfilled=backfilled,
-            is_new_state=(not outlier and not backfilled),
-            current_state=current_state,
+            is_new_state=not outlier,
         )
 
         defer.returnValue((context, event_stream_id, max_stream_id))
@@ -1176,7 +1185,6 @@ class FederationHandler(BaseHandler):
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event, new_event_context,
-            backfilled=False,
             is_new_state=True,
             current_state=state,
         )
@@ -1620,19 +1628,15 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def exchange_third_party_invite(self, invite):
-        sender = invite["sender"]
-        room_id = invite["room_id"]
-
-        if "signed" not in invite or "token" not in invite["signed"]:
-            logger.info(
-                "Discarding received notification of third party invite "
-                "without signed: %s" % (invite,)
-            )
-            return
-
+    def exchange_third_party_invite(
+            self,
+            sender_user_id,
+            target_user_id,
+            room_id,
+            signed,
+    ):
         third_party_invite = {
-            "signed": invite["signed"],
+            "signed": signed,
         }
 
         event_dict = {
@@ -1642,8 +1646,8 @@ class FederationHandler(BaseHandler):
                 "third_party_invite": third_party_invite,
             },
             "room_id": room_id,
-            "sender": sender,
-            "state_key": invite["mxid"],
+            "sender": sender_user_id,
+            "state_key": target_user_id,
         }
 
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
@@ -1656,11 +1660,11 @@ class FederationHandler(BaseHandler):
             )
 
             self.auth.check(event, context.current_state)
-            yield self._validate_keyserver(event, auth_events=context.current_state)
+            yield self._check_signature(event, auth_events=context.current_state)
             member_handler = self.hs.get_handlers().room_member_handler
-            yield member_handler.send_membership_event(event, context)
+            yield member_handler.send_membership_event(None, event, context)
         else:
-            destinations = set([x.split(":", 1)[-1] for x in (sender, room_id)])
+            destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
             yield self.replication_layer.forward_third_party_invite(
                 destinations,
                 room_id,
@@ -1681,13 +1685,13 @@ class FederationHandler(BaseHandler):
         )
 
         self.auth.check(event, auth_events=context.current_state)
-        yield self._validate_keyserver(event, auth_events=context.current_state)
+        yield self._check_signature(event, auth_events=context.current_state)
 
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
         event.signatures.update(returned_invite.signatures)
         member_handler = self.hs.get_handlers().room_member_handler
-        yield member_handler.send_membership_event(event, context)
+        yield member_handler.send_membership_event(None, event, context)
 
     @defer.inlineCallbacks
     def add_display_name_to_third_party_invite(self, event_dict, event, context):
@@ -1711,17 +1715,69 @@ class FederationHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def _validate_keyserver(self, event, auth_events):
-        token = event.content["third_party_invite"]["signed"]["token"]
+    def _check_signature(self, event, auth_events):
+        """
+        Checks that the signature in the event is consistent with its invite.
+        :param event (Event): The m.room.member event to check
+        :param auth_events (dict<(event type, state_key), event>)
+
+        :raises
+            AuthError if signature didn't match any keys, or key has been
+                revoked,
+            SynapseError if a transient error meant a key couldn't be checked
+                for revocation.
+        """
+        signed = event.content["third_party_invite"]["signed"]
+        token = signed["token"]
 
         invite_event = auth_events.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
+        if not invite_event:
+            raise AuthError(403, "Could not find invite")
+
+        last_exception = None
+        for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
+            try:
+                for server, signature_block in signed["signatures"].items():
+                    for key_name, encoded_signature in signature_block.items():
+                        if not key_name.startswith("ed25519:"):
+                            continue
+
+                        public_key = public_key_object["public_key"]
+                        verify_key = decode_verify_key_bytes(
+                            key_name,
+                            decode_base64(public_key)
+                        )
+                        verify_signed_json(signed, server, verify_key)
+                        if "key_validity_url" in public_key_object:
+                            yield self._check_key_revocation(
+                                public_key,
+                                public_key_object["key_validity_url"]
+                            )
+                        return
+            except Exception as e:
+                last_exception = e
+        raise last_exception
+
+    @defer.inlineCallbacks
+    def _check_key_revocation(self, public_key, url):
+        """
+        Checks whether public_key has been revoked.
+
+        :param public_key (str): base-64 encoded public key.
+        :param url (str): Key revocation URL.
+
+        :raises
+            AuthError if they key has been revoked.
+            SynapseError if a transient error meant a key couldn't be checked
+                for revocation.
+        """
         try:
             response = yield self.hs.get_simple_http_client().get_json(
-                invite_event.content["key_validity_url"],
-                {"public_key": invite_event.content["public_key"]}
+                url,
+                {"public_key": public_key}
             )
         except Exception:
             raise SynapseError(

+ 64 - 43
synapse/handlers/message.py

@@ -16,12 +16,11 @@
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, Membership
-from synapse.api.errors import AuthError, Codes
+from synapse.api.errors import AuthError, Codes, SynapseError
 from synapse.streams.config import PaginationConfig
 from synapse.events.utils import serialize_event
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.util.caches.snapshot_cache import SnapshotCache
 from synapse.types import UserID, RoomStreamToken, StreamToken
 
@@ -197,12 +196,25 @@ class MessageHandler(BaseHandler):
 
         if builder.type == EventTypes.Member:
             membership = builder.content.get("membership", None)
+            target = UserID.from_string(builder.state_key)
+
             if membership == Membership.JOIN:
-                joinee = UserID.from_string(builder.state_key)
                 # If event doesn't include a display name, add one.
                 yield collect_presencelike_data(
-                    self.distributor, joinee, builder.content
+                    self.distributor, target, builder.content
                 )
+            elif membership == Membership.INVITE:
+                profile = self.hs.get_handlers().profile_handler
+                content = builder.content
+
+                try:
+                    content["displayname"] = yield profile.get_displayname(target)
+                    content["avatar_url"] = yield profile.get_avatar_url(target)
+                except Exception as e:
+                    logger.info(
+                        "Failed to get profile information for %r: %s",
+                        target, e
+                    )
 
         if token_id is not None:
             builder.internal_metadata.token_id = token_id
@@ -216,7 +228,7 @@ class MessageHandler(BaseHandler):
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def send_event(self, event, context, ratelimit=True, is_guest=False):
+    def send_nonmember_event(self, requester, event, context, ratelimit=True):
         """
         Persists and notifies local clients and federation of an event.
 
@@ -226,55 +238,70 @@ class MessageHandler(BaseHandler):
             ratelimit (bool): Whether to rate limit this send.
             is_guest (bool): Whether the sender is a guest.
         """
+        if event.type == EventTypes.Member:
+            raise SynapseError(
+                500,
+                "Tried to send member event through non-member codepath"
+            )
+
         user = UserID.from_string(event.sender)
 
         assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
 
-        if ratelimit:
-            self.ratelimit(event.sender)
-
         if event.is_state():
-            prev_state = context.current_state.get((event.type, event.state_key))
-            if prev_state and event.user_id == prev_state.user_id:
-                prev_content = encode_canonical_json(prev_state.content)
-                next_content = encode_canonical_json(event.content)
-                if prev_content == next_content:
-                    # Duplicate suppression for state updates with same sender
-                    # and content.
-                    defer.returnValue(prev_state)
-
-        if event.type == EventTypes.Member:
-            member_handler = self.hs.get_handlers().room_member_handler
-            yield member_handler.send_membership_event(event, context, is_guest=is_guest)
-        else:
-            yield self.handle_new_client_event(
-                event=event,
-                context=context,
-            )
+            prev_state = self.deduplicate_state_event(event, context)
+            if prev_state is not None:
+                defer.returnValue(prev_state)
+
+        yield self.handle_new_client_event(
+            requester=requester,
+            event=event,
+            context=context,
+            ratelimit=ratelimit,
+        )
 
         if event.type == EventTypes.Message:
             presence = self.hs.get_handlers().presence_handler
-            with PreserveLoggingContext():
-                presence.bump_presence_active_time(user)
+            yield presence.bump_presence_active_time(user)
+
+    def deduplicate_state_event(self, event, context):
+        """
+        Checks whether event is in the latest resolved state in context.
+
+        If so, returns the version of the event in context.
+        Otherwise, returns None.
+        """
+        prev_event = context.current_state.get((event.type, event.state_key))
+        if prev_event and event.user_id == prev_event.user_id:
+            prev_content = encode_canonical_json(prev_event.content)
+            next_content = encode_canonical_json(event.content)
+            if prev_content == next_content:
+                return prev_event
+        return None
 
     @defer.inlineCallbacks
-    def create_and_send_event(self, event_dict, ratelimit=True,
-                              token_id=None, txn_id=None, is_guest=False):
+    def create_and_send_nonmember_event(
+        self,
+        requester,
+        event_dict,
+        ratelimit=True,
+        txn_id=None
+    ):
         """
         Creates an event, then sends it.
 
-        See self.create_event and self.send_event.
+        See self.create_event and self.send_nonmember_event.
         """
         event, context = yield self.create_event(
             event_dict,
-            token_id=token_id,
+            token_id=requester.access_token_id,
             txn_id=txn_id
         )
-        yield self.send_event(
+        yield self.send_nonmember_event(
+            requester,
             event,
             context,
             ratelimit=ratelimit,
-            is_guest=is_guest
         )
         defer.returnValue(event)
 
@@ -635,8 +662,8 @@ class MessageHandler(BaseHandler):
             user_id, messages, is_peeking=is_peeking
         )
 
-        start_token = StreamToken(token[0], 0, 0, 0, 0)
-        end_token = StreamToken(token[1], 0, 0, 0, 0)
+        start_token = StreamToken.START.copy_and_replace("room_key", token[0])
+        end_token = StreamToken.START.copy_and_replace("room_key", token[1])
 
         time_now = self.clock.time_msec()
 
@@ -660,10 +687,6 @@ class MessageHandler(BaseHandler):
             room_id=room_id,
         )
 
-        # TODO(paul): I wish I was called with user objects not user_id
-        #   strings...
-        auth_user = UserID.from_string(user_id)
-
         # TODO: These concurrently
         time_now = self.clock.time_msec()
         state = [
@@ -688,13 +711,11 @@ class MessageHandler(BaseHandler):
         @defer.inlineCallbacks
         def get_presence():
             states = yield presence_handler.get_states(
-                target_users=[UserID.from_string(m.user_id) for m in room_members],
-                auth_user=auth_user,
+                [m.user_id for m in room_members],
                 as_event=True,
-                check_auth=False,
             )
 
-            defer.returnValue(states.values())
+            defer.returnValue(states)
 
         @defer.inlineCallbacks
         def get_receipts():

File diff suppressed because it is too large
+ 520 - 364
synapse/handlers/presence.py


+ 24 - 24
synapse/handlers/profile.py

@@ -16,8 +16,7 @@
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, AuthError, CodeMessageException
-from synapse.api.constants import EventTypes, Membership
-from synapse.types import UserID
+from synapse.types import UserID, Requester
 from synapse.util import unwrapFirstError
 
 from ._base import BaseHandler
@@ -49,6 +48,9 @@ class ProfileHandler(BaseHandler):
         distributor = hs.get_distributor()
         self.distributor = distributor
 
+        distributor.declare("collect_presencelike_data")
+        distributor.declare("changed_presencelike_data")
+
         distributor.observe("registered_user", self.registered_user)
 
         distributor.observe(
@@ -87,13 +89,13 @@ class ProfileHandler(BaseHandler):
                 defer.returnValue(result["displayname"])
 
     @defer.inlineCallbacks
-    def set_displayname(self, target_user, auth_user, new_displayname):
+    def set_displayname(self, target_user, requester, new_displayname):
         """target_user is the user whose displayname is to be changed;
         auth_user is the user attempting to make this change."""
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != auth_user:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's displayname")
 
         if new_displayname == '':
@@ -107,7 +109,7 @@ class ProfileHandler(BaseHandler):
             "displayname": new_displayname,
         })
 
-        yield self._update_join_states(target_user)
+        yield self._update_join_states(requester)
 
     @defer.inlineCallbacks
     def get_avatar_url(self, target_user):
@@ -137,13 +139,13 @@ class ProfileHandler(BaseHandler):
             defer.returnValue(result["avatar_url"])
 
     @defer.inlineCallbacks
-    def set_avatar_url(self, target_user, auth_user, new_avatar_url):
+    def set_avatar_url(self, target_user, requester, new_avatar_url):
         """target_user is the user whose avatar_url is to be changed;
         auth_user is the user attempting to make this change."""
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this Home Server")
 
-        if target_user != auth_user:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's avatar_url")
 
         yield self.store.set_profile_avatar_url(
@@ -154,7 +156,7 @@ class ProfileHandler(BaseHandler):
             "avatar_url": new_avatar_url,
         })
 
-        yield self._update_join_states(target_user)
+        yield self._update_join_states(requester)
 
     @defer.inlineCallbacks
     def collect_presencelike_data(self, user, state):
@@ -197,32 +199,30 @@ class ProfileHandler(BaseHandler):
         defer.returnValue(response)
 
     @defer.inlineCallbacks
-    def _update_join_states(self, user):
+    def _update_join_states(self, requester):
+        user = requester.user
         if not self.hs.is_mine(user):
             return
 
-        self.ratelimit(user.to_string())
+        self.ratelimit(requester)
 
         joins = yield self.store.get_rooms_for_user(
             user.to_string(),
         )
 
         for j in joins:
-            content = {
-                "membership": Membership.JOIN,
-            }
-
-            yield collect_presencelike_data(self.distributor, user, content)
-
-            msg_handler = self.hs.get_handlers().message_handler
+            handler = self.hs.get_handlers().room_member_handler
             try:
-                yield msg_handler.create_and_send_event({
-                    "type": EventTypes.Member,
-                    "room_id": j.room_id,
-                    "state_key": user.to_string(),
-                    "content": content,
-                    "sender": user.to_string()
-                }, ratelimit=False)
+                # Assume the user isn't a guest because we don't let guests set
+                # profile or avatar data.
+                requester = Requester(user, "", False)
+                yield handler.update_membership(
+                    requester,
+                    user,
+                    j.room_id,
+                    "join",  # We treat a profile update like a join.
+                    ratelimit=False,  # Try to hide that these events aren't atomic.
+                )
             except Exception as e:
                 logger.warn(
                     "Failed to update join event for room %s - %s",

+ 0 - 2
synapse/handlers/receipts.py

@@ -36,8 +36,6 @@ class ReceiptsHandler(BaseHandler):
         )
         self.clock = self.hs.get_clock()
 
-        self._receipt_cache = None
-
     @defer.inlineCallbacks
     def received_client_receipt(self, room_id, receipt_type, user_id,
                                 event_id):

+ 44 - 8
synapse/handlers/register.py

@@ -47,7 +47,8 @@ class RegistrationHandler(BaseHandler):
         self._next_generated_user_id = None
 
     @defer.inlineCallbacks
-    def check_username(self, localpart, guest_access_token=None):
+    def check_username(self, localpart, guest_access_token=None,
+                       assigned_user_id=None):
         yield run_on_reactor()
 
         if urllib.quote(localpart.encode('utf-8')) != localpart:
@@ -60,7 +61,16 @@ class RegistrationHandler(BaseHandler):
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
-        yield self.check_user_id_is_valid(user_id)
+        if assigned_user_id:
+            if user_id == assigned_user_id:
+                return
+            else:
+                raise SynapseError(
+                    400,
+                    "A different user ID has already been registered for this session",
+                )
+
+        yield self.check_user_id_not_appservice_exclusive(user_id)
 
         users = yield self.store.get_users_by_id_case_insensitive(user_id)
         if users:
@@ -145,7 +155,7 @@ class RegistrationHandler(BaseHandler):
                 localpart = yield self._generate_user_id(attempts > 0)
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
-                yield self.check_user_id_is_valid(user_id)
+                yield self.check_user_id_not_appservice_exclusive(user_id)
                 if generate_token:
                     token = self.auth_handler().generate_access_token(user_id)
                 try:
@@ -157,6 +167,7 @@ class RegistrationHandler(BaseHandler):
                     )
                 except SynapseError:
                     # if user id is taken, just generate another
+                    user = None
                     user_id = None
                     token = None
                     attempts += 1
@@ -180,11 +191,19 @@ class RegistrationHandler(BaseHandler):
                 400, "Invalid user localpart for this application service.",
                 errcode=Codes.EXCLUSIVE
             )
+
+        service_id = service.id if service.is_exclusive_user(user_id) else None
+
+        yield self.check_user_id_not_appservice_exclusive(
+            user_id, allowed_appservice=service
+        )
+
         token = self.auth_handler().generate_access_token(user_id)
         yield self.store.register(
             user_id=user_id,
             token=token,
-            password_hash=""
+            password_hash="",
+            appservice_id=service_id,
         )
         yield registered_user(self.distributor, user)
         defer.returnValue((user_id, token))
@@ -226,7 +245,7 @@ class RegistrationHandler(BaseHandler):
         user = UserID(localpart, self.hs.hostname)
         user_id = user.to_string()
 
-        yield self.check_user_id_is_valid(user_id)
+        yield self.check_user_id_not_appservice_exclusive(user_id)
         token = self.auth_handler().generate_access_token(user_id)
         try:
             yield self.store.register(
@@ -235,7 +254,7 @@ class RegistrationHandler(BaseHandler):
                 password_hash=None
             )
             yield registered_user(self.distributor, user)
-        except Exception, e:
+        except Exception as e:
             yield self.store.add_access_token_to_user(user_id, token)
             # Ignore Registration errors
             logger.exception(e)
@@ -278,12 +297,14 @@ class RegistrationHandler(BaseHandler):
             yield identity_handler.bind_threepid(c, user_id)
 
     @defer.inlineCallbacks
-    def check_user_id_is_valid(self, user_id):
+    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
         # valid user IDs must not clash with any user ID namespaces claimed by
         # application services.
         services = yield self.store.get_app_services()
         interested_services = [
-            s for s in services if s.is_interested_in_user(user_id)
+            s for s in services
+            if s.is_interested_in_user(user_id)
+            and s != allowed_appservice
         ]
         for service in interested_services:
             if service.is_exclusive_user(user_id):
@@ -342,3 +363,18 @@ class RegistrationHandler(BaseHandler):
 
     def auth_handler(self):
         return self.hs.get_handlers().auth_handler
+
+    @defer.inlineCallbacks
+    def guest_access_token_for(self, medium, address, inviter_user_id):
+        access_token = yield self.store.get_3pid_guest_access_token(medium, address)
+        if access_token:
+            defer.returnValue(access_token)
+
+        _, access_token = yield self.register(
+            generate_token=True,
+            make_guest=True
+        )
+        access_token = yield self.store.save_or_get_3pid_guest_access_token(
+            medium, address, access_token, inviter_user_id
+        )
+        defer.returnValue(access_token)

File diff suppressed because it is too large
+ 395 - 336
synapse/handlers/room.py


+ 76 - 9
synapse/handlers/sync.py

@@ -20,6 +20,7 @@ from synapse.api.constants import Membership, EventTypes
 from synapse.util import unwrapFirstError
 from synapse.util.logcontext import LoggingContext, preserve_fn
 from synapse.util.metrics import Measure
+from synapse.push.clientformat import format_push_rules_for_user
 
 from twisted.internet import defer
 
@@ -121,7 +122,11 @@ class SyncResult(collections.namedtuple("SyncResult", [
         events.
         """
         return bool(
-            self.presence or self.joined or self.invited or self.archived
+            self.presence or
+            self.joined or
+            self.invited or
+            self.archived or
+            self.account_data
         )
 
 
@@ -205,9 +210,9 @@ class SyncHandler(BaseHandler):
             key=None
         )
 
-        membership_list = (Membership.INVITE, Membership.JOIN)
-        if sync_config.filter_collection.include_leave:
-            membership_list += (Membership.LEAVE, Membership.BAN)
+        membership_list = (
+            Membership.INVITE, Membership.JOIN, Membership.LEAVE, Membership.BAN
+        )
 
         room_list = yield self.store.get_rooms_for_user_where_membership_is(
             user_id=sync_config.user.to_string(),
@@ -220,6 +225,10 @@ class SyncHandler(BaseHandler):
             )
         )
 
+        account_data['m.push_rules'] = yield self.push_rules_for_user(
+            sync_config.user
+        )
+
         tags_by_room = yield self.store.get_tags_for_user(
             sync_config.user.to_string()
         )
@@ -253,6 +262,12 @@ class SyncHandler(BaseHandler):
                         invite=invite,
                     ))
                 elif event.membership in (Membership.LEAVE, Membership.BAN):
+                    # Always send down rooms we were banned or kicked from.
+                    if not sync_config.filter_collection.include_leave:
+                        if event.membership == Membership.LEAVE:
+                            if sync_config.user.to_string() == event.sender:
+                                continue
+
                     leave_token = now_token.copy_and_replace(
                         "room_key", "s%d" % (event.stream_ordering,)
                     )
@@ -318,6 +333,14 @@ class SyncHandler(BaseHandler):
 
         defer.returnValue(room_sync)
 
+    @defer.inlineCallbacks
+    def push_rules_for_user(self, user):
+        user_id = user.to_string()
+        rawrules = yield self.store.get_push_rules_for_user(user_id)
+        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
+        rules = format_push_rules_for_user(user, rawrules, enabled_map)
+        defer.returnValue(rules)
+
     def account_data_for_user(self, account_data):
         account_data_events = []
 
@@ -477,6 +500,15 @@ class SyncHandler(BaseHandler):
             )
         )
 
+        push_rules_changed = yield self.store.have_push_rules_changed_for_user(
+            user_id, int(since_token.push_rules_key)
+        )
+
+        if push_rules_changed:
+            account_data["m.push_rules"] = yield self.push_rules_for_user(
+                sync_config.user
+            )
+
         # Get a list of membership change events that have happened.
         rooms_changed = yield self.store.get_membership_changes_for_user(
             user_id, since_token.room_key, now_token.room_key
@@ -582,6 +614,28 @@ class SyncHandler(BaseHandler):
             if room_sync:
                 joined.append(room_sync)
 
+        # For each newly joined room, we want to send down presence of
+        # existing users.
+        presence_handler = self.hs.get_handlers().presence_handler
+        extra_presence_users = set()
+        for room_id in newly_joined_rooms:
+            users = yield self.store.get_users_in_room(event.room_id)
+            extra_presence_users.update(users)
+
+        # For each new member, send down presence.
+        for joined_sync in joined:
+            it = itertools.chain(joined_sync.timeline.events, joined_sync.state.values())
+            for event in it:
+                if event.type == EventTypes.Member:
+                    if event.membership == Membership.JOIN:
+                        extra_presence_users.add(event.state_key)
+
+        states = yield presence_handler.get_states(
+            [u for u in extra_presence_users if u != user_id],
+            as_event=True,
+        )
+        presence.extend(states)
+
         account_data_for_user = sync_config.filter_collection.filter_account_data(
             self.account_data_for_user(account_data)
         )
@@ -623,7 +677,6 @@ class SyncHandler(BaseHandler):
                 recents = yield self._filter_events_for_client(
                     sync_config.user.to_string(),
                     recents,
-                    is_peeking=sync_config.is_guest,
                 )
             else:
                 recents = []
@@ -645,7 +698,6 @@ class SyncHandler(BaseHandler):
                 loaded_recents = yield self._filter_events_for_client(
                     sync_config.user.to_string(),
                     loaded_recents,
-                    is_peeking=sync_config.is_guest,
                 )
                 loaded_recents.extend(recents)
                 recents = loaded_recents
@@ -825,14 +877,20 @@ class SyncHandler(BaseHandler):
         with Measure(self.clock, "compute_state_delta"):
             if full_state:
                 if batch:
+                    current_state = yield self.store.get_state_for_event(
+                        batch.events[-1].event_id
+                    )
+
                     state = yield self.store.get_state_for_event(
                         batch.events[0].event_id
                     )
                 else:
-                    state = yield self.get_state_at(
+                    current_state = yield self.get_state_at(
                         room_id, stream_position=now_token
                     )
 
+                    state = current_state
+
                 timeline_state = {
                     (event.type, event.state_key): event
                     for event in batch.events if event.is_state()
@@ -842,12 +900,17 @@ class SyncHandler(BaseHandler):
                     timeline_contains=timeline_state,
                     timeline_start=state,
                     previous={},
+                    current=current_state,
                 )
             elif batch.limited:
                 state_at_previous_sync = yield self.get_state_at(
                     room_id, stream_position=since_token
                 )
 
+                current_state = yield self.store.get_state_for_event(
+                    batch.events[-1].event_id
+                )
+
                 state_at_timeline_start = yield self.store.get_state_for_event(
                     batch.events[0].event_id
                 )
@@ -861,6 +924,7 @@ class SyncHandler(BaseHandler):
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
                     previous=state_at_previous_sync,
+                    current=current_state,
                 )
             else:
                 state = {}
@@ -920,7 +984,7 @@ def _action_has_highlight(actions):
     return False
 
 
-def _calculate_state(timeline_contains, timeline_start, previous):
+def _calculate_state(timeline_contains, timeline_start, previous, current):
     """Works out what state to include in a sync response.
 
     Args:
@@ -928,6 +992,7 @@ def _calculate_state(timeline_contains, timeline_start, previous):
         timeline_start (dict): state at the start of the timeline
         previous (dict): state at the end of the previous sync (or empty dict
             if this is an initial sync)
+        current (dict): state at the end of the timeline
 
     Returns:
         dict
@@ -938,14 +1003,16 @@ def _calculate_state(timeline_contains, timeline_start, previous):
             timeline_contains.values(),
             previous.values(),
             timeline_start.values(),
+            current.values(),
         )
     }
 
+    c_ids = set(e.event_id for e in current.values())
     tc_ids = set(e.event_id for e in timeline_contains.values())
     p_ids = set(e.event_id for e in previous.values())
     ts_ids = set(e.event_id for e in timeline_start.values())
 
-    state_ids = (ts_ids - p_ids) - tc_ids
+    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
 
     evs = (event_id_to_state[e] for e in state_ids)
     return {

+ 14 - 0
synapse/handlers/typing.py

@@ -25,6 +25,7 @@ from synapse.types import UserID
 import logging
 
 from collections import namedtuple
+import ujson as json
 
 logger = logging.getLogger(__name__)
 
@@ -219,6 +220,19 @@ class TypingNotificationHandler(BaseHandler):
                 "typing_key", self._latest_room_serial, rooms=[room_id]
             )
 
+    def get_all_typing_updates(self, last_id, current_id):
+        # TODO: Work out a way to do this without scanning the entire state.
+        rows = []
+        for room_id, serial in self._room_serials.items():
+            if last_id < serial and serial <= current_id:
+                typing = self._room_typing[room_id]
+                typing_bytes = json.dumps([
+                    u.to_string() for u in typing
+                ], ensure_ascii=False)
+                rows.append((serial, room_id, typing_bytes))
+        rows.sort()
+        return rows
+
 
 class TypingNotificationEventSource(object):
     def __init__(self, hs):

+ 15 - 2
synapse/http/client.py

@@ -103,7 +103,7 @@ class SimpleHttpClient(object):
         # TODO: Do we ever want to log message contents?
         logger.debug("post_urlencoded_get_json args: %s", args)
 
-        query_bytes = urllib.urlencode(args, True)
+        query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
         response = yield self.request(
             "POST",
@@ -249,7 +249,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
 
     @defer.inlineCallbacks
     def post_urlencoded_get_raw(self, url, args={}):
-        query_bytes = urllib.urlencode(args, True)
+        query_bytes = urllib.urlencode(encode_urlencode_args(args), True)
 
         response = yield self.request(
             "POST",
@@ -269,6 +269,19 @@ class CaptchaServerHttpClient(SimpleHttpClient):
             defer.returnValue(e.response)
 
 
+def encode_urlencode_args(args):
+    return {k: encode_urlencode_arg(v) for k, v in args.items()}
+
+
+def encode_urlencode_arg(arg):
+    if isinstance(arg, unicode):
+        return arg.encode('utf-8')
+    elif isinstance(arg, list):
+        return [encode_urlencode_arg(i) for i in arg]
+    else:
+        return arg
+
+
 def _print_ex(e):
     if hasattr(e, "reasons") and e.reasons:
         for ex in e.reasons:

+ 26 - 5
synapse/http/server.py

@@ -18,6 +18,7 @@ from synapse.api.errors import (
     cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError, Codes
 )
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
+from synapse.util.caches import intern_dict
 import synapse.metrics
 import synapse.events
 
@@ -229,11 +230,12 @@ class JsonResource(HttpServer, resource.Resource):
             else:
                 servlet_classname = "%r" % callback
 
-            args = [
-                urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
-            ]
+            kwargs = intern_dict({
+                name: urllib.unquote(value).decode("UTF-8") if value else value
+                for name, value in m.groupdict().items()
+            })
 
-            callback_return = yield callback(request, *args)
+            callback_return = yield callback(request, **kwargs)
             if callback_return is not None:
                 code, response = callback_return
                 self._send_response(request, code, response)
@@ -367,10 +369,29 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
                           "Origin, X-Requested-With, Content-Type, Accept")
 
     request.write(json_bytes)
-    request.finish()
+    finish_request(request)
     return NOT_DONE_YET
 
 
+def finish_request(request):
+    """ Finish writing the response to the request.
+
+    Twisted throws a RuntimeException if the connection closed before the
+    response was written but doesn't provide a convenient or reliable way to
+    determine if the connection was closed. So we catch and log the RuntimeException
+
+    You might think that ``request.notifyFinish`` could be used to tell if the
+    request was finished. However the deferred it returns won't fire if the
+    connection was already closed, meaning we'd have to have called the method
+    right at the start of the request. By the time we want to write the response
+    it will already be too late.
+    """
+    try:
+        request.finish()
+    except RuntimeError as e:
+        logger.info("Connection disconnected before response was written: %r", e)
+
+
 def _request_user_agent_is_curl(request):
     user_agents = request.requestHeaders.getRawHeaders(
         "User-Agent", default=[]

+ 85 - 5
synapse/http/servlet.py

@@ -15,14 +15,27 @@
 
 """ This module contains base REST classes for constructing REST servlets. """
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, Codes
 
 import logging
+import simplejson
 
 logger = logging.getLogger(__name__)
 
 
 def parse_integer(request, name, default=None, required=False):
+    """Parse an integer parameter from the request string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: An int value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not an integer.
+    """
     if name in request.args:
         try:
             return int(request.args[name][0])
@@ -32,12 +45,25 @@ def parse_integer(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing integer query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_boolean(request, name, default=None, required=False):
+    """Parse a boolean parameter from the request query string
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :return: A bool value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present and not one of "true" or "false".
+    """
+
     if name in request.args:
         try:
             return {
@@ -53,30 +79,84 @@ def parse_boolean(request, name, default=None, required=False):
     else:
         if required:
             message = "Missing boolean query parameter %r" % (name,)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
 def parse_string(request, name, default=None, required=False,
                  allowed_values=None, param_type="string"):
+    """Parse a string parameter from the request query string.
+
+    :param request: the twisted HTTP request.
+    :param name (str): the name of the query parameter.
+    :param default: value to use if the parameter is absent, defaults to None.
+    :param required (bool): whether to raise a 400 SynapseError if the
+        parameter is absent, defaults to False.
+    :param allowed_values (list): List of allowed values for the string,
+        or None if any value is allowed, defaults to None
+    :return: A string value or the default.
+    :raises
+        SynapseError if the parameter is absent and required, or if the
+            parameter is present, must be one of a list of allowed values and
+            is not one of those allowed values.
+    """
+
     if name in request.args:
         value = request.args[name][0]
         if allowed_values is not None and value not in allowed_values:
             message = "Query parameter %r must be one of [%s]" % (
                 name, ", ".join(repr(v) for v in allowed_values)
             )
-            raise SynapseError(message)
+            raise SynapseError(400, message)
         else:
             return value
     else:
         if required:
             message = "Missing %s query parameter %r" % (param_type, name)
-            raise SynapseError(400, message)
+            raise SynapseError(400, message, errcode=Codes.MISSING_PARAM)
         else:
             return default
 
 
+def parse_json_value_from_request(request):
+    """Parse a JSON value from the body of a twisted HTTP request.
+
+    :param request: the twisted HTTP request.
+    :returns: The JSON value.
+    :raises
+        SynapseError if the request body couldn't be decoded as JSON.
+    """
+    try:
+        content_bytes = request.content.read()
+    except:
+        raise SynapseError(400, "Error reading JSON content.")
+
+    try:
+        content = simplejson.loads(content_bytes)
+    except simplejson.JSONDecodeError:
+        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+
+    return content
+
+
+def parse_json_object_from_request(request):
+    """Parse a JSON object from the body of a twisted HTTP request.
+
+    :param request: the twisted HTTP request.
+    :raises
+        SynapseError if the request body couldn't be decoded as JSON or
+            if it wasn't a JSON object.
+    """
+    content = parse_json_value_from_request(request)
+
+    if type(content) != dict:
+        message = "Content must be a JSON object."
+        raise SynapseError(400, message, errcode=Codes.BAD_JSON)
+
+    return content
+
+
 class RestServlet(object):
 
     """ A Synapse REST Servlet.

+ 55 - 1
synapse/notifier.py

@@ -159,6 +159,8 @@ class Notifier(object):
             self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
         )
 
+        self.replication_deferred = ObservableDeferred(defer.Deferred())
+
         # This is not a very cheap test to perform, but it's only executed
         # when rendering the metrics page, which is likely once per minute at
         # most when scraping it.
@@ -207,6 +209,8 @@ class Notifier(object):
             ))
             self._notify_pending_new_room_events(max_room_stream_id)
 
+            self.notify_replication()
+
     def _notify_pending_new_room_events(self, max_room_stream_id):
         """Notify for the room events that were queued waiting for a previous
         event to be persisted.
@@ -276,9 +280,17 @@ class Notifier(object):
                 except:
                     logger.exception("Failed to notify listener")
 
+            self.notify_replication()
+
+    def on_new_replication_data(self):
+        """Used to inform replication listeners that something has happend
+        without waking up any of the normal user event streams"""
+        with PreserveLoggingContext():
+            self.notify_replication()
+
     @defer.inlineCallbacks
     def wait_for_events(self, user_id, timeout, callback, room_ids=None,
-                        from_token=StreamToken("s0", "0", "0", "0", "0")):
+                        from_token=StreamToken.START):
         """Wait until the callback returns a non empty response or the
         timeout fires.
         """
@@ -479,3 +491,45 @@ class Notifier(object):
             room_streams = self.room_to_user_streams.setdefault(room_id, set())
             room_streams.add(new_user_stream)
             new_user_stream.rooms.add(room_id)
+
+    def notify_replication(self):
+        """Notify the any replication listeners that there's a new event"""
+        with PreserveLoggingContext():
+            deferred = self.replication_deferred
+            self.replication_deferred = ObservableDeferred(defer.Deferred())
+            deferred.callback(None)
+
+    @defer.inlineCallbacks
+    def wait_for_replication(self, callback, timeout):
+        """Wait for an event to happen.
+
+        :param callback:
+            Gets called whenever an event happens. If this returns a truthy
+            value then ``wait_for_replication`` returns, otherwise it waits
+            for another event.
+        :param int timeout:
+            How many milliseconds to wait for callback return a truthy value.
+        :returns:
+            A deferred that resolves with the value returned by the callback.
+        """
+        listener = _NotificationListener(None)
+
+        def timed_out():
+            listener.deferred.cancel()
+
+        timer = self.clock.call_later(timeout / 1000., timed_out)
+        while True:
+            listener.deferred = self.replication_deferred.observe()
+            result = yield callback()
+            if result:
+                break
+
+            try:
+                with PreserveLoggingContext():
+                    yield listener.deferred
+            except defer.CancelledError:
+                break
+
+        self.clock.cancel_call_later(timer, ignore_errs=True)
+
+        defer.returnValue(result)

+ 5 - 6
synapse/push/__init__.py

@@ -21,7 +21,7 @@ from synapse.util.logcontext import LoggingContext
 from synapse.util.metrics import Measure
 
 import synapse.util.async
-import push_rule_evaluator as push_rule_evaluator
+from .push_rule_evaluator import evaluator_for_user_id
 
 import logging
 import random
@@ -47,14 +47,13 @@ class Pusher(object):
     MAX_BACKOFF = 60 * 60 * 1000
     GIVE_UP_AFTER = 24 * 60 * 60 * 1000
 
-    def __init__(self, _hs, profile_tag, user_id, app_id,
+    def __init__(self, _hs, user_id, app_id,
                  app_display_name, device_display_name, pushkey, pushkey_ts,
                  data, last_token, last_success, failing_since):
         self.hs = _hs
         self.evStreamHandler = self.hs.get_handlers().event_stream_handler
         self.store = self.hs.get_datastore()
         self.clock = self.hs.get_clock()
-        self.profile_tag = profile_tag
         self.user_id = user_id
         self.app_id = app_id
         self.app_display_name = app_display_name
@@ -186,8 +185,8 @@ class Pusher(object):
         processed = False
 
         rule_evaluator = yield \
-            push_rule_evaluator.evaluator_for_user_id_and_profile_tag(
-                self.user_id, self.profile_tag, single_event['room_id'], self.store
+            evaluator_for_user_id(
+                self.user_id, single_event['room_id'], self.store
             )
 
         actions = yield rule_evaluator.actions_for_event(single_event)
@@ -318,7 +317,7 @@ class Pusher(object):
     @defer.inlineCallbacks
     def _get_badge_count(self):
         invites, joins = yield defer.gatherResults([
-            self.store.get_invites_for_user(self.user_id),
+            self.store.get_invited_rooms_for_user(self.user_id),
             self.store.get_rooms_for_user(self.user_id),
         ], consumeErrors=True)
 

+ 3 - 3
synapse/push/action_generator.py

@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-import bulk_push_rule_evaluator
+from .bulk_push_rule_evaluator import evaluator_for_room_id
 
 import logging
 
@@ -35,7 +35,7 @@ class ActionGenerator:
 
     @defer.inlineCallbacks
     def handle_push_actions_for_event(self, event, context, handler):
-        bulk_evaluator = yield bulk_push_rule_evaluator.evaluator_for_room_id(
+        bulk_evaluator = yield evaluator_for_room_id(
             event.room_id, self.hs, self.store
         )
 
@@ -44,5 +44,5 @@ class ActionGenerator:
         )
 
         context.push_actions = [
-            (uid, None, actions) for uid, actions in actions_by_user.items()
+            (uid, actions) for uid, actions in actions_by_user.items()
         ]

+ 50 - 7
synapse/push/baserules.py

@@ -13,46 +13,67 @@
 # limitations under the License.
 
 from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+import copy
 
 
 def list_with_base_rules(rawrules):
+    """Combine the list of rules set by the user with the default push rules
+
+    :param list rawrules: The rules the user has modified or set.
+    :returns: A new list with the rules set by the user combined with the
+        defaults.
+    """
     ruleslist = []
 
+    # Grab the base rules that the user has modified.
+    # The modified base rules have a priority_class of -1.
+    modified_base_rules = {
+        r['rule_id']: r for r in rawrules if r['priority_class'] < 0
+    }
+
+    # Remove the modified base rules from the list, They'll be added back
+    # in the default postions in the list.
+    rawrules = [r for r in rawrules if r['priority_class'] >= 0]
+
     # shove the server default rules for each kind onto the end of each
     current_prio_class = PRIORITY_CLASS_INVERSE_MAP.keys()[-1]
 
     ruleslist.extend(make_base_prepend_rules(
-        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+        PRIORITY_CLASS_INVERSE_MAP[current_prio_class], modified_base_rules
     ))
 
     for r in rawrules:
         if r['priority_class'] < current_prio_class:
             while r['priority_class'] < current_prio_class:
                 ruleslist.extend(make_base_append_rules(
-                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                    PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                    modified_base_rules,
                 ))
                 current_prio_class -= 1
                 if current_prio_class > 0:
                     ruleslist.extend(make_base_prepend_rules(
-                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                        PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                        modified_base_rules,
                     ))
 
         ruleslist.append(r)
 
     while current_prio_class > 0:
         ruleslist.extend(make_base_append_rules(
-            PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+            PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+            modified_base_rules,
         ))
         current_prio_class -= 1
         if current_prio_class > 0:
             ruleslist.extend(make_base_prepend_rules(
-                PRIORITY_CLASS_INVERSE_MAP[current_prio_class]
+                PRIORITY_CLASS_INVERSE_MAP[current_prio_class],
+                modified_base_rules,
             ))
 
     return ruleslist
 
 
-def make_base_append_rules(kind):
+def make_base_append_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
@@ -62,15 +83,31 @@ def make_base_append_rules(kind):
     elif kind == 'content':
         rules = BASE_APPEND_CONTENT_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
-def make_base_prepend_rules(kind):
+def make_base_prepend_rules(kind, modified_base_rules):
     rules = []
 
     if kind == 'override':
         rules = BASE_PREPEND_OVERRIDE_RULES
 
+    # Copy the rules before modifying them
+    rules = copy.deepcopy(rules)
+    for r in rules:
+        # Only modify the actions, keep the conditions the same.
+        modified = modified_base_rules.get(r['rule_id'])
+        if modified:
+            r['actions'] = modified['actions']
+
     return rules
 
 
@@ -263,18 +300,24 @@ BASE_APPEND_UNDERRIDE_RULES = [
 ]
 
 
+BASE_RULE_IDS = set()
+
 for r in BASE_APPEND_CONTENT_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['content']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_PREPEND_OVERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_OVRRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['override']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])
 
 for r in BASE_APPEND_UNDERRIDE_RULES:
     r['priority_class'] = PRIORITY_CLASS_MAP['underride']
     r['default'] = True
+    BASE_RULE_IDS.add(r['rule_id'])

+ 8 - 6
synapse/push/bulk_push_rule_evaluator.py

@@ -18,8 +18,8 @@ import ujson as json
 
 from twisted.internet import defer
 
-import baserules
-from push_rule_evaluator import PushRuleEvaluatorForEvent
+from .baserules import list_with_base_rules
+from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
 from synapse.api.constants import EventTypes
 
@@ -39,7 +39,7 @@ def _get_rules(room_id, user_ids, store):
     rules_enabled_by_user = yield store.bulk_get_push_rules_enabled(user_ids)
 
     rules_by_user = {
-        uid: baserules.list_with_base_rules([
+        uid: list_with_base_rules([
             decode_rule_json(rule_list)
             for rule_list in rules_by_user.get(uid, [])
         ])
@@ -103,11 +103,13 @@ class BulkPushRuleEvaluator:
 
         users_dict = yield self.store.are_guests(self.rules_by_user.keys())
 
-        filtered_by_user = yield handler._filter_events_for_clients(
+        filtered_by_user = yield handler.filter_events_for_clients(
             users_dict.items(), [event], {event.event_id: current_state}
         )
 
-        evaluator = PushRuleEvaluatorForEvent(event, len(self.users_in_room))
+        room_members = yield self.store.get_users_in_room(self.room_id)
+
+        evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
 
         condition_cache = {}
 
@@ -152,7 +154,7 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
             elif res is True:
                 continue
 
-        res = evaluator.matches(cond, uid, display_name, None)
+        res = evaluator.matches(cond, uid, display_name)
         if _id:
             cache[_id] = bool(res)
 

+ 112 - 0
synapse/push/clientformat.py

@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.push.baserules import list_with_base_rules
+
+from synapse.push.rulekinds import (
+    PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
+)
+
+import copy
+import simplejson as json
+
+
+def format_push_rules_for_user(user, rawrules, enabled_map):
+    """Converts a list of rawrules and a enabled map into nested dictionaries
+    to match the Matrix client-server format for push rules"""
+
+    ruleslist = []
+    for rawrule in rawrules:
+        rule = dict(rawrule)
+        rule["conditions"] = json.loads(rawrule["conditions"])
+        rule["actions"] = json.loads(rawrule["actions"])
+        ruleslist.append(rule)
+
+    # We're going to be mutating this a lot, so do a deep copy
+    ruleslist = copy.deepcopy(list_with_base_rules(ruleslist))
+
+    rules = {'global': {}, 'device': {}}
+
+    rules['global'] = _add_empty_priority_class_arrays(rules['global'])
+
+    for r in ruleslist:
+        rulearray = None
+
+        template_name = _priority_class_to_template_name(r['priority_class'])
+
+        # Remove internal stuff.
+        for c in r["conditions"]:
+            c.pop("_id", None)
+
+            pattern_type = c.pop("pattern_type", None)
+            if pattern_type == "user_id":
+                c["pattern"] = user.to_string()
+            elif pattern_type == "user_localpart":
+                c["pattern"] = user.localpart
+
+        rulearray = rules['global'][template_name]
+
+        template_rule = _rule_to_template(r)
+        if template_rule:
+            if r['rule_id'] in enabled_map:
+                template_rule['enabled'] = enabled_map[r['rule_id']]
+            elif 'enabled' in r:
+                template_rule['enabled'] = r['enabled']
+            else:
+                template_rule['enabled'] = True
+            rulearray.append(template_rule)
+
+    return rules
+
+
+def _add_empty_priority_class_arrays(d):
+    for pc in PRIORITY_CLASS_MAP.keys():
+        d[pc] = []
+    return d
+
+
+def _rule_to_template(rule):
+    unscoped_rule_id = None
+    if 'rule_id' in rule:
+        unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
+
+    template_name = _priority_class_to_template_name(rule['priority_class'])
+    if template_name in ['override', 'underride']:
+        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
+    elif template_name in ["sender", "room"]:
+        templaterule = {'actions': rule['actions']}
+        unscoped_rule_id = rule['conditions'][0]['pattern']
+    elif template_name == 'content':
+        if len(rule["conditions"]) != 1:
+            return None
+        thecond = rule["conditions"][0]
+        if "pattern" not in thecond:
+            return None
+        templaterule = {'actions': rule['actions']}
+        templaterule["pattern"] = thecond["pattern"]
+
+    if unscoped_rule_id:
+            templaterule['rule_id'] = unscoped_rule_id
+    if 'default' in rule:
+        templaterule['default'] = rule['default']
+    return templaterule
+
+
+def _rule_id_from_namespaced(in_rule_id):
+    return in_rule_id.split('/')[-1]
+
+
+def _priority_class_to_template_name(pc):
+    return PRIORITY_CLASS_INVERSE_MAP[pc]

+ 1 - 2
synapse/push/httppusher.py

@@ -23,12 +23,11 @@ logger = logging.getLogger(__name__)
 
 
 class HttpPusher(Pusher):
-    def __init__(self, _hs, profile_tag, user_id, app_id,
+    def __init__(self, _hs, user_id, app_id,
                  app_display_name, device_display_name, pushkey, pushkey_ts,
                  data, last_token, last_success, failing_since):
         super(HttpPusher, self).__init__(
             _hs,
-            profile_tag,
             user_id,
             app_id,
             app_display_name,

+ 7 - 12
synapse/push/push_rule_evaluator.py

@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-import baserules
+from .baserules import list_with_base_rules
 
 import logging
 import simplejson as json
@@ -33,7 +33,7 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
 
 
 @defer.inlineCallbacks
-def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
+def evaluator_for_user_id(user_id, room_id, store):
     rawrules = yield store.get_push_rules_for_user(user_id)
     enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
     our_member_event = yield store.get_current_state(
@@ -43,7 +43,7 @@ def evaluator_for_user_id_and_profile_tag(user_id, profile_tag, room_id, store):
     )
 
     defer.returnValue(PushRuleEvaluator(
-        user_id, profile_tag, rawrules, enabled_map,
+        user_id, rawrules, enabled_map,
         room_id, our_member_event, store
     ))
 
@@ -77,10 +77,9 @@ def _room_member_count(ev, condition, room_member_count):
 class PushRuleEvaluator:
     DEFAULT_ACTIONS = []
 
-    def __init__(self, user_id, profile_tag, raw_rules, enabled_map, room_id,
+    def __init__(self, user_id, raw_rules, enabled_map, room_id,
                  our_member_event, store):
         self.user_id = user_id
-        self.profile_tag = profile_tag
         self.room_id = room_id
         self.our_member_event = our_member_event
         self.store = store
@@ -92,7 +91,7 @@ class PushRuleEvaluator:
             rule['actions'] = json.loads(raw_rule['actions'])
             rules.append(rule)
 
-        self.rules = baserules.list_with_base_rules(rules)
+        self.rules = list_with_base_rules(rules)
 
         self.enabled_map = enabled_map
 
@@ -152,7 +151,7 @@ class PushRuleEvaluator:
             matches = True
             for c in conditions:
                 matches = evaluator.matches(
-                    c, self.user_id, my_display_name, self.profile_tag
+                    c, self.user_id, my_display_name
                 )
                 if not matches:
                     break
@@ -189,13 +188,9 @@ class PushRuleEvaluatorForEvent(object):
         # Maps strings of e.g. 'content.body' -> event["content"]["body"]
         self._value_cache = _flatten_dict(event)
 
-    def matches(self, condition, user_id, display_name, profile_tag):
+    def matches(self, condition, user_id, display_name):
         if condition['kind'] == 'event_match':
             return self._event_match(condition, user_id)
-        elif condition['kind'] == 'device':
-            if 'profile_tag' not in condition:
-                return True
-            return condition['profile_tag'] == profile_tag
         elif condition['kind'] == 'contains_display_name':
             return self._contains_display_name(display_name)
         elif condition['kind'] == 'room_member_count':

+ 25 - 33
synapse/push/pusherpool.py

@@ -16,7 +16,7 @@
 
 from twisted.internet import defer
 
-from httppusher import HttpPusher
+from .httppusher import HttpPusher
 from synapse.push import PusherConfigException
 from synapse.util.logcontext import preserve_fn
 
@@ -29,6 +29,7 @@ class PusherPool:
     def __init__(self, _hs):
         self.hs = _hs
         self.store = self.hs.get_datastore()
+        self.clock = self.hs.get_clock()
         self.pushers = {}
         self.last_pusher_started = -1
 
@@ -38,8 +39,11 @@ class PusherPool:
         self._start_pushers(pushers)
 
     @defer.inlineCallbacks
-    def add_pusher(self, user_id, access_token, profile_tag, kind, app_id,
-                   app_display_name, device_display_name, pushkey, lang, data):
+    def add_pusher(self, user_id, access_token, kind, app_id,
+                   app_display_name, device_display_name, pushkey, lang, data,
+                   profile_tag=""):
+        time_now_msec = self.clock.time_msec()
+
         # we try to create the pusher just to validate the config: it
         # will then get pulled out of the database,
         # recreated, added and started: this means we have only one
@@ -47,23 +51,31 @@ class PusherPool:
         self._create_pusher({
             "user_name": user_id,
             "kind": kind,
-            "profile_tag": profile_tag,
             "app_id": app_id,
             "app_display_name": app_display_name,
             "device_display_name": device_display_name,
             "pushkey": pushkey,
-            "ts": self.hs.get_clock().time_msec(),
+            "ts": time_now_msec,
             "lang": lang,
             "data": data,
             "last_token": None,
             "last_success": None,
             "failing_since": None
         })
-        yield self._add_pusher_to_store(
-            user_id, access_token, profile_tag, kind, app_id,
-            app_display_name, device_display_name,
-            pushkey, lang, data
+        yield self.store.add_pusher(
+            user_id=user_id,
+            access_token=access_token,
+            kind=kind,
+            app_id=app_id,
+            app_display_name=app_display_name,
+            device_display_name=device_display_name,
+            pushkey=pushkey,
+            pushkey_ts=time_now_msec,
+            lang=lang,
+            data=data,
+            profile_tag=profile_tag,
         )
+        yield self._refresh_pusher(app_id, pushkey, user_id)
 
     @defer.inlineCallbacks
     def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
@@ -80,44 +92,24 @@ class PusherPool:
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
     @defer.inlineCallbacks
-    def remove_pushers_by_user(self, user_id):
+    def remove_pushers_by_user(self, user_id, except_token_ids=[]):
         all = yield self.store.get_all_pushers()
         logger.info(
-            "Removing all pushers for user %s",
-            user_id,
+            "Removing all pushers for user %s except access tokens ids %r",
+            user_id, except_token_ids
         )
         for p in all:
-            if p['user_name'] == user_id:
+            if p['user_name'] == user_id and p['access_token'] not in except_token_ids:
                 logger.info(
                     "Removing pusher for app id %s, pushkey %s, user %s",
                     p['app_id'], p['pushkey'], p['user_name']
                 )
                 yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
 
-    @defer.inlineCallbacks
-    def _add_pusher_to_store(self, user_id, access_token, profile_tag, kind,
-                             app_id, app_display_name, device_display_name,
-                             pushkey, lang, data):
-        yield self.store.add_pusher(
-            user_id=user_id,
-            access_token=access_token,
-            profile_tag=profile_tag,
-            kind=kind,
-            app_id=app_id,
-            app_display_name=app_display_name,
-            device_display_name=device_display_name,
-            pushkey=pushkey,
-            pushkey_ts=self.hs.get_clock().time_msec(),
-            lang=lang,
-            data=data,
-        )
-        yield self._refresh_pusher(app_id, pushkey, user_id)
-
     def _create_pusher(self, pusherdict):
         if pusherdict['kind'] == 'http':
             return HttpPusher(
                 self.hs,
-                profile_tag=pusherdict['profile_tag'],
                 user_id=pusherdict['user_name'],
                 app_id=pusherdict['app_id'],
                 app_display_name=pusherdict['app_display_name'],

+ 2 - 2
synapse/python_dependencies.py

@@ -19,7 +19,7 @@ logger = logging.getLogger(__name__)
 
 REQUIREMENTS = {
     "frozendict>=0.4": ["frozendict"],
-    "unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
+    "unpaddedbase64>=1.1.0": ["unpaddedbase64>=1.1.0"],
     "canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
     "signedjson>=1.0.0": ["signedjson>=1.0.0"],
     "pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
@@ -34,7 +34,7 @@ REQUIREMENTS = {
     "pydenticon": ["pydenticon"],
     "ujson": ["ujson"],
     "blist": ["blist"],
-    "pysaml2": ["saml2"],
+    "pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
     "pymacaroons-pynacl": ["pymacaroons"],
 }
 CONDITIONAL_REQUIREMENTS = {

+ 14 - 0
synapse/replication/__init__.py

@@ -0,0 +1,14 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 367 - 0
synapse/replication/resource.py

@@ -0,0 +1,367 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.http.servlet import parse_integer, parse_string
+from synapse.http.server import request_handler, finish_request
+
+from twisted.web.resource import Resource
+from twisted.web.server import NOT_DONE_YET
+from twisted.internet import defer
+
+import ujson as json
+
+import collections
+import logging
+
+logger = logging.getLogger(__name__)
+
+REPLICATION_PREFIX = "/_synapse/replication"
+
+STREAM_NAMES = (
+    ("events",),
+    ("presence",),
+    ("typing",),
+    ("receipts",),
+    ("user_account_data", "room_account_data", "tag_account_data",),
+    ("backfill",),
+    ("push_rules",),
+    ("pushers",),
+)
+
+
+class ReplicationResource(Resource):
+    """
+    HTTP endpoint for extracting data from synapse.
+
+    The streams of data returned by the endpoint are controlled by the
+    parameters given to the API. To return a given stream pass a query
+    parameter with a position in the stream to return data from or the
+    special value "-1" to return data from the start of the stream.
+
+    If there is no data for any of the supplied streams after the given
+    position then the request will block until there is data for one
+    of the streams. This allows clients to long-poll this API.
+
+    The possible streams are:
+
+    * "streams": A special stream returing the positions of other streams.
+    * "events": The new events seen on the server.
+    * "presence": Presence updates.
+    * "typing": Typing updates.
+    * "receipts": Receipt updates.
+    * "user_account_data": Top-level per user account data.
+    * "room_account_data: Per room per user account data.
+    * "tag_account_data": Per room per user tags.
+    * "backfill": Old events that have been backfilled from other servers.
+    * "push_rules": Per user changes to push rules.
+    * "pushers": Per user changes to their pushers.
+
+    The API takes two additional query parameters:
+
+    * "timeout": How long to wait before returning an empty response.
+    * "limit": The maximum number of rows to return for the selected streams.
+
+    The response is a JSON object with keys for each stream with updates. Under
+    each key is a JSON object with:
+
+    * "postion": The current position of the stream.
+    * "field_names": The names of the fields in each row.
+    * "rows": The updates as an array of arrays.
+
+    There are a number of ways this API could be used:
+
+    1) To replicate the contents of the backing database to another database.
+    2) To be notified when the contents of a shared backing database changes.
+    3) To "tail" the activity happening on a server for debugging.
+
+    In the first case the client would track all of the streams and store it's
+    own copy of the data.
+
+    In the second case the client might theoretically just be able to follow
+    the "streams" stream to track where the other streams are. However in
+    practise it will probably need to get the contents of the streams in
+    order to expire the any in-memory caches. Whether it gets the contents
+    of the streams from this replication API or directly from the backing
+    store is a matter of taste.
+
+    In the third case the client would use the "streams" stream to find what
+    streams are available and their current positions. Then it can start
+    long-polling this replication API for new data on those streams.
+    """
+
+    isLeaf = True
+
+    def __init__(self, hs):
+        Resource.__init__(self)  # Resource is old-style, so no super()
+
+        self.version_string = hs.version_string
+        self.store = hs.get_datastore()
+        self.sources = hs.get_event_sources()
+        self.presence_handler = hs.get_handlers().presence_handler
+        self.typing_handler = hs.get_handlers().typing_notification_handler
+        self.notifier = hs.notifier
+
+    def render_GET(self, request):
+        self._async_render_GET(request)
+        return NOT_DONE_YET
+
+    @defer.inlineCallbacks
+    def current_replication_token(self):
+        stream_token = yield self.sources.get_current_token()
+        backfill_token = yield self.store.get_current_backfill_token()
+        push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
+        pushers_token = self.store.get_pushers_stream_token()
+
+        defer.returnValue(_ReplicationToken(
+            room_stream_token,
+            int(stream_token.presence_key),
+            int(stream_token.typing_key),
+            int(stream_token.receipt_key),
+            int(stream_token.account_data_key),
+            backfill_token,
+            push_rules_token,
+            pushers_token,
+        ))
+
+    @request_handler
+    @defer.inlineCallbacks
+    def _async_render_GET(self, request):
+        limit = parse_integer(request, "limit", 100)
+        timeout = parse_integer(request, "timeout", 10 * 1000)
+
+        request.setHeader(b"Content-Type", b"application/json")
+        writer = _Writer(request)
+
+        @defer.inlineCallbacks
+        def replicate():
+            current_token = yield self.current_replication_token()
+            logger.info("Replicating up to %r", current_token)
+
+            yield self.account_data(writer, current_token, limit)
+            yield self.events(writer, current_token, limit)
+            yield self.presence(writer, current_token)  # TODO: implement limit
+            yield self.typing(writer, current_token)  # TODO: implement limit
+            yield self.receipts(writer, current_token, limit)
+            yield self.push_rules(writer, current_token, limit)
+            yield self.pushers(writer, current_token, limit)
+            self.streams(writer, current_token)
+
+            logger.info("Replicated %d rows", writer.total)
+            defer.returnValue(writer.total)
+
+        yield self.notifier.wait_for_replication(replicate, timeout)
+
+        writer.finish()
+
+    def streams(self, writer, current_token):
+        request_token = parse_string(writer.request, "streams")
+
+        streams = []
+
+        if request_token is not None:
+            if request_token == "-1":
+                for names, position in zip(STREAM_NAMES, current_token):
+                    streams.extend((name, position) for name in names)
+            else:
+                items = zip(
+                    STREAM_NAMES,
+                    current_token,
+                    _ReplicationToken(request_token)
+                )
+                for names, current_id, last_id in items:
+                    if last_id < current_id:
+                        streams.extend((name, current_id) for name in names)
+
+            if streams:
+                writer.write_header_and_rows(
+                    "streams", streams, ("name", "position"),
+                    position=str(current_token)
+                )
+
+    @defer.inlineCallbacks
+    def events(self, writer, current_token, limit):
+        request_events = parse_integer(writer.request, "events")
+        request_backfill = parse_integer(writer.request, "backfill")
+
+        if request_events is not None or request_backfill is not None:
+            if request_events is None:
+                request_events = current_token.events
+            if request_backfill is None:
+                request_backfill = current_token.backfill
+            events_rows, backfill_rows = yield self.store.get_all_new_events(
+                request_backfill, request_events,
+                current_token.backfill, current_token.events,
+                limit
+            )
+            writer.write_header_and_rows(
+                "events", events_rows, ("position", "internal", "json")
+            )
+            writer.write_header_and_rows(
+                "backfill", backfill_rows, ("position", "internal", "json")
+            )
+
+    @defer.inlineCallbacks
+    def presence(self, writer, current_token):
+        current_position = current_token.presence
+
+        request_presence = parse_integer(writer.request, "presence")
+
+        if request_presence is not None:
+            presence_rows = yield self.presence_handler.get_all_presence_updates(
+                request_presence, current_position
+            )
+            writer.write_header_and_rows("presence", presence_rows, (
+                "position", "user_id", "state", "last_active_ts",
+                "last_federation_update_ts", "last_user_sync_ts",
+                "status_msg", "currently_active",
+            ))
+
+    @defer.inlineCallbacks
+    def typing(self, writer, current_token):
+        current_position = current_token.presence
+
+        request_typing = parse_integer(writer.request, "typing")
+
+        if request_typing is not None:
+            typing_rows = yield self.typing_handler.get_all_typing_updates(
+                request_typing, current_position
+            )
+            writer.write_header_and_rows("typing", typing_rows, (
+                "position", "room_id", "typing"
+            ))
+
+    @defer.inlineCallbacks
+    def receipts(self, writer, current_token, limit):
+        current_position = current_token.receipts
+
+        request_receipts = parse_integer(writer.request, "receipts")
+
+        if request_receipts is not None:
+            receipts_rows = yield self.store.get_all_updated_receipts(
+                request_receipts, current_position, limit
+            )
+            writer.write_header_and_rows("receipts", receipts_rows, (
+                "position", "room_id", "receipt_type", "user_id", "event_id", "data"
+            ))
+
+    @defer.inlineCallbacks
+    def account_data(self, writer, current_token, limit):
+        current_position = current_token.account_data
+
+        user_account_data = parse_integer(writer.request, "user_account_data")
+        room_account_data = parse_integer(writer.request, "room_account_data")
+        tag_account_data = parse_integer(writer.request, "tag_account_data")
+
+        if user_account_data is not None or room_account_data is not None:
+            if user_account_data is None:
+                user_account_data = current_position
+            if room_account_data is None:
+                room_account_data = current_position
+            user_rows, room_rows = yield self.store.get_all_updated_account_data(
+                user_account_data, room_account_data, current_position, limit
+            )
+            writer.write_header_and_rows("user_account_data", user_rows, (
+                "position", "user_id", "type", "content"
+            ))
+            writer.write_header_and_rows("room_account_data", room_rows, (
+                "position", "user_id", "room_id", "type", "content"
+            ))
+
+        if tag_account_data is not None:
+            tag_rows = yield self.store.get_all_updated_tags(
+                tag_account_data, current_position, limit
+            )
+            writer.write_header_and_rows("tag_account_data", tag_rows, (
+                "position", "user_id", "room_id", "tags"
+            ))
+
+    @defer.inlineCallbacks
+    def push_rules(self, writer, current_token, limit):
+        current_position = current_token.push_rules
+
+        push_rules = parse_integer(writer.request, "push_rules")
+
+        if push_rules is not None:
+            rows = yield self.store.get_all_push_rule_updates(
+                push_rules, current_position, limit
+            )
+            writer.write_header_and_rows("push_rules", rows, (
+                "position", "event_stream_ordering", "user_id", "rule_id", "op",
+                "priority_class", "priority", "conditions", "actions"
+            ))
+
+    @defer.inlineCallbacks
+    def pushers(self, writer, current_token, limit):
+        current_position = current_token.pushers
+
+        pushers = parse_integer(writer.request, "pushers")
+        if pushers is not None:
+            updated, deleted = yield self.store.get_all_updated_pushers(
+                pushers, current_position, limit
+            )
+            writer.write_header_and_rows("pushers", updated, (
+                "position", "user_id", "access_token", "profile_tag", "kind",
+                "app_id", "app_display_name", "device_display_name", "pushkey",
+                "ts", "lang", "data"
+            ))
+            writer.write_header_and_rows("deleted", deleted, (
+                "position", "user_id", "app_id", "pushkey"
+            ))
+
+
+class _Writer(object):
+    """Writes the streams as a JSON object as the response to the request"""
+    def __init__(self, request):
+        self.streams = {}
+        self.request = request
+        self.total = 0
+
+    def write_header_and_rows(self, name, rows, fields, position=None):
+        if not rows:
+            return
+
+        if position is None:
+            position = rows[-1][0]
+
+        self.streams[name] = {
+            "position": str(position),
+            "field_names": fields,
+            "rows": rows,
+        }
+
+        self.total += len(rows)
+
+    def finish(self):
+        self.request.write(json.dumps(self.streams, ensure_ascii=False))
+        finish_request(self.request)
+
+
+class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
+    "events", "presence", "typing", "receipts", "account_data", "backfill",
+    "push_rules", "pushers"
+))):
+    __slots__ = []
+
+    def __new__(cls, *args):
+        if len(args) == 1:
+            streams = [int(value) for value in args[0].split("_")]
+            if len(streams) < len(cls._fields):
+                streams.extend([0] * (len(cls._fields) - len(streams)))
+            return cls(*streams)
+        else:
+            return super(_ReplicationToken, cls).__new__(cls, *args)
+
+    def __str__(self):
+        return "_".join(str(value) for value in self)

+ 2 - 0
synapse/rest/__init__.py

@@ -30,6 +30,7 @@ from synapse.rest.client.v1 import (
     push_rule,
     register as v1_register,
     login as v1_login,
+    logout,
 )
 
 from synapse.rest.client.v2_alpha import (
@@ -72,6 +73,7 @@ class ClientRestResource(JsonResource):
         admin.register_servlets(hs, client_resource)
         pusher.register_servlets(hs, client_resource)
         push_rule.register_servlets(hs, client_resource)
+        logout.register_servlets(hs, client_resource)
 
         # "v2"
         sync.register_servlets(hs, client_resource)

+ 1 - 1
synapse/rest/client/v1/admin.py

@@ -18,7 +18,7 @@ from twisted.internet import defer
 from synapse.api.errors import AuthError, SynapseError
 from synapse.types import UserID
 
-from base import ClientV1RestServlet, client_path_patterns
+from .base import ClientV1RestServlet, client_path_patterns
 
 import logging
 

+ 50 - 16
synapse/rest/client/v1/directory.py

@@ -18,9 +18,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import AuthError, SynapseError, Codes
 from synapse.types import RoomAlias
+from synapse.http.servlet import parse_json_object_from_request
+
 from .base import ClientV1RestServlet, client_path_patterns
 
-import simplejson as json
 import logging
 
 
@@ -29,6 +30,7 @@ logger = logging.getLogger(__name__)
 
 def register_servlets(hs, http_server):
     ClientDirectoryServer(hs).register(http_server)
+    ClientDirectoryListServer(hs).register(http_server)
 
 
 class ClientDirectoryServer(ClientV1RestServlet):
@@ -45,7 +47,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_alias):
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
         if "room_id" not in content:
             raise SynapseError(400, "Missing room_id key",
                                errcode=Codes.BAD_JSON)
@@ -75,7 +77,11 @@ class ClientDirectoryServer(ClientV1RestServlet):
                 yield dir_handler.create_association(
                     user_id, room_alias, room_id, servers
                 )
-                yield dir_handler.send_room_alias_update_event(user_id, room_id)
+                yield dir_handler.send_room_alias_update_event(
+                    requester,
+                    user_id,
+                    room_id
+                )
             except SynapseError as e:
                 raise e
             except:
@@ -118,15 +124,13 @@ class ClientDirectoryServer(ClientV1RestServlet):
 
         requester = yield self.auth.get_user_by_req(request)
         user = requester.user
-        is_admin = yield self.auth.is_server_admin(user)
-        if not is_admin:
-            raise AuthError(403, "You need to be a server admin")
 
         room_alias = RoomAlias.from_string(room_alias)
 
         yield dir_handler.delete_association(
-            user.to_string(), room_alias
+            requester, user.to_string(), room_alias
         )
+
         logger.info(
             "User %s deleted alias %s",
             user.to_string(),
@@ -136,12 +140,42 @@ class ClientDirectoryServer(ClientV1RestServlet):
         defer.returnValue((200, {}))
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+class ClientDirectoryListServer(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/directory/list/room/(?P<room_id>[^/]*)$")
+
+    def __init__(self, hs):
+        super(ClientDirectoryListServer, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def on_GET(self, request, room_id):
+        room = yield self.store.get_room(room_id)
+        if room is None:
+            raise SynapseError(400, "Unknown room")
+
+        defer.returnValue((200, {
+            "visibility": "public" if room["is_public"] else "private"
+        }))
+
+    @defer.inlineCallbacks
+    def on_PUT(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+
+        content = parse_json_object_from_request(request)
+        visibility = content.get("visibility", "public")
+
+        yield self.handlers.directory_handler.edit_published_room_list(
+            requester, room_id, visibility,
+        )
+
+        defer.returnValue((200, {}))
+
+    @defer.inlineCallbacks
+    def on_DELETE(self, request, room_id):
+        requester = yield self.auth.get_user_by_req(request)
+
+        yield self.handlers.directory_handler.edit_published_room_list(
+            requester, room_id, "private",
+        )
+
+        defer.returnValue((200, {}))

+ 1 - 1
synapse/rest/client/v1/initial_sync.py

@@ -16,7 +16,7 @@
 from twisted.internet import defer
 
 from synapse.streams.config import PaginationConfig
-from base import ClientV1RestServlet, client_path_patterns
+from .base import ClientV1RestServlet, client_path_patterns
 
 
 # TODO: Needs unit testing

+ 10 - 17
synapse/rest/client/v1/login.py

@@ -17,7 +17,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, LoginError, Codes
 from synapse.types import UserID
-from base import ClientV1RestServlet, client_path_patterns
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_json_object_from_request
+
+from .base import ClientV1RestServlet, client_path_patterns
 
 import simplejson as json
 import urllib
@@ -77,7 +80,7 @@ class LoginRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        login_submission = _parse_json(request)
+        login_submission = parse_json_object_from_request(request)
         try:
             if login_submission["type"] == LoginRestServlet.PASS_TYPE:
                 if not self.password_enabled:
@@ -250,7 +253,7 @@ class SAML2RestServlet(ClientV1RestServlet):
             SP = Saml2Client(conf)
             saml2_auth = SP.parse_authn_request_response(
                 request.args['SAMLResponse'][0], BINDING_HTTP_POST)
-        except Exception, e:        # Not authenticated
+        except Exception as e:        # Not authenticated
             logger.exception(e)
         if saml2_auth and saml2_auth.status_ok() and not saml2_auth.not_signed:
             username = saml2_auth.name_id.text
@@ -263,7 +266,7 @@ class SAML2RestServlet(ClientV1RestServlet):
                                  '?status=authenticated&access_token=' +
                                  token + '&user_id=' + user_id + '&ava=' +
                                  urllib.quote(json.dumps(saml2_auth.ava)))
-                request.finish()
+                finish_request(request)
                 defer.returnValue(None)
             defer.returnValue((200, {"status": "authenticated",
                                      "user_id": user_id, "token": token,
@@ -272,7 +275,7 @@ class SAML2RestServlet(ClientV1RestServlet):
             request.redirect(urllib.unquote(
                              request.args['RelayState'][0]) +
                              '?status=not_authenticated')
-            request.finish()
+            finish_request(request)
             defer.returnValue(None)
         defer.returnValue((200, {"status": "not_authenticated"}))
 
@@ -309,7 +312,7 @@ class CasRedirectServlet(ClientV1RestServlet):
             "service": "%s?%s" % (hs_redirect_url, client_redirect_url_param)
         })
         request.redirect("%s?%s" % (self.cas_server_url, service_param))
-        request.finish()
+        finish_request(request)
 
 
 class CasTicketServlet(ClientV1RestServlet):
@@ -362,7 +365,7 @@ class CasTicketServlet(ClientV1RestServlet):
         redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
                                                             login_token)
         request.redirect(redirect_url)
-        request.finish()
+        finish_request(request)
 
     def add_login_token_to_redirect_url(self, url, token):
         url_parts = list(urlparse.urlparse(url))
@@ -398,16 +401,6 @@ class CasTicketServlet(ClientV1RestServlet):
         return (user, attributes)
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.")
-
-
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.saml2_enabled:

+ 72 - 0
synapse/rest/client/v1/logout.py

@@ -0,0 +1,72 @@
+# -*- coding: utf-8 -*-
+# Copyright 2016 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+from synapse.api.errors import AuthError, Codes
+
+from .base import ClientV1RestServlet, client_path_patterns
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class LogoutRestServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/logout$")
+
+    def __init__(self, hs):
+        super(LogoutRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+
+    def on_OPTIONS(self, request):
+        return (200, {})
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        try:
+            access_token = request.args["access_token"][0]
+        except KeyError:
+            raise AuthError(
+                self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
+                errcode=Codes.MISSING_TOKEN
+            )
+        yield self.store.delete_access_token(access_token)
+        defer.returnValue((200, {}))
+
+
+class LogoutAllRestServlet(ClientV1RestServlet):
+    PATTERNS = client_path_patterns("/logout/all$")
+
+    def __init__(self, hs):
+        super(LogoutAllRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.auth = hs.get_auth()
+
+    def on_OPTIONS(self, request):
+        return (200, {})
+
+    @defer.inlineCallbacks
+    def on_POST(self, request):
+        requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
+        yield self.store.user_delete_access_tokens(user_id)
+        defer.returnValue((200, {}))
+
+
+def register_servlets(hs, http_server):
+    LogoutRestServlet(hs).register(http_server)
+    LogoutAllRestServlet(hs).register(http_server)

+ 21 - 18
synapse/rest/client/v1/presence.py

@@ -17,11 +17,11 @@
 """
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import SynapseError, AuthError
 from synapse.types import UserID
+from synapse.http.servlet import parse_json_object_from_request
 from .base import ClientV1RestServlet, client_path_patterns
 
-import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -35,8 +35,15 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
-        state = yield self.handlers.presence_handler.get_state(
-            target_user=user, auth_user=requester.user)
+        if requester.user != user:
+            allowed = yield self.handlers.presence_handler.is_visible(
+                observed_user=user, observer_user=requester.user,
+            )
+
+            if not allowed:
+                raise AuthError(403, "You are not allowed to see their presence.")
+
+        state = yield self.handlers.presence_handler.get_state(target_user=user)
 
         defer.returnValue((200, state))
 
@@ -45,10 +52,14 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
+        if requester.user != user:
+            raise AuthError(403, "Can only set your own presence state")
+
         state = {}
-        try:
-            content = json.loads(request.content.read())
 
+        content = parse_json_object_from_request(request)
+
+        try:
             state["presence"] = content.pop("presence")
 
             if "status_msg" in content:
@@ -63,8 +74,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
         except:
             raise SynapseError(400, "Unable to parse state")
 
-        yield self.handlers.presence_handler.set_state(
-            target_user=user, auth_user=requester.user, state=state)
+        yield self.handlers.presence_handler.set_state(user, state)
 
         defer.returnValue((200, {}))
 
@@ -87,11 +97,8 @@ class PresenceListRestServlet(ClientV1RestServlet):
             raise SynapseError(400, "Cannot get another user's presence list")
 
         presence = yield self.handlers.presence_handler.get_presence_list(
-            observer_user=user, accepted=True)
-
-        for p in presence:
-            observed_user = p.pop("observed_user")
-            p["user_id"] = observed_user.to_string()
+            observer_user=user, accepted=True
+        )
 
         defer.returnValue((200, presence))
 
@@ -107,11 +114,7 @@ class PresenceListRestServlet(ClientV1RestServlet):
             raise SynapseError(
                 400, "Cannot modify another user's presence list")
 
-        try:
-            content = json.loads(request.content.read())
-        except:
-            logger.exception("JSON parse error")
-            raise SynapseError(400, "Unable to parse content")
+        content = parse_json_object_from_request(request)
 
         if "invite" in content:
             for u in content["invite"]:

+ 6 - 6
synapse/rest/client/v1/profile.py

@@ -18,8 +18,7 @@ from twisted.internet import defer
 
 from .base import ClientV1RestServlet, client_path_patterns
 from synapse.types import UserID
-
-import simplejson as json
+from synapse.http.servlet import parse_json_object_from_request
 
 
 class ProfileDisplaynameRestServlet(ClientV1RestServlet):
@@ -44,14 +43,15 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
         user = UserID.from_string(user_id)
 
+        content = parse_json_object_from_request(request)
+
         try:
-            content = json.loads(request.content.read())
             new_name = content["displayname"]
         except:
             defer.returnValue((400, "Unable to parse name"))
 
         yield self.handlers.profile_handler.set_displayname(
-            user, requester.user, new_name)
+            user, requester, new_name)
 
         defer.returnValue((200, {}))
 
@@ -81,14 +81,14 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
         requester = yield self.auth.get_user_by_req(request)
         user = UserID.from_string(user_id)
 
+        content = parse_json_object_from_request(request)
         try:
-            content = json.loads(request.content.read())
             new_name = content["avatar_url"]
         except:
             defer.returnValue((400, "Unable to parse name"))
 
         yield self.handlers.profile_handler.set_avatar_url(
-            user, requester.user, new_name)
+            user, requester, new_name)
 
         defer.returnValue((200, {}))
 

+ 65 - 206
synapse/rest/client/v1/push_rule.py

@@ -16,19 +16,16 @@
 from twisted.internet import defer
 
 from synapse.api.errors import (
-    SynapseError, Codes, UnrecognizedRequestError, NotFoundError, StoreError
+    SynapseError, UnrecognizedRequestError, NotFoundError, StoreError
 )
 from .base import ClientV1RestServlet, client_path_patterns
 from synapse.storage.push_rule import (
     InconsistentRuleException, RuleNotFoundException
 )
-import synapse.push.baserules as baserules
-from synapse.push.rulekinds import (
-    PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
-)
-
-import copy
-import simplejson as json
+from synapse.push.clientformat import format_push_rules_for_user
+from synapse.push.baserules import BASE_RULE_IDS
+from synapse.push.rulekinds import PRIORITY_CLASS_MAP
+from synapse.http.servlet import parse_json_value_from_request
 
 
 class PushRuleRestServlet(ClientV1RestServlet):
@@ -36,6 +33,11 @@ class PushRuleRestServlet(ClientV1RestServlet):
     SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = (
         "Unrecognised request: You probably wanted a trailing slash")
 
+    def __init__(self, hs):
+        super(PushRuleRestServlet, self).__init__(hs)
+        self.store = hs.get_datastore()
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_PUT(self, request):
         spec = _rule_spec_from_path(request.postpath)
@@ -49,18 +51,24 @@ class PushRuleRestServlet(ClientV1RestServlet):
         if '/' in spec['rule_id'] or '\\' in spec['rule_id']:
             raise SynapseError(400, "rule_id may not contain slashes")
 
-        content = _parse_json(request)
+        content = parse_json_value_from_request(request)
+
+        user_id = requester.user.to_string()
 
         if 'attr' in spec:
-            yield self.set_rule_attr(requester.user.to_string(), spec, content)
+            yield self.set_rule_attr(user_id, spec, content)
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
 
+        if spec['rule_id'].startswith('.'):
+            # Rule ids starting with '.' are reserved for server default rules.
+            raise SynapseError(400, "cannot add new rule_ids that start with '.'")
+
         try:
             (conditions, actions) = _rule_tuple_from_request_object(
                 spec['template'],
                 spec['rule_id'],
                 content,
-                device=spec['device'] if 'device' in spec else None
             )
         except InvalidRuleException as e:
             raise SynapseError(400, e.message)
@@ -74,8 +82,8 @@ class PushRuleRestServlet(ClientV1RestServlet):
             after = _namespaced_rule_id(spec, after[0])
 
         try:
-            yield self.hs.get_datastore().add_push_rule(
-                user_id=requester.user.to_string(),
+            yield self.store.add_push_rule(
+                user_id=user_id,
                 rule_id=_namespaced_rule_id_from_spec(spec),
                 priority_class=priority_class,
                 conditions=conditions,
@@ -83,6 +91,7 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 before=before,
                 after=after
             )
+            self.notify_user(user_id)
         except InconsistentRuleException as e:
             raise SynapseError(400, e.message)
         except RuleNotFoundException as e:
@@ -95,13 +104,15 @@ class PushRuleRestServlet(ClientV1RestServlet):
         spec = _rule_spec_from_path(request.postpath)
 
         requester = yield self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
 
         namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
 
         try:
-            yield self.hs.get_datastore().delete_push_rule(
-                requester.user.to_string(), namespaced_rule_id
+            yield self.store.delete_push_rule(
+                user_id, namespaced_rule_id
             )
+            self.notify_user(user_id)
             defer.returnValue((200, {}))
         except StoreError as e:
             if e.code == 404:
@@ -112,74 +123,16 @@ class PushRuleRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_GET(self, request):
         requester = yield self.auth.get_user_by_req(request)
-        user = requester.user
+        user_id = requester.user.to_string()
 
         # we build up the full structure and then decide which bits of it
         # to send which means doing unnecessary work sometimes but is
         # is probably not going to make a whole lot of difference
-        rawrules = yield self.hs.get_datastore().get_push_rules_for_user(
-            user.to_string()
-        )
+        rawrules = yield self.store.get_push_rules_for_user(user_id)
 
-        ruleslist = []
-        for rawrule in rawrules:
-            rule = dict(rawrule)
-            rule["conditions"] = json.loads(rawrule["conditions"])
-            rule["actions"] = json.loads(rawrule["actions"])
-            ruleslist.append(rule)
-
-        # We're going to be mutating this a lot, so do a deep copy
-        ruleslist = copy.deepcopy(baserules.list_with_base_rules(ruleslist))
-
-        rules = {'global': {}, 'device': {}}
-
-        rules['global'] = _add_empty_priority_class_arrays(rules['global'])
-
-        enabled_map = yield self.hs.get_datastore().\
-            get_push_rules_enabled_for_user(user.to_string())
-
-        for r in ruleslist:
-            rulearray = None
-
-            template_name = _priority_class_to_template_name(r['priority_class'])
-
-            # Remove internal stuff.
-            for c in r["conditions"]:
-                c.pop("_id", None)
-
-                pattern_type = c.pop("pattern_type", None)
-                if pattern_type == "user_id":
-                    c["pattern"] = user.to_string()
-                elif pattern_type == "user_localpart":
-                    c["pattern"] = user.localpart
-
-            if r['priority_class'] > PRIORITY_CLASS_MAP['override']:
-                # per-device rule
-                profile_tag = _profile_tag_from_conditions(r["conditions"])
-                r = _strip_device_condition(r)
-                if not profile_tag:
-                    continue
-                if profile_tag not in rules['device']:
-                    rules['device'][profile_tag] = {}
-                    rules['device'][profile_tag] = (
-                        _add_empty_priority_class_arrays(
-                            rules['device'][profile_tag]
-                        )
-                    )
-
-                rulearray = rules['device'][profile_tag][template_name]
-            else:
-                rulearray = rules['global'][template_name]
-
-            template_rule = _rule_to_template(r)
-            if template_rule:
-                if r['rule_id'] in enabled_map:
-                    template_rule['enabled'] = enabled_map[r['rule_id']]
-                elif 'enabled' in r:
-                    template_rule['enabled'] = r['enabled']
-                else:
-                    template_rule['enabled'] = True
-                rulearray.append(template_rule)
+        enabled_map = yield self.store.get_push_rules_enabled_for_user(user_id)
+
+        rules = format_push_rules_for_user(requester.user, rawrules, enabled_map)
 
         path = request.postpath[1:]
 
@@ -195,30 +148,18 @@ class PushRuleRestServlet(ClientV1RestServlet):
             path = path[1:]
             result = _filter_ruleset_with_path(rules['global'], path)
             defer.returnValue((200, result))
-        elif path[0] == 'device':
-            path = path[1:]
-            if path == []:
-                raise UnrecognizedRequestError(
-                    PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR
-                )
-            if path[0] == '':
-                defer.returnValue((200, rules['device']))
-
-            profile_tag = path[0]
-            path = path[1:]
-            if profile_tag not in rules['device']:
-                ret = {}
-                ret = _add_empty_priority_class_arrays(ret)
-                defer.returnValue((200, ret))
-            ruleset = rules['device'][profile_tag]
-            result = _filter_ruleset_with_path(ruleset, path)
-            defer.returnValue((200, result))
         else:
             raise UnrecognizedRequestError()
 
     def on_OPTIONS(self, _):
         return 200, {}
 
+    def notify_user(self, user_id):
+        stream_id, _ = self.store.get_push_rules_stream_token()
+        self.notifier.on_new_event(
+            "push_rules_key", stream_id, users=[user_id]
+        )
+
     def set_rule_attr(self, user_id, spec, val):
         if spec['attr'] == 'enabled':
             if isinstance(val, dict) and "enabled" in val:
@@ -229,16 +170,20 @@ class PushRuleRestServlet(ClientV1RestServlet):
                 # bools directly, so let's not break them.
                 raise SynapseError(400, "Value for 'enabled' must be boolean")
             namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
-            return self.hs.get_datastore().set_push_rule_enabled(
+            return self.store.set_push_rule_enabled(
                 user_id, namespaced_rule_id, val
             )
-        else:
-            raise UnrecognizedRequestError()
-
-    def get_rule_attr(self, user_id, namespaced_rule_id, attr):
-        if attr == 'enabled':
-            return self.hs.get_datastore().get_push_rule_enabled_by_user_rule_id(
-                user_id, namespaced_rule_id
+        elif spec['attr'] == 'actions':
+            actions = val.get('actions')
+            _check_actions(actions)
+            namespaced_rule_id = _namespaced_rule_id_from_spec(spec)
+            rule_id = spec['rule_id']
+            is_default_rule = rule_id.startswith(".")
+            if is_default_rule:
+                if namespaced_rule_id not in BASE_RULE_IDS:
+                    raise SynapseError(404, "Unknown rule %r" % (namespaced_rule_id,))
+            return self.store.set_push_rule_actions(
+                user_id, namespaced_rule_id, actions, is_default_rule
             )
         else:
             raise UnrecognizedRequestError()
@@ -252,16 +197,9 @@ def _rule_spec_from_path(path):
 
     scope = path[1]
     path = path[2:]
-    if scope not in ['global', 'device']:
+    if scope != 'global':
         raise UnrecognizedRequestError()
 
-    device = None
-    if scope == 'device':
-        if len(path) == 0:
-            raise UnrecognizedRequestError()
-        device = path[0]
-        path = path[1:]
-
     if len(path) == 0:
         raise UnrecognizedRequestError()
 
@@ -278,8 +216,6 @@ def _rule_spec_from_path(path):
         'template': template,
         'rule_id': rule_id
     }
-    if device:
-        spec['profile_tag'] = device
 
     path = path[1:]
 
@@ -289,7 +225,7 @@ def _rule_spec_from_path(path):
     return spec
 
 
-def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None):
+def _rule_tuple_from_request_object(rule_template, rule_id, req_obj):
     if rule_template in ['override', 'underride']:
         if 'conditions' not in req_obj:
             raise InvalidRuleException("Missing 'conditions'")
@@ -322,16 +258,19 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
     else:
         raise InvalidRuleException("Unknown rule template: %s" % (rule_template,))
 
-    if device:
-        conditions.append({
-            'kind': 'device',
-            'profile_tag': device
-        })
-
     if 'actions' not in req_obj:
         raise InvalidRuleException("No actions found")
     actions = req_obj['actions']
 
+    _check_actions(actions)
+
+    return conditions, actions
+
+
+def _check_actions(actions):
+    if not isinstance(actions, list):
+        raise InvalidRuleException("No actions found")
+
     for a in actions:
         if a in ['notify', 'dont_notify', 'coalesce']:
             pass
@@ -340,25 +279,6 @@ def _rule_tuple_from_request_object(rule_template, rule_id, req_obj, device=None
         else:
             raise InvalidRuleException("Unrecognised action")
 
-    return conditions, actions
-
-
-def _add_empty_priority_class_arrays(d):
-    for pc in PRIORITY_CLASS_MAP.keys():
-        d[pc] = []
-    return d
-
-
-def _profile_tag_from_conditions(conditions):
-    """
-    Given a list of conditions, return the profile tag of the
-    device rule if there is one
-    """
-    for c in conditions:
-        if c['kind'] == 'device':
-            return c['profile_tag']
-    return None
-
 
 def _filter_ruleset_with_path(ruleset, path):
     if path == []:
@@ -393,93 +313,32 @@ def _filter_ruleset_with_path(ruleset, path):
 
     attr = path[0]
     if attr in the_rule:
-        return the_rule[attr]
+        # Make sure we return a JSON object as the attribute may be a
+        # JSON value.
+        return {attr: the_rule[attr]}
     else:
         raise UnrecognizedRequestError()
 
 
 def _priority_class_from_spec(spec):
     if spec['template'] not in PRIORITY_CLASS_MAP.keys():
-        raise InvalidRuleException("Unknown template: %s" % (spec['kind']))
+        raise InvalidRuleException("Unknown template: %s" % (spec['template']))
     pc = PRIORITY_CLASS_MAP[spec['template']]
 
-    if spec['scope'] == 'device':
-        pc += len(PRIORITY_CLASS_MAP)
-
     return pc
 
 
-def _priority_class_to_template_name(pc):
-    if pc > PRIORITY_CLASS_MAP['override']:
-        # per-device
-        prio_class_index = pc - len(PRIORITY_CLASS_MAP)
-        return PRIORITY_CLASS_INVERSE_MAP[prio_class_index]
-    else:
-        return PRIORITY_CLASS_INVERSE_MAP[pc]
-
-
-def _rule_to_template(rule):
-    unscoped_rule_id = None
-    if 'rule_id' in rule:
-        unscoped_rule_id = _rule_id_from_namespaced(rule['rule_id'])
-
-    template_name = _priority_class_to_template_name(rule['priority_class'])
-    if template_name in ['override', 'underride']:
-        templaterule = {k: rule[k] for k in ["conditions", "actions"]}
-    elif template_name in ["sender", "room"]:
-        templaterule = {'actions': rule['actions']}
-        unscoped_rule_id = rule['conditions'][0]['pattern']
-    elif template_name == 'content':
-        if len(rule["conditions"]) != 1:
-            return None
-        thecond = rule["conditions"][0]
-        if "pattern" not in thecond:
-            return None
-        templaterule = {'actions': rule['actions']}
-        templaterule["pattern"] = thecond["pattern"]
-
-    if unscoped_rule_id:
-            templaterule['rule_id'] = unscoped_rule_id
-    if 'default' in rule:
-        templaterule['default'] = rule['default']
-    return templaterule
-
-
-def _strip_device_condition(rule):
-    for i, c in enumerate(rule['conditions']):
-        if c['kind'] == 'device':
-            del rule['conditions'][i]
-    return rule
-
-
 def _namespaced_rule_id_from_spec(spec):
     return _namespaced_rule_id(spec, spec['rule_id'])
 
 
 def _namespaced_rule_id(spec, rule_id):
-    if spec['scope'] == 'global':
-        scope = 'global'
-    else:
-        scope = 'device/%s' % (spec['profile_tag'])
-    return "%s/%s/%s" % (scope, spec['template'], rule_id)
-
-
-def _rule_id_from_namespaced(in_rule_id):
-    return in_rule_id.split('/')[-1]
+    return "global/%s/%s" % (spec['template'], rule_id)
 
 
 class InvalidRuleException(Exception):
     pass
 
 
-# XXX: C+ped from rest/room.py - surely this should be common?
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     PushRuleRestServlet(hs).register(http_server)

+ 12 - 17
synapse/rest/client/v1/pusher.py

@@ -17,9 +17,10 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.push import PusherConfigException
+from synapse.http.servlet import parse_json_object_from_request
+
 from .base import ClientV1RestServlet, client_path_patterns
 
-import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -28,12 +29,16 @@ logger = logging.getLogger(__name__)
 class PusherRestServlet(ClientV1RestServlet):
     PATTERNS = client_path_patterns("/pushers/set$")
 
+    def __init__(self, hs):
+        super(PusherRestServlet, self).__init__(hs)
+        self.notifier = hs.get_notifier()
+
     @defer.inlineCallbacks
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
         user = requester.user
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         pusher_pool = self.hs.get_pusherpool()
 
@@ -45,7 +50,7 @@ class PusherRestServlet(ClientV1RestServlet):
             )
             defer.returnValue((200, {}))
 
-        reqd = ['profile_tag', 'kind', 'app_id', 'app_display_name',
+        reqd = ['kind', 'app_id', 'app_display_name',
                 'device_display_name', 'pushkey', 'lang', 'data']
         missing = []
         for i in reqd:
@@ -73,36 +78,26 @@ class PusherRestServlet(ClientV1RestServlet):
             yield pusher_pool.add_pusher(
                 user_id=user.to_string(),
                 access_token=requester.access_token_id,
-                profile_tag=content['profile_tag'],
                 kind=content['kind'],
                 app_id=content['app_id'],
                 app_display_name=content['app_display_name'],
                 device_display_name=content['device_display_name'],
                 pushkey=content['pushkey'],
                 lang=content['lang'],
-                data=content['data']
+                data=content['data'],
+                profile_tag=content.get('profile_tag', ""),
             )
         except PusherConfigException as pce:
             raise SynapseError(400, "Config Error: " + pce.message,
                                errcode=Codes.MISSING_PARAM)
 
+        self.notifier.on_new_replication_data()
+
         defer.returnValue((200, {}))
 
     def on_OPTIONS(self, _):
         return 200, {}
 
 
-# XXX: C+ped from rest/room.py - surely this should be common?
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_servlets(hs, http_server):
     PusherRestServlet(hs).register(http_server)

+ 3 - 13
synapse/rest/client/v1/register.py

@@ -18,14 +18,14 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, Codes
 from synapse.api.constants import LoginType
-from base import ClientV1RestServlet, client_path_patterns
+from .base import ClientV1RestServlet, client_path_patterns
 import synapse.util.stringutils as stringutils
+from synapse.http.servlet import parse_json_object_from_request
 
 from synapse.util.async import run_on_reactor
 
 from hashlib import sha1
 import hmac
-import simplejson as json
 import logging
 
 logger = logging.getLogger(__name__)
@@ -98,7 +98,7 @@ class RegisterRestServlet(ClientV1RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        register_json = _parse_json(request)
+        register_json = parse_json_object_from_request(request)
 
         session = (register_json["session"]
                    if "session" in register_json else None)
@@ -355,15 +355,5 @@ class RegisterRestServlet(ClientV1RestServlet):
             )
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.")
-
-
 def register_servlets(hs, http_server):
     RegisterRestServlet(hs).register(http_server)

+ 74 - 100
synapse/rest/client/v1/room.py

@@ -16,14 +16,14 @@
 """ This module contains REST servlets to do with rooms: /rooms/<paths> """
 from twisted.internet import defer
 
-from base import ClientV1RestServlet, client_path_patterns
+from .base import ClientV1RestServlet, client_path_patterns
 from synapse.api.errors import SynapseError, Codes, AuthError
 from synapse.streams.config import PaginationConfig
 from synapse.api.constants import EventTypes, Membership
 from synapse.types import UserID, RoomID, RoomAlias
 from synapse.events.utils import serialize_event
+from synapse.http.servlet import parse_json_object_from_request
 
-import simplejson as json
 import logging
 import urllib
 
@@ -63,35 +63,18 @@ class RoomCreateRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
 
-        room_config = self.get_room_config(request)
-        info = yield self.make_room(
-            room_config,
-            requester.user,
-            None,
-        )
-        room_config.update(info)
-        defer.returnValue((200, info))
-
-    @defer.inlineCallbacks
-    def make_room(self, room_config, auth_user, room_id):
         handler = self.handlers.room_creation_handler
         info = yield handler.create_room(
-            user_id=auth_user.to_string(),
-            room_id=room_id,
-            config=room_config
+            requester, self.get_room_config(request)
         )
-        defer.returnValue(info)
+
+        defer.returnValue((200, info))
 
     def get_room_config(self, request):
-        try:
-            user_supplied_config = json.loads(request.content.read())
-            if "visibility" not in user_supplied_config:
-                # default visibility
-                user_supplied_config["visibility"] = "public"
-            return user_supplied_config
-        except (ValueError, TypeError):
-            raise SynapseError(400, "Body must be JSON.",
-                               errcode=Codes.BAD_JSON)
+        user_supplied_config = parse_json_object_from_request(request)
+        # default visibility
+        user_supplied_config.setdefault("visibility", "public")
+        return user_supplied_config
 
     def on_OPTIONS(self, request):
         return (200, {})
@@ -149,7 +132,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
     def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
         requester = yield self.auth.get_user_by_req(request)
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         event_dict = {
             "type": event_type,
@@ -162,11 +145,22 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
             event_dict["state_key"] = state_key
 
         msg_handler = self.handlers.message_handler
-        yield msg_handler.create_and_send_event(
-            event_dict, token_id=requester.access_token_id, txn_id=txn_id,
+        event, context = yield msg_handler.create_event(
+            event_dict,
+            token_id=requester.access_token_id,
+            txn_id=txn_id,
         )
 
-        defer.returnValue((200, {}))
+        if event_type == EventTypes.Member:
+            yield self.handlers.room_member_handler.send_membership_event(
+                requester,
+                event,
+                context,
+            )
+        else:
+            yield msg_handler.send_nonmember_event(requester, event, context)
+
+        defer.returnValue((200, {"event_id": event.event_id}))
 
 
 # TODO: Needs unit testing for generic events + feedback
@@ -180,17 +174,17 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_type, txn_id=None):
         requester = yield self.auth.get_user_by_req(request, allow_guest=True)
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_event(
+        event = yield msg_handler.create_and_send_nonmember_event(
+            requester,
             {
                 "type": event_type,
                 "content": content,
                 "room_id": room_id,
                 "sender": requester.user.to_string(),
             },
-            token_id=requester.access_token_id,
             txn_id=txn_id,
         )
 
@@ -229,46 +223,37 @@ class JoinRoomAliasServlet(ClientV1RestServlet):
             allow_guest=True,
         )
 
-        # the identifier could be a room alias or a room id. Try one then the
-        # other if it fails to parse, without swallowing other valid
-        # SynapseErrors.
-
-        identifier = None
-        is_room_alias = False
         try:
-            identifier = RoomAlias.from_string(room_identifier)
-            is_room_alias = True
-        except SynapseError:
-            identifier = RoomID.from_string(room_identifier)
-
-        # TODO: Support for specifying the home server to join with?
-
-        if is_room_alias:
+            content = parse_json_object_from_request(request)
+        except:
+            # Turns out we used to ignore the body entirely, and some clients
+            # cheekily send invalid bodies.
+            content = {}
+
+        if RoomID.is_valid(room_identifier):
+            room_id = room_identifier
+            remote_room_hosts = None
+        elif RoomAlias.is_valid(room_identifier):
             handler = self.handlers.room_member_handler
-            ret_dict = yield handler.join_room_alias(
-                requester.user,
-                identifier,
-            )
-            defer.returnValue((200, ret_dict))
-        else:  # room id
-            msg_handler = self.handlers.message_handler
-            content = {"membership": Membership.JOIN}
-            if requester.is_guest:
-                content["kind"] = "guest"
-            yield msg_handler.create_and_send_event(
-                {
-                    "type": EventTypes.Member,
-                    "content": content,
-                    "room_id": identifier.to_string(),
-                    "sender": requester.user.to_string(),
-                    "state_key": requester.user.to_string(),
-                },
-                token_id=requester.access_token_id,
-                txn_id=txn_id,
-                is_guest=requester.is_guest,
-            )
+            room_alias = RoomAlias.from_string(room_identifier)
+            room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+            room_id = room_id.to_string()
+        else:
+            raise SynapseError(400, "%s was not legal room ID or room alias" % (
+                room_identifier,
+            ))
+
+        yield self.handlers.room_member_handler.update_membership(
+            requester=requester,
+            target=requester.user,
+            room_id=room_id,
+            action="join",
+            txn_id=txn_id,
+            remote_room_hosts=remote_room_hosts,
+            third_party_signed=content.get("third_party_signed", None),
+        )
 
-            defer.returnValue((200, {"room_id": identifier.to_string()}))
+        defer.returnValue((200, {"room_id": room_id}))
 
     @defer.inlineCallbacks
     def on_PUT(self, request, room_identifier, txn_id):
@@ -316,18 +301,6 @@ class RoomMemberListRestServlet(ClientV1RestServlet):
             if event["type"] != EventTypes.Member:
                 continue
             chunk.append(event)
-            # FIXME: should probably be state_key here, not user_id
-            target_user = UserID.from_string(event["user_id"])
-            # Presence is an optional cache; don't fail if we can't fetch it
-            try:
-                presence_handler = self.handlers.presence_handler
-                presence_state = yield presence_handler.get_state(
-                    target_user=target_user,
-                    auth_user=requester.user,
-                )
-                event["content"].update(presence_state)
-            except:
-                pass
 
         defer.returnValue((200, {
             "chunk": chunk
@@ -454,7 +427,12 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
         }:
             raise AuthError(403, "Guest access not allowed")
 
-        content = _parse_json(request)
+        try:
+            content = parse_json_object_from_request(request)
+        except:
+            # Turns out we used to ignore the body entirely, and some clients
+            # cheekily send invalid bodies.
+            content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
             yield self.handlers.room_member_handler.do_3pid_invite(
@@ -463,7 +441,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
                 content["medium"],
                 content["address"],
                 content["id_server"],
-                requester.access_token_id,
+                requester,
                 txn_id
             )
             defer.returnValue((200, {}))
@@ -481,6 +459,7 @@ class RoomMembershipRestServlet(ClientV1RestServlet):
             room_id=room_id,
             action=membership_action,
             txn_id=txn_id,
+            third_party_signed=content.get("third_party_signed", None),
         )
 
         defer.returnValue((200, {}))
@@ -516,10 +495,11 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, event_id, txn_id=None):
         requester = yield self.auth.get_user_by_req(request)
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         msg_handler = self.handlers.message_handler
-        event = yield msg_handler.create_and_send_event(
+        event = yield msg_handler.create_and_send_nonmember_event(
+            requester,
             {
                 "type": EventTypes.Redaction,
                 "content": content,
@@ -527,7 +507,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
                 "sender": requester.user.to_string(),
                 "redacts": event_id,
             },
-            token_id=requester.access_token_id,
             txn_id=txn_id,
         )
 
@@ -553,6 +532,10 @@ class RoomTypingRestServlet(ClientV1RestServlet):
         "/rooms/(?P<room_id>[^/]*)/typing/(?P<user_id>[^/]*)$"
     )
 
+    def __init__(self, hs):
+        super(RoomTypingRestServlet, self).__init__(hs)
+        self.presence_handler = hs.get_handlers().presence_handler
+
     @defer.inlineCallbacks
     def on_PUT(self, request, room_id, user_id):
         requester = yield self.auth.get_user_by_req(request)
@@ -560,10 +543,12 @@ class RoomTypingRestServlet(ClientV1RestServlet):
         room_id = urllib.unquote(room_id)
         target_user = UserID.from_string(urllib.unquote(user_id))
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         typing_handler = self.handlers.typing_notification_handler
 
+        yield self.presence_handler.bump_presence_active_time(requester.user)
+
         if content["typing"]:
             yield typing_handler.started_typing(
                 target_user=target_user,
@@ -590,7 +575,7 @@ class SearchRestServlet(ClientV1RestServlet):
     def on_POST(self, request):
         requester = yield self.auth.get_user_by_req(request)
 
-        content = _parse_json(request)
+        content = parse_json_object_from_request(request)
 
         batch = request.args.get("next_batch", [None])[0]
         results = yield self.handlers.search_handler.search(
@@ -602,17 +587,6 @@ class SearchRestServlet(ClientV1RestServlet):
         defer.returnValue((200, results))
 
 
-def _parse_json(request):
-    try:
-        content = json.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.",
-                               errcode=Codes.NOT_JSON)
-        return content
-    except ValueError:
-        raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
-
-
 def register_txn_path(servlet, regex_string, http_server, with_get=False):
     """Registers a transaction-based path.
 

+ 1 - 1
synapse/rest/client/v1/voip.py

@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from base import ClientV1RestServlet, client_path_patterns
+from .base import ClientV1RestServlet, client_path_patterns
 
 
 import hmac

+ 0 - 22
synapse/rest/client/v2_alpha/_base.py

@@ -17,11 +17,9 @@
 """
 
 from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
-from synapse.api.errors import SynapseError
 import re
 
 import logging
-import simplejson
 
 
 logger = logging.getLogger(__name__)
@@ -44,23 +42,3 @@ def client_v2_patterns(path_regex, releases=(0,)):
         new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release)
         patterns.append(re.compile("^" + new_prefix + path_regex))
     return patterns
-
-
-def parse_request_allow_empty(request):
-    content = request.content.read()
-    if content is None or content == '':
-        return None
-    try:
-        return simplejson.loads(content)
-    except simplejson.JSONDecodeError:
-        raise SynapseError(400, "Content not JSON.")
-
-
-def parse_json_dict_from_request(request):
-    try:
-        content = simplejson.loads(request.content.read())
-        if type(content) != dict:
-            raise SynapseError(400, "Content must be a JSON object.")
-        return content
-    except simplejson.JSONDecodeError:
-        raise SynapseError(400, "Content not JSON.")

+ 6 - 6
synapse/rest/client/v2_alpha/account.py

@@ -17,10 +17,10 @@ from twisted.internet import defer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import LoginError, SynapseError, Codes
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.util.async import run_on_reactor
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 import logging
 
@@ -41,9 +41,9 @@ class PasswordRestServlet(RestServlet):
     def on_POST(self, request):
         yield run_on_reactor()
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
-        authed, result, params = yield self.auth_handler.check_auth([
+        authed, result, params, _ = yield self.auth_handler.check_auth([
             [LoginType.PASSWORD],
             [LoginType.EMAIL_IDENTITY]
         ], body, self.hs.get_ip_from_request(request))
@@ -79,7 +79,7 @@ class PasswordRestServlet(RestServlet):
         new_password = params['new_password']
 
         yield self.auth_handler.set_password(
-            user_id, new_password
+            user_id, new_password, requester
         )
 
         defer.returnValue((200, {}))
@@ -114,7 +114,7 @@ class ThreepidRestServlet(RestServlet):
     def on_POST(self, request):
         yield run_on_reactor()
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         threePidCreds = body.get('threePidCreds')
         threePidCreds = body.get('three_pid_creds', threePidCreds)

+ 4 - 17
synapse/rest/client/v2_alpha/account_data.py

@@ -15,15 +15,13 @@
 
 from ._base import client_v2_patterns
 
-from synapse.http.servlet import RestServlet
-from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import AuthError
 
 from twisted.internet import defer
 
 import logging
 
-import simplejson as json
-
 logger = logging.getLogger(__name__)
 
 
@@ -47,11 +45,7 @@ class AccountDataServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
-        try:
-            content_bytes = request.content.read()
-            body = json.loads(content_bytes)
-        except:
-            raise SynapseError(400, "Invalid JSON")
+        body = parse_json_object_from_request(request)
 
         max_id = yield self.store.add_account_data_for_user(
             user_id, account_data_type, body
@@ -86,14 +80,7 @@ class RoomAccountDataServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
 
-        try:
-            content_bytes = request.content.read()
-            body = json.loads(content_bytes)
-        except:
-            raise SynapseError(400, "Invalid JSON")
-
-        if not isinstance(body, dict):
-            raise ValueError("Expected a JSON object")
+        body = parse_json_object_from_request(request)
 
         max_id = yield self.store.add_account_data_to_room(
             user_id, room_id, account_data_type, body

+ 3 - 2
synapse/rest/client/v2_alpha/auth.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
 from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
+from synapse.http.server import finish_request
 from synapse.http.servlet import RestServlet
 
 from ._base import client_v2_patterns
@@ -130,7 +131,7 @@ class AuthRestServlet(RestServlet):
             request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
 
             request.write(html_bytes)
-            request.finish()
+            finish_request(request)
             defer.returnValue(None)
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -176,7 +177,7 @@ class AuthRestServlet(RestServlet):
             request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
 
             request.write(html_bytes)
-            request.finish()
+            finish_request(request)
 
             defer.returnValue(None)
         else:

+ 2 - 8
synapse/rest/client/v2_alpha/filter.py

@@ -16,12 +16,11 @@
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError, SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import UserID
 
 from ._base import client_v2_patterns
 
-import simplejson as json
 import logging
 
 
@@ -84,12 +83,7 @@ class CreateFilterRestServlet(RestServlet):
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "Can only create filters for local users")
 
-        try:
-            content = json.loads(request.content.read())
-
-            # TODO(paul): check for required keys and invalid keys
-        except:
-            raise SynapseError(400, "Invalid filter definition")
+        content = parse_json_object_from_request(request)
 
         filter_id = yield self.filtering.add_user_filter(
             user_localpart=target_user.localpart,

+ 7 - 15
synapse/rest/client/v2_alpha/keys.py

@@ -15,16 +15,15 @@
 
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 from synapse.types import UserID
 
 from canonicaljson import encode_canonical_json
 
 from ._base import client_v2_patterns
 
-import simplejson as json
 import logging
+import simplejson as json
 
 logger = logging.getLogger(__name__)
 
@@ -68,10 +67,9 @@ class KeyUploadServlet(RestServlet):
         user_id = requester.user.to_string()
         # TODO: Check that the device_id matches that in the authentication
         # or derive the device_id from the authentication instead.
-        try:
-            body = json.loads(request.content.read())
-        except:
-            raise SynapseError(400, "Invalid key JSON")
+
+        body = parse_json_object_from_request(request)
+
         time_now = self.clock.time_msec()
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -173,10 +171,7 @@ class KeyQueryServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id):
         yield self.auth.get_user_by_req(request)
-        try:
-            body = json.loads(request.content.read())
-        except:
-            raise SynapseError(400, "Invalid key JSON")
+        body = parse_json_object_from_request(request)
         result = yield self.handle_request(body)
         defer.returnValue(result)
 
@@ -272,10 +267,7 @@ class OneTimeKeyServlet(RestServlet):
     @defer.inlineCallbacks
     def on_POST(self, request, user_id, device_id, algorithm):
         yield self.auth.get_user_by_req(request)
-        try:
-            body = json.loads(request.content.read())
-        except:
-            raise SynapseError(400, "Invalid key JSON")
+        body = parse_json_object_from_request(request)
         result = yield self.handle_request(body)
         defer.returnValue(result)
 

+ 3 - 0
synapse/rest/client/v2_alpha/receipts.py

@@ -37,6 +37,7 @@ class ReceiptRestServlet(RestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_handlers().receipts_handler
+        self.presence_handler = hs.get_handlers().presence_handler
 
     @defer.inlineCallbacks
     def on_POST(self, request, room_id, receipt_type, event_id):
@@ -45,6 +46,8 @@ class ReceiptRestServlet(RestServlet):
         if receipt_type != "m.read":
             raise SynapseError(400, "Receipt type must be 'm.read'")
 
+        yield self.presence_handler.bump_presence_active_time(requester.user)
+
         yield self.receipts_handler.received_client_receipt(
             room_id,
             receipt_type,

+ 48 - 11
synapse/rest/client/v2_alpha/register.py

@@ -17,9 +17,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 import logging
 import hmac
@@ -73,7 +73,7 @@ class RegisterRestServlet(RestServlet):
             ret = yield self.onEmailTokenRequest(request)
             defer.returnValue(ret)
 
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         # we do basic sanity checks here because the auth layer will store these
         # in sessions. Pull out the username/password provided to us.
@@ -122,10 +122,22 @@ class RegisterRestServlet(RestServlet):
 
         guest_access_token = body.get("guest_access_token", None)
 
+        session_id = self.auth_handler.get_session_id(body)
+        registered_user_id = None
+        if session_id:
+            # if we get a registered user id out of here, it means we previously
+            # registered a user for this session, so we could just return the
+            # user here. We carry on and go through the auth checks though,
+            # for paranoia.
+            registered_user_id = self.auth_handler.get_session_data(
+                session_id, "registered_user_id", None
+            )
+
         if desired_username is not None:
             yield self.registration_handler.check_username(
                 desired_username,
-                guest_access_token=guest_access_token
+                guest_access_token=guest_access_token,
+                assigned_user_id=registered_user_id,
             )
 
         if self.hs.config.enable_registration_captcha:
@@ -139,7 +151,7 @@ class RegisterRestServlet(RestServlet):
                 [LoginType.EMAIL_IDENTITY]
             ]
 
-        authed, result, params = yield self.auth_handler.check_auth(
+        authed, result, params, session_id = yield self.auth_handler.check_auth(
             flows, body, self.hs.get_ip_from_request(request)
         )
 
@@ -147,6 +159,22 @@ class RegisterRestServlet(RestServlet):
             defer.returnValue((401, result))
             return
 
+        if registered_user_id is not None:
+            logger.info(
+                "Already registered user ID %r for this session",
+                registered_user_id
+            )
+            access_token = yield self.auth_handler.issue_access_token(registered_user_id)
+            refresh_token = yield self.auth_handler.issue_refresh_token(
+                registered_user_id
+            )
+            defer.returnValue((200, {
+                "user_id": registered_user_id,
+                "access_token": access_token,
+                "home_server": self.hs.hostname,
+                "refresh_token": refresh_token,
+            }))
+
         # NB: This may be from the auth handler and NOT from the POST
         if 'password' not in params:
             raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM)
@@ -161,6 +189,12 @@ class RegisterRestServlet(RestServlet):
             guest_access_token=guest_access_token,
         )
 
+        # remember that we've now registered that user account, and with what
+        # user ID (since the user may not have specified)
+        self.auth_handler.set_session_data(
+            session_id, "registered_user_id", user_id
+        )
+
         if result and LoginType.EMAIL_IDENTITY in result:
             threepid = result[LoginType.EMAIL_IDENTITY]
 
@@ -187,7 +221,7 @@ class RegisterRestServlet(RestServlet):
             else:
                 logger.info("bind_email not specified: not binding email")
 
-        result = self._create_registration_details(user_id, token)
+        result = yield self._create_registration_details(user_id, token)
         defer.returnValue((200, result))
 
     def on_OPTIONS(self, _):
@@ -198,7 +232,7 @@ class RegisterRestServlet(RestServlet):
         (user_id, token) = yield self.registration_handler.appservice_register(
             username, as_token
         )
-        defer.returnValue(self._create_registration_details(user_id, token))
+        defer.returnValue((yield self._create_registration_details(user_id, token)))
 
     @defer.inlineCallbacks
     def _do_shared_secret_registration(self, username, password, mac):
@@ -225,18 +259,21 @@ class RegisterRestServlet(RestServlet):
         (user_id, token) = yield self.registration_handler.register(
             localpart=username, password=password
         )
-        defer.returnValue(self._create_registration_details(user_id, token))
+        defer.returnValue((yield self._create_registration_details(user_id, token)))
 
+    @defer.inlineCallbacks
     def _create_registration_details(self, user_id, token):
-        return {
+        refresh_token = yield self.auth_handler.issue_refresh_token(user_id)
+        defer.returnValue({
             "user_id": user_id,
             "access_token": token,
             "home_server": self.hs.hostname,
-        }
+            "refresh_token": refresh_token,
+        })
 
     @defer.inlineCallbacks
     def onEmailTokenRequest(self, request):
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
 
         required = ['id_server', 'client_secret', 'email', 'send_attempt']
         absent = []

+ 10 - 6
synapse/rest/client/v2_alpha/sync.py

@@ -25,6 +25,7 @@ from synapse.events.utils import (
 )
 from synapse.api.filtering import FilterCollection, DEFAULT_FILTER_COLLECTION
 from synapse.api.errors import SynapseError
+from synapse.api.constants import PresenceState
 from ._base import client_v2_patterns
 
 import copy
@@ -82,6 +83,7 @@ class SyncRestServlet(RestServlet):
         self.sync_handler = hs.get_handlers().sync_handler
         self.clock = hs.get_clock()
         self.filtering = hs.get_filtering()
+        self.presence_handler = hs.get_handlers().presence_handler
 
     @defer.inlineCallbacks
     def on_GET(self, request):
@@ -139,17 +141,19 @@ class SyncRestServlet(RestServlet):
         else:
             since_token = None
 
-        if set_presence == "online":
-            yield self.event_stream_handler.started_stream(user)
+        affect_presence = set_presence != PresenceState.OFFLINE
 
-        try:
+        if affect_presence:
+            yield self.presence_handler.set_state(user, {"presence": set_presence})
+
+        context = yield self.presence_handler.user_syncing(
+            user.to_string(), affect_presence=affect_presence,
+        )
+        with context:
             sync_result = yield self.sync_handler.wait_for_sync_for_user(
                 sync_config, since_token=since_token, timeout=timeout,
                 full_state=full_state
             )
-        finally:
-            if set_presence == "online":
-                self.event_stream_handler.stopped_stream(user)
 
         time_now = self.clock.time_msec()
 

+ 3 - 9
synapse/rest/client/v2_alpha/tags.py

@@ -15,15 +15,13 @@
 
 from ._base import client_v2_patterns
 
-from synapse.http.servlet import RestServlet
-from synapse.api.errors import AuthError, SynapseError
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
+from synapse.api.errors import AuthError
 
 from twisted.internet import defer
 
 import logging
 
-import simplejson as json
-
 logger = logging.getLogger(__name__)
 
 
@@ -72,11 +70,7 @@ class TagServlet(RestServlet):
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add tags for other users.")
 
-        try:
-            content_bytes = request.content.read()
-            body = json.loads(content_bytes)
-        except:
-            raise SynapseError(400, "Invalid tag JSON")
+        body = parse_json_object_from_request(request)
 
         max_id = yield self.store.add_tag_to_room(user_id, room_id, tag, body)
 

+ 3 - 3
synapse/rest/client/v2_alpha/tokenrefresh.py

@@ -16,9 +16,9 @@
 from twisted.internet import defer
 
 from synapse.api.errors import AuthError, StoreError, SynapseError
-from synapse.http.servlet import RestServlet
+from synapse.http.servlet import RestServlet, parse_json_object_from_request
 
-from ._base import client_v2_patterns, parse_json_dict_from_request
+from ._base import client_v2_patterns
 
 
 class TokenRefreshRestServlet(RestServlet):
@@ -35,7 +35,7 @@ class TokenRefreshRestServlet(RestServlet):
 
     @defer.inlineCallbacks
     def on_POST(self, request):
-        body = parse_json_dict_from_request(request)
+        body = parse_json_object_from_request(request)
         try:
             old_refresh_token = body["refresh_token"]
             auth_handler = self.hs.get_handlers().auth_handler

+ 2 - 10
synapse/rest/key/v2/remote_key_resource.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from synapse.http.server import request_handler, respond_with_json_bytes
-from synapse.http.servlet import parse_integer
+from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.api.errors import SynapseError, Codes
 
 from twisted.web.resource import Resource
@@ -22,7 +22,6 @@ from twisted.internet import defer
 
 
 from io import BytesIO
-import json
 import logging
 logger = logging.getLogger(__name__)
 
@@ -126,14 +125,7 @@ class RemoteKey(Resource):
     @request_handler
     @defer.inlineCallbacks
     def async_render_POST(self, request):
-        try:
-            content = json.loads(request.content.read())
-            if type(content) != dict:
-                raise ValueError()
-        except ValueError:
-            raise SynapseError(
-                400, "Content must be JSON object.", errcode=Codes.NOT_JSON
-            )
+        content = parse_json_object_from_request(request)
 
         query = content["server_keys"]
 

+ 2 - 2
synapse/rest/media/v0/content_repository.py

@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from synapse.http.server import respond_with_json_bytes
+from synapse.http.server import respond_with_json_bytes, finish_request
 
 from synapse.util.stringutils import random_string
 from synapse.api.errors import (
@@ -144,7 +144,7 @@ class ContentRepoResource(resource.Resource):
             # after the file has been sent, clean up and finish the request
             def cbFinished(ignored):
                 f.close()
-                request.finish()
+                finish_request(request)
             d.addCallback(cbFinished)
         else:
             respond_with_json_bytes(

+ 2 - 2
synapse/rest/media/v1/base_resource.py

@@ -16,7 +16,7 @@
 from .thumbnailer import Thumbnailer
 
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
-from synapse.http.server import respond_with_json
+from synapse.http.server import respond_with_json, finish_request
 from synapse.util.stringutils import random_string
 from synapse.api.errors import (
     cs_error, Codes, SynapseError
@@ -238,7 +238,7 @@ class BaseMediaResource(Resource):
             with open(file_path, "rb") as f:
                 yield FileSender().beginFileTransfer(f, request)
 
-            request.finish()
+            finish_request(request)
         else:
             self._respond_404(request)
 

+ 68 - 75
synapse/state.py

@@ -18,6 +18,7 @@ from twisted.internet import defer
 
 from synapse.util.logutils import log_function
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.metrics import Measure
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.auth import AuthEventTypes
@@ -27,6 +28,7 @@ from collections import namedtuple
 
 import logging
 import hashlib
+import os
 
 logger = logging.getLogger(__name__)
 
@@ -34,8 +36,11 @@ logger = logging.getLogger(__name__)
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-SIZE_OF_CACHE = 1000
-EVICTION_TIMEOUT_SECONDS = 20
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
+
+
+SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
+EVICTION_TIMEOUT_SECONDS = 60 * 60
 
 
 class _StateCacheEntry(object):
@@ -85,16 +90,8 @@ class StateHandler(object):
         """
         event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
 
-        cache = None
-        if self._state_cache is not None:
-            cache = self._state_cache.get(frozenset(event_ids), None)
-
-        if cache:
-            cache.ts = self.clock.time_msec()
-            state = cache.state
-        else:
-            res = yield self.resolve_state_groups(room_id, event_ids)
-            state = res[1]
+        res = yield self.resolve_state_groups(room_id, event_ids)
+        state = res[1]
 
         if event_type:
             defer.returnValue(state.get((event_type, state_key)))
@@ -186,20 +183,6 @@ class StateHandler(object):
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        if self._state_cache is not None:
-            cache = self._state_cache.get(frozenset(event_ids), None)
-            if cache and cache.state_group:
-                cache.ts = self.clock.time_msec()
-                prev_state = cache.state.get((event_type, state_key), None)
-                if prev_state:
-                    prev_state = prev_state.event_id
-                    prev_states = [prev_state]
-                else:
-                    prev_states = []
-                defer.returnValue(
-                    (cache.state_group, cache.state, prev_states)
-                )
-
         state_groups = yield self.store.get_state_groups(
             room_id, event_ids
         )
@@ -209,7 +192,7 @@ class StateHandler(object):
             state_groups.keys()
         )
 
-        group_names = set(state_groups.keys())
+        group_names = frozenset(state_groups.keys())
         if len(group_names) == 1:
             name, state_list = state_groups.items().pop()
             state = {
@@ -223,16 +206,25 @@ class StateHandler(object):
             else:
                 prev_states = []
 
-            if self._state_cache is not None:
-                cache = _StateCacheEntry(
-                    state=state,
-                    state_group=name,
-                    ts=self.clock.time_msec()
-                )
+            defer.returnValue((name, state, prev_states))
 
-                self._state_cache[frozenset(event_ids)] = cache
+        if self._state_cache is not None:
+            cache = self._state_cache.get(group_names, None)
+            if cache and cache.state_group:
+                cache.ts = self.clock.time_msec()
 
-            defer.returnValue((name, state, prev_states))
+                event_dict = yield self.store.get_events(cache.state.values())
+                state = {(e.type, e.state_key): e for e in event_dict.values()}
+
+                prev_state = state.get((event_type, state_key), None)
+                if prev_state:
+                    prev_state = prev_state.event_id
+                    prev_states = [prev_state]
+                else:
+                    prev_states = []
+                defer.returnValue(
+                    (cache.state_group, state, prev_states)
+                )
 
         new_state, prev_states = self._resolve_events(
             state_groups.values(), event_type, state_key
@@ -240,12 +232,12 @@ class StateHandler(object):
 
         if self._state_cache is not None:
             cache = _StateCacheEntry(
-                state=new_state,
+                state={key: event.event_id for key, event in new_state.items()},
                 state_group=None,
                 ts=self.clock.time_msec()
             )
 
-            self._state_cache[frozenset(event_ids)] = cache
+            self._state_cache[group_names] = cache
 
         defer.returnValue((None, new_state, prev_states))
 
@@ -263,48 +255,49 @@ class StateHandler(object):
         from (type, state_key) to event. prev_states is a list of event_ids.
         :rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
         """
-        state = {}
-        for st in state_sets:
-            for e in st:
-                state.setdefault(
-                    (e.type, e.state_key),
-                    {}
-                )[e.event_id] = e
-
-        unconflicted_state = {
-            k: v.values()[0] for k, v in state.items()
-            if len(v.values()) == 1
-        }
-
-        conflicted_state = {
-            k: v.values()
-            for k, v in state.items()
-            if len(v.values()) > 1
-        }
+        with Measure(self.clock, "state._resolve_events"):
+            state = {}
+            for st in state_sets:
+                for e in st:
+                    state.setdefault(
+                        (e.type, e.state_key),
+                        {}
+                    )[e.event_id] = e
+
+            unconflicted_state = {
+                k: v.values()[0] for k, v in state.items()
+                if len(v.values()) == 1
+            }
 
-        if event_type:
-            prev_states_events = conflicted_state.get(
-                (event_type, state_key), []
-            )
-            prev_states = [s.event_id for s in prev_states_events]
-        else:
-            prev_states = []
+            conflicted_state = {
+                k: v.values()
+                for k, v in state.items()
+                if len(v.values()) > 1
+            }
+
+            if event_type:
+                prev_states_events = conflicted_state.get(
+                    (event_type, state_key), []
+                )
+                prev_states = [s.event_id for s in prev_states_events]
+            else:
+                prev_states = []
 
-        auth_events = {
-            k: e for k, e in unconflicted_state.items()
-            if k[0] in AuthEventTypes
-        }
+            auth_events = {
+                k: e for k, e in unconflicted_state.items()
+                if k[0] in AuthEventTypes
+            }
 
-        try:
-            resolved_state = self._resolve_state_events(
-                conflicted_state, auth_events
-            )
-        except:
-            logger.exception("Failed to resolve state")
-            raise
+            try:
+                resolved_state = self._resolve_state_events(
+                    conflicted_state, auth_events
+                )
+            except:
+                logger.exception("Failed to resolve state")
+                raise
 
-        new_state = unconflicted_state
-        new_state.update(resolved_state)
+            new_state = unconflicted_state
+            new_state.update(resolved_state)
 
         return new_state, prev_states
 

+ 74 - 11
synapse/storage/__init__.py

@@ -20,7 +20,7 @@ from .appservice import (
 from ._base import Cache
 from .directory import DirectoryStore
 from .events import EventsStore
-from .presence import PresenceStore
+from .presence import PresenceStore, UserPresenceState
 from .profile import ProfileStore
 from .registration import RegistrationStore
 from .room import RoomStore
@@ -45,8 +45,9 @@ from .search import SearchStore
 from .tags import TagsStore
 from .account_data import AccountDataStore
 
-from util.id_generators import IdGenerator, StreamIdGenerator
+from .util.id_generators import IdGenerator, StreamIdGenerator, ChainedIdGenerator
 
+from synapse.api.constants import PresenceState
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
 
@@ -110,16 +111,25 @@ class DataStore(RoomMemberStore, RoomStore,
         self._account_data_id_gen = StreamIdGenerator(
             db_conn, "account_data_max_stream_id", "stream_id"
         )
+        self._presence_id_gen = StreamIdGenerator(
+            db_conn, "presence_stream", "stream_id"
+        )
 
-        self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
-        self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
-        self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
-        self._refresh_tokens_id_gen = IdGenerator("refresh_tokens", "id", self)
-        self._pushers_id_gen = IdGenerator("pushers", "id", self)
-        self._push_rule_id_gen = IdGenerator("push_rules", "id", self)
-        self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self)
+        self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
+        self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
+        self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
+        self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
+        self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
+        self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
+        self._push_rules_stream_id_gen = ChainedIdGenerator(
+            self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
+        )
+        self._pushers_id_gen = StreamIdGenerator(
+            db_conn, "pushers", "id",
+            extra_tables=[("deleted_pushers", "stream_id")],
+        )
 
-        events_max = self._stream_id_gen.get_max_token(None)
+        events_max = self._stream_id_gen.get_max_token()
         event_cache_prefill, min_event_val = self._get_cache_dict(
             db_conn, "events",
             entity_column="room_id",
@@ -135,13 +145,43 @@ class DataStore(RoomMemberStore, RoomStore,
             "MembershipStreamChangeCache", events_max,
         )
 
-        account_max = self._account_data_id_gen.get_max_token(None)
+        account_max = self._account_data_id_gen.get_max_token()
         self._account_data_stream_cache = StreamChangeCache(
             "AccountDataAndTagsChangeCache", account_max,
         )
 
+        self.__presence_on_startup = self._get_active_presence(db_conn)
+
+        presence_cache_prefill, min_presence_val = self._get_cache_dict(
+            db_conn, "presence_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._presence_id_gen.get_max_token(),
+        )
+        self.presence_stream_cache = StreamChangeCache(
+            "PresenceStreamChangeCache", min_presence_val,
+            prefilled_cache=presence_cache_prefill
+        )
+
+        push_rules_prefill, push_rules_id = self._get_cache_dict(
+            db_conn, "push_rules_stream",
+            entity_column="user_id",
+            stream_column="stream_id",
+            max_value=self._push_rules_stream_id_gen.get_max_token()[0],
+        )
+
+        self.push_rules_stream_cache = StreamChangeCache(
+            "PushRulesStreamChangeCache", push_rules_id,
+            prefilled_cache=push_rules_prefill,
+        )
+
         super(DataStore, self).__init__(hs)
 
+    def take_presence_startup_info(self):
+        active_on_startup = self.__presence_on_startup
+        self.__presence_on_startup = None
+        return active_on_startup
+
     def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
@@ -161,6 +201,7 @@ class DataStore(RoomMemberStore, RoomStore,
         txn = db_conn.cursor()
         txn.execute(sql, (int(max_value),))
         rows = txn.fetchall()
+        txn.close()
 
         cache = {
             row[0]: int(row[1])
@@ -174,6 +215,28 @@ class DataStore(RoomMemberStore, RoomStore,
 
         return cache, min_val
 
+    def _get_active_presence(self, db_conn):
+        """Fetch non-offline presence from the database so that we can register
+        the appropriate time outs.
+        """
+
+        sql = (
+            "SELECT user_id, state, last_active_ts, last_federation_update_ts,"
+            " last_user_sync_ts, status_msg, currently_active FROM presence_stream"
+            " WHERE state != ?"
+        )
+        sql = self.database_engine.convert_param_style(sql)
+
+        txn = db_conn.cursor()
+        txn.execute(sql, (PresenceState.OFFLINE,))
+        rows = self.cursor_to_dict(txn)
+        txn.close()
+
+        for row in rows:
+            row["currently_active"] = bool(row["currently_active"])
+
+        return [UserPresenceState(**row) for row in rows]
+
     @defer.inlineCallbacks
     def insert_client_ip(self, user, access_token, ip, user_agent):
         now = int(self._clock.time_msec())

+ 27 - 9
synapse/storage/_base.py

@@ -18,6 +18,7 @@ from synapse.api.errors import StoreError
 from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
 from synapse.util.caches.dictionary_cache import DictionaryCache
 from synapse.util.caches.descriptors import Cache
+from synapse.util.caches import intern_dict
 import synapse.metrics
 
 
@@ -26,6 +27,10 @@ from twisted.internet import defer
 import sys
 import time
 import threading
+import os
+
+
+CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
 
 
 logger = logging.getLogger(__name__)
@@ -163,7 +168,9 @@ class SQLBaseStore(object):
         self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
                                       max_entries=hs.config.event_cache_size)
 
-        self._state_group_cache = DictionaryCache("*stateGroupCache*", 2000)
+        self._state_group_cache = DictionaryCache(
+            "*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
+        )
 
         self._event_fetch_lock = threading.Condition()
         self._event_fetch_list = []
@@ -344,7 +351,7 @@ class SQLBaseStore(object):
         """
         col_headers = list(column[0] for column in cursor.description)
         results = list(
-            dict(zip(col_headers, row)) for row in cursor.fetchall()
+            intern_dict(dict(zip(col_headers, row))) for row in cursor.fetchall()
         )
         return results
 
@@ -766,6 +773,19 @@ class SQLBaseStore(object):
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
+        Args:
+            table : string giving the table name
+            keyvalues : dict of column names and values to select the row with
+        """
+        return self.runInteraction(
+            desc, self._simple_delete_one_txn, table, keyvalues
+        )
+
+    @staticmethod
+    def _simple_delete_one_txn(txn, table, keyvalues):
+        """Executes a DELETE query on the named table, expecting to delete a
+        single row.
+
         Args:
             table : string giving the table name
             keyvalues : dict of column names and values to select the row with
@@ -775,13 +795,11 @@ class SQLBaseStore(object):
             " AND ".join("%s = ?" % (k, ) for k in keyvalues)
         )
 
-        def func(txn):
-            txn.execute(sql, keyvalues.values())
-            if txn.rowcount == 0:
-                raise StoreError(404, "No row found")
-            if txn.rowcount > 1:
-                raise StoreError(500, "more than one row matched")
-        return self.runInteraction(desc, func)
+        txn.execute(sql, keyvalues.values())
+        if txn.rowcount == 0:
+            raise StoreError(404, "No row found")
+        if txn.rowcount > 1:
+            raise StoreError(500, "more than one row matched")
 
     @staticmethod
     def _simple_delete_txn(txn, table, keyvalues):

+ 38 - 6
synapse/storage/account_data.py

@@ -83,8 +83,40 @@ class AccountDataStore(SQLBaseStore):
             "get_account_data_for_room", get_account_data_for_room_txn
         )
 
-    def get_updated_account_data_for_user(self, user_id, stream_id, room_ids=None):
-        """Get all the client account_data for a that's changed.
+    def get_all_updated_account_data(self, last_global_id, last_room_id,
+                                     current_id, limit):
+        """Get all the client account_data that has changed on the server
+        Args:
+            last_global_id(int): The position to fetch from for top level data
+            last_room_id(int): The position to fetch from for per room data
+            current_id(int): The position to fetch up to.
+        Returns:
+            A deferred pair of lists of tuples of stream_id int, user_id string,
+            room_id string, type string, and content string.
+        """
+        def get_updated_account_data_txn(txn):
+            sql = (
+                "SELECT stream_id, user_id, account_data_type, content"
+                " FROM account_data WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_global_id, current_id, limit))
+            global_results = txn.fetchall()
+
+            sql = (
+                "SELECT stream_id, user_id, room_id, account_data_type, content"
+                " FROM room_account_data WHERE ? < stream_id AND stream_id <= ?"
+                " ORDER BY stream_id ASC LIMIT ?"
+            )
+            txn.execute(sql, (last_room_id, current_id, limit))
+            room_results = txn.fetchall()
+            return (global_results, room_results)
+        return self.runInteraction(
+            "get_all_updated_account_data_txn", get_updated_account_data_txn
+        )
+
+    def get_updated_account_data_for_user(self, user_id, stream_id):
+        """Get all the client account_data for a that's changed for a user
 
         Args:
             user_id(str): The user to get the account_data for.
@@ -163,12 +195,12 @@ class AccountDataStore(SQLBaseStore):
             )
             self._update_max_stream_id(txn, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction(
                 "add_room_account_data", add_account_data_txn, next_id
             )
 
-        result = yield self._account_data_id_gen.get_max_token(self)
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     @defer.inlineCallbacks
@@ -202,12 +234,12 @@ class AccountDataStore(SQLBaseStore):
             )
             self._update_max_stream_id(txn, next_id)
 
-        with (yield self._account_data_id_gen.get_next(self)) as next_id:
+        with self._account_data_id_gen.get_next() as next_id:
             yield self.runInteraction(
                 "add_user_account_data", add_account_data_txn, next_id
             )
 
-        result = yield self._account_data_id_gen.get_max_token(self)
+        result = self._account_data_id_gen.get_max_token()
         defer.returnValue(result)
 
     def _update_max_stream_id(self, txn, next_id):

+ 21 - 13
synapse/storage/appservice.py

@@ -34,8 +34,8 @@ class ApplicationServiceStore(SQLBaseStore):
     def __init__(self, hs):
         super(ApplicationServiceStore, self).__init__(hs)
         self.hostname = hs.hostname
-        self.services_cache = []
-        self._populate_appservice_cache(
+        self.services_cache = ApplicationServiceStore.load_appservices(
+            hs.hostname,
             hs.config.app_service_config_files
         )
 
@@ -144,21 +144,23 @@ class ApplicationServiceStore(SQLBaseStore):
 
         return rooms_for_user_matching_user_id
 
-    def _load_appservice(self, as_info):
+    @classmethod
+    def _load_appservice(cls, hostname, as_info, config_filename):
         required_string_fields = [
-            # TODO: Add id here when it's stable to release
-            "url", "as_token", "hs_token", "sender_localpart"
+            "id", "url", "as_token", "hs_token", "sender_localpart"
         ]
         for field in required_string_fields:
             if not isinstance(as_info.get(field), basestring):
-                raise KeyError("Required string field: '%s'", field)
+                raise KeyError("Required string field: '%s' (%s)" % (
+                    field, config_filename,
+                ))
 
         localpart = as_info["sender_localpart"]
         if urllib.quote(localpart) != localpart:
             raise ValueError(
                 "sender_localpart needs characters which are not URL encoded."
             )
-        user = UserID(localpart, self.hostname)
+        user = UserID(localpart, hostname)
         user_id = user.to_string()
 
         # namespace checks
@@ -188,25 +190,30 @@ class ApplicationServiceStore(SQLBaseStore):
             namespaces=as_info["namespaces"],
             hs_token=as_info["hs_token"],
             sender=user_id,
-            id=as_info["id"] if "id" in as_info else as_info["as_token"],
+            id=as_info["id"],
         )
 
-    def _populate_appservice_cache(self, config_files):
-        """Populates a cache of Application Services from the config files."""
+    @classmethod
+    def load_appservices(cls, hostname, config_files):
+        """Returns a list of Application Services from the config files."""
         if not isinstance(config_files, list):
             logger.warning(
                 "Expected %s to be a list of AS config files.", config_files
             )
-            return
+            return []
 
         # Dicts of value -> filename
         seen_as_tokens = {}
         seen_ids = {}
 
+        appservices = []
+
         for config_file in config_files:
             try:
                 with open(config_file, 'r') as f:
-                    appservice = self._load_appservice(yaml.load(f))
+                    appservice = ApplicationServiceStore._load_appservice(
+                        hostname, yaml.load(f), config_file
+                    )
                     if appservice.id in seen_ids:
                         raise ConfigError(
                             "Cannot reuse ID across application services: "
@@ -226,11 +233,12 @@ class ApplicationServiceStore(SQLBaseStore):
                         )
                     seen_as_tokens[appservice.token] = config_file
                     logger.info("Loaded application service: %s", appservice)
-                    self.services_cache.append(appservice)
+                    appservices.append(appservice)
             except Exception as e:
                 logger.error("Failed to load appservice from '%s'", config_file)
                 logger.exception(e)
                 raise
+        return appservices
 
 
 class ApplicationServiceTransactionStore(SQLBaseStore):

+ 15 - 2
synapse/storage/directory.py

@@ -70,13 +70,14 @@ class DirectoryStore(SQLBaseStore):
         )
 
     @defer.inlineCallbacks
-    def create_room_alias_association(self, room_alias, room_id, servers):
+    def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
         """ Creates an associatin between  a room alias and room_id/servers
 
         Args:
             room_alias (RoomAlias)
             room_id (str)
             servers (list)
+            creator (str): Optional user_id of creator.
 
         Returns:
             Deferred
@@ -87,6 +88,7 @@ class DirectoryStore(SQLBaseStore):
                 {
                     "room_alias": room_alias.to_string(),
                     "room_id": room_id,
+                    "creator": creator,
                 },
                 desc="create_room_alias_association",
             )
@@ -107,6 +109,17 @@ class DirectoryStore(SQLBaseStore):
             )
         self.get_aliases_for_room.invalidate((room_id,))
 
+    def get_room_alias_creator(self, room_alias):
+        return self._simple_select_one_onecol(
+            table="room_aliases",
+            keyvalues={
+                "room_alias": room_alias,
+            },
+            retcol="creator",
+            desc="get_room_alias_creator",
+            allow_none=True
+        )
+
     @defer.inlineCallbacks
     def delete_room_alias(self, room_alias):
         room_id = yield self.runInteraction(
@@ -142,7 +155,7 @@ class DirectoryStore(SQLBaseStore):
 
         return room_id
 
-    @cached()
+    @cached(max_entries=5000)
     def get_aliases_for_room(self, room_id):
         return self._simple_select_onecol(
             "room_aliases",

+ 1 - 1
synapse/storage/end_to_end_keys.py

@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
 
 
 class EndToEndKeyStore(SQLBaseStore):

+ 3 - 2
synapse/storage/engines/__init__.py

@@ -26,12 +26,13 @@ SUPPORTED_MODULE = {
 }
 
 
-def create_engine(name):
+def create_engine(config):
+    name = config.database_config["name"]
     engine_class = SUPPORTED_MODULE.get(name, None)
 
     if engine_class:
         module = importlib.import_module(name)
-        return engine_class(module)
+        return engine_class(module, config=config)
 
     raise RuntimeError(
         "Unsupported database engine '%s'" % (name,)

+ 3 - 2
synapse/storage/engines/postgres.py

@@ -21,9 +21,10 @@ from ._base import IncorrectDatabaseSetup
 class PostgresEngine(object):
     single_threaded = False
 
-    def __init__(self, database_module):
+    def __init__(self, database_module, config):
         self.module = database_module
         self.module.extensions.register_type(self.module.extensions.UNICODE)
+        self.config = config
 
     def check_database(self, txn):
         txn.execute("SHOW SERVER_ENCODING")
@@ -44,7 +45,7 @@ class PostgresEngine(object):
         )
 
     def prepare_database(self, db_conn):
-        prepare_database(db_conn, self)
+        prepare_database(db_conn, self, config=self.config)
 
     def is_deadlock(self, error):
         if isinstance(error, self.module.DatabaseError):

+ 3 - 2
synapse/storage/engines/sqlite3.py

@@ -23,8 +23,9 @@ import struct
 class Sqlite3Engine(object):
     single_threaded = True
 
-    def __init__(self, database_module):
+    def __init__(self, database_module, config):
         self.module = database_module
+        self.config = config
 
     def check_database(self, txn):
         pass
@@ -38,7 +39,7 @@ class Sqlite3Engine(object):
 
     def prepare_database(self, db_conn):
         prepare_sqlite3_database(db_conn)
-        prepare_database(db_conn, self)
+        prepare_database(db_conn, self, config=self.config)
 
     def is_deadlock(self, error):
         return False

+ 4 - 4
synapse/storage/event_federation.py

@@ -114,10 +114,10 @@ class EventFederationStore(SQLBaseStore):
             retcol="event_id",
         )
 
-    def get_latest_events_in_room(self, room_id):
+    def get_latest_event_ids_and_hashes_in_room(self, room_id):
         return self.runInteraction(
-            "get_latest_events_in_room",
-            self._get_latest_events_in_room,
+            "get_latest_event_ids_and_hashes_in_room",
+            self._get_latest_event_ids_and_hashes_in_room,
             room_id,
         )
 
@@ -132,7 +132,7 @@ class EventFederationStore(SQLBaseStore):
             desc="get_latest_event_ids_in_room",
         )
 
-    def _get_latest_events_in_room(self, txn, room_id):
+    def _get_latest_event_ids_and_hashes_in_room(self, txn, room_id):
         sql = (
             "SELECT e.event_id, e.depth FROM events as e "
             "INNER JOIN event_forward_extremities as f "

+ 4 - 5
synapse/storage/event_push_actions.py

@@ -27,15 +27,14 @@ class EventPushActionsStore(SQLBaseStore):
     def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
         """
         :param event: the event set actions for
-        :param tuples: list of tuples of (user_id, profile_tag, actions)
+        :param tuples: list of tuples of (user_id, actions)
         """
         values = []
-        for uid, profile_tag, actions in tuples:
+        for uid, actions in tuples:
             values.append({
                 'room_id': event.room_id,
                 'event_id': event.event_id,
                 'user_id': uid,
-                'profile_tag': profile_tag,
                 'actions': json.dumps(actions),
                 'stream_ordering': event.internal_metadata.stream_ordering,
                 'topological_ordering': event.depth,
@@ -43,14 +42,14 @@ class EventPushActionsStore(SQLBaseStore):
                 'highlight': 1 if _action_has_highlight(actions) else 0,
             })
 
-        for uid, _, __ in tuples:
+        for uid, __ in tuples:
             txn.call_after(
                 self.get_unread_event_push_actions_by_room_for_user.invalidate_many,
                 (event.room_id, uid)
             )
         self._simple_insert_many_txn(txn, "event_push_actions", values)
 
-    @cachedInlineCallbacks(num_args=3, lru=True, tree=True)
+    @cachedInlineCallbacks(num_args=3, lru=True, tree=True, max_entries=5000)
     def get_unread_event_push_actions_by_room_for_user(
             self, room_id, user_id, last_read_event_id
     ):

+ 92 - 38
synapse/storage/events.py

@@ -12,7 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from _base import SQLBaseStore, _RollbackButIsFineException
+from ._base import SQLBaseStore, _RollbackButIsFineException
 
 from twisted.internet import defer, reactor
 
@@ -75,8 +75,8 @@ class EventsStore(SQLBaseStore):
                 yield stream_orderings
             stream_ordering_manager = stream_ordering_manager()
         else:
-            stream_ordering_manager = yield self._stream_id_gen.get_next_mult(
-                self, len(events_and_contexts)
+            stream_ordering_manager = self._stream_id_gen.get_next_mult(
+                len(events_and_contexts)
             )
 
         with stream_ordering_manager as stream_orderings:
@@ -101,37 +101,23 @@ class EventsStore(SQLBaseStore):
 
     @defer.inlineCallbacks
     @log_function
-    def persist_event(self, event, context, backfilled=False,
+    def persist_event(self, event, context,
                       is_new_state=True, current_state=None):
-        stream_ordering = None
-        if backfilled:
-            self.min_stream_token -= 1
-            stream_ordering = self.min_stream_token
-
-        if stream_ordering is None:
-            stream_ordering_manager = yield self._stream_id_gen.get_next(self)
-        else:
-            @contextmanager
-            def stream_ordering_manager():
-                yield stream_ordering
-            stream_ordering_manager = stream_ordering_manager()
-
         try:
-            with stream_ordering_manager as stream_ordering:
+            with self._stream_id_gen.get_next() as stream_ordering:
                 event.internal_metadata.stream_ordering = stream_ordering
                 yield self.runInteraction(
                     "persist_event",
                     self._persist_event_txn,
                     event=event,
                     context=context,
-                    backfilled=backfilled,
                     is_new_state=is_new_state,
                     current_state=current_state,
                 )
         except _RollbackButIsFineException:
             pass
 
-        max_persisted_id = yield self._stream_id_gen.get_max_token(self)
+        max_persisted_id = yield self._stream_id_gen.get_max_token()
         defer.returnValue((stream_ordering, max_persisted_id))
 
     @defer.inlineCallbacks
@@ -165,13 +151,38 @@ class EventsStore(SQLBaseStore):
 
         defer.returnValue(events[0] if events else None)
 
+    @defer.inlineCallbacks
+    def get_events(self, event_ids, check_redacted=True,
+                   get_prev_content=False, allow_rejected=False):
+        """Get events from the database
+
+        Args:
+            event_ids (list): The event_ids of the events to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+
+        Returns:
+            Deferred : Dict from event_id to event.
+        """
+        events = yield self._get_events(
+            event_ids,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
+
+        defer.returnValue({e.event_id: e for e in events})
+
     @log_function
-    def _persist_event_txn(self, txn, event, context, backfilled,
+    def _persist_event_txn(self, txn, event, context,
                            is_new_state=True, current_state=None):
         # We purposefully do this first since if we include a `current_state`
         # key, we *want* to update the `current_state_events` table
         if current_state:
-            txn.call_after(self.get_current_state_for_key.invalidate_all)
+            txn.call_after(self._get_current_state_for_key.invalidate_all)
             txn.call_after(self.get_rooms_for_user.invalidate_all)
             txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
             txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
@@ -198,7 +209,7 @@ class EventsStore(SQLBaseStore):
         return self._persist_events_txn(
             txn,
             [(event, context)],
-            backfilled=backfilled,
+            backfilled=False,
             is_new_state=is_new_state,
         )
 
@@ -455,7 +466,7 @@ class EventsStore(SQLBaseStore):
             for event, _ in state_events_and_contexts:
                 if not context.rejected:
                     txn.call_after(
-                        self.get_current_state_for_key.invalidate,
+                        self._get_current_state_for_key.invalidate,
                         (event.room_id, event.type, event.state_key,)
                     )
 
@@ -526,6 +537,9 @@ class EventsStore(SQLBaseStore):
         if not event_ids:
             defer.returnValue([])
 
+        event_id_list = event_ids
+        event_ids = set(event_ids)
+
         event_map = self._get_events_from_cache(
             event_ids,
             check_redacted=check_redacted,
@@ -535,23 +549,18 @@ class EventsStore(SQLBaseStore):
 
         missing_events_ids = [e for e in event_ids if e not in event_map]
 
-        if not missing_events_ids:
-            defer.returnValue([
-                event_map[e_id] for e_id in event_ids
-                if e_id in event_map and event_map[e_id]
-            ])
-
-        missing_events = yield self._enqueue_events(
-            missing_events_ids,
-            check_redacted=check_redacted,
-            get_prev_content=get_prev_content,
-            allow_rejected=allow_rejected,
-        )
+        if missing_events_ids:
+            missing_events = yield self._enqueue_events(
+                missing_events_ids,
+                check_redacted=check_redacted,
+                get_prev_content=get_prev_content,
+                allow_rejected=allow_rejected,
+            )
 
-        event_map.update(missing_events)
+            event_map.update(missing_events)
 
         defer.returnValue([
-            event_map[e_id] for e_id in event_ids
+            event_map[e_id] for e_id in event_id_list
             if e_id in event_map and event_map[e_id]
         ])
 
@@ -1064,3 +1073,48 @@ class EventsStore(SQLBaseStore):
             yield self._end_background_update(self.EVENT_ORIGIN_SERVER_TS_NAME)
 
         defer.returnValue(result)
+
+    def get_current_backfill_token(self):
+        """The current minimum token that backfilled events have reached"""
+
+        # TODO: Fix race with the persit_event txn by using one of the
+        # stream id managers
+        return -self.min_stream_token
+
+    def get_all_new_events(self, last_backfill_id, last_forward_id,
+                           current_backfill_id, current_forward_id, limit):
+        """Get all the new events that have arrived at the server either as
+        new events or as backfilled events"""
+        def get_all_new_events_txn(txn):
+            sql = (
+                "SELECT e.stream_ordering, ej.internal_metadata, ej.json"
+                " FROM events as e"
+                " JOIN event_json as ej"
+                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+                " WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
+                " ORDER BY e.stream_ordering ASC"
+                " LIMIT ?"
+            )
+            if last_forward_id != current_forward_id:
+                txn.execute(sql, (last_forward_id, current_forward_id, limit))
+                new_forward_events = txn.fetchall()
+            else:
+                new_forward_events = []
+
+            sql = (
+                "SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
+                " FROM events as e"
+                " JOIN event_json as ej"
+                " ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
+                " WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
+                " ORDER BY e.stream_ordering DESC"
+                " LIMIT ?"
+            )
+            if last_backfill_id != current_backfill_id:
+                txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
+                new_backfill_events = txn.fetchall()
+            else:
+                new_backfill_events = []
+
+            return (new_forward_events, new_backfill_events)
+        return self.runInteraction("get_all_new_events", get_all_new_events_txn)

+ 1 - 1
synapse/storage/keys.py

@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from _base import SQLBaseStore
+from ._base import SQLBaseStore
 from synapse.util.caches.descriptors import cachedInlineCallbacks
 
 from twisted.internet import defer

Some files were not shown because too many files changed in this diff