Skip to content

Commit

Permalink
Aggregate unread notif count query for badge count calculation (matri…
Browse files Browse the repository at this point in the history
…x-org#14255)

Fetch the unread notification counts used by the badge counts
in push notifications for all rooms at once (instead of fetching
them per room).
  • Loading branch information
Fizzadar authored Nov 30, 2022
1 parent 4569eda commit e8bce89
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 27 deletions.
1 change: 1 addition & 0 deletions changelog.d/14255.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimise push badge count calculations. Contributed by Nick @ Beeper (@fizzadar).
28 changes: 9 additions & 19 deletions synapse/push/push_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage.controllers import StorageControllers
from synapse.storage.databases.main import DataStore
from synapse.util.async_helpers import concurrently_execute


async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -> int:
Expand All @@ -26,23 +25,12 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -

badge = len(invites)

room_notifs = []

async def get_room_unread_count(room_id: str) -> None:
room_notifs.append(
await store.get_unread_event_push_actions_by_room_for_user(
room_id,
user_id,
)
)

await concurrently_execute(get_room_unread_count, joins, 10)

for notifs in room_notifs:
# Combine the counts from all the threads.
notify_count = notifs.main_timeline.notify_count + sum(
n.notify_count for n in notifs.threads.values()
)
room_to_count = await store.get_unread_counts_by_room_for_user(user_id)
for room_id, notify_count in room_to_count.items():
# room_to_count may include rooms which the user has left,
# ignore those.
if room_id not in joins:
continue

if notify_count == 0:
continue
Expand All @@ -51,8 +39,10 @@ async def get_room_unread_count(room_id: str) -> None:
# return one badge count per conversation
badge += 1
else:
# increment the badge count by the number of unread messages in the room
# Increase badge by number of notifications in room
# NOTE: this includes threaded and unthreaded notifications.
badge += notify_count

return badge


Expand Down
149 changes: 149 additions & 0 deletions synapse/storage/databases/main/event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
"""

import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
Collection,
Expand All @@ -95,6 +96,7 @@
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
PostgresEngine,
)
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore
Expand Down Expand Up @@ -463,6 +465,153 @@ def add_thread_id_summary_txn(txn: LoggingTransaction) -> int:

return result

async def get_unread_counts_by_room_for_user(self, user_id: str) -> Dict[str, int]:
"""Get the notification count by room for a user. Only considers notifications,
not highlight or unread counts, and threads are currently aggregated under their room.
This function is intentionally not cached because it is called to calculate the
unread badge for push notifications and thus the result is expected to change.
Note that this function assumes the user is a member of the room. Because
summary rows are not removed when a user leaves a room, the caller must
filter out those results from the result.
Returns:
A map of room ID to notification counts for the given user.
"""
return await self.db_pool.runInteraction(
"get_unread_counts_by_room_for_user",
self._get_unread_counts_by_room_for_user_txn,
user_id,
)

def _get_unread_counts_by_room_for_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Dict[str, int]:
receipt_types_clause, args = make_in_list_sql_clause(
self.database_engine,
"receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
)
args.extend([user_id, user_id])

receipts_cte = f"""
WITH all_receipts AS (
SELECT room_id, thread_id, MAX(event_stream_ordering) AS max_receipt_stream_ordering
FROM receipts_linearized
LEFT JOIN events USING (room_id, event_id)
WHERE
{receipt_types_clause}
AND user_id = ?
GROUP BY room_id, thread_id
)
"""

receipts_joins = """
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS threaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NOT NULL
) AS threaded_receipts USING (room_id, thread_id)
LEFT JOIN (
SELECT room_id, thread_id,
max_receipt_stream_ordering AS unthreaded_receipt_stream_ordering
FROM all_receipts
WHERE thread_id IS NULL
) AS unthreaded_receipts USING (room_id)
"""

# First get summary counts by room / thread for the user. We use the max receipt
# stream ordering of both threaded & unthreaded receipts to compare against the
# summary table.
#
# PostgreSQL and SQLite differ in comparing scalar numerics.
if isinstance(self.database_engine, PostgresEngine):
# GREATEST ignores NULLs.
max_clause = """GREATEST(
threaded_receipt_stream_ordering,
unthreaded_receipt_stream_ordering
)"""
else:
# MAX returns NULL if any are NULL, so COALESCE to 0 first.
max_clause = """MAX(
COALESCE(threaded_receipt_stream_ordering, 0),
COALESCE(unthreaded_receipt_stream_ordering, 0)
)"""

sql = f"""
{receipts_cte}
SELECT eps.room_id, eps.thread_id, notif_count
FROM event_push_summary AS eps
{receipts_joins}
WHERE user_id = ?
AND notif_count != 0
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > {max_clause})
OR last_receipt_stream_ordering = {max_clause}
)
"""
txn.execute(sql, args)

seen_thread_ids = set()
room_to_count: Dict[str, int] = defaultdict(int)

for room_id, thread_id, notif_count in txn:
room_to_count[room_id] += notif_count
seen_thread_ids.add(thread_id)

# Now get any event push actions that haven't been rotated using the same OR
# join and filter by receipt and event push summary rotated up to stream ordering.
sql = f"""
{receipts_cte}
SELECT epa.room_id, epa.thread_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND epa.notif = 1
AND stream_ordering > (SELECT stream_ordering FROM event_push_summary_stream_ordering)
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id, epa.thread_id
"""
txn.execute(sql, args)

for room_id, thread_id, notif_count in txn:
# Note: only count push actions we have valid summaries for with up to date receipt.
if thread_id not in seen_thread_ids:
continue
room_to_count[room_id] += notif_count

thread_id_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine, "epa.thread_id", seen_thread_ids
)

# Finally re-check event_push_actions for any rooms not in the summary, ignoring
# the rotated up-to position. This handles the case where a read receipt has arrived
# but not been rotated meaning the summary table is out of date, so we go back to
# the push actions table.
sql = f"""
{receipts_cte}
SELECT epa.room_id, COUNT(CASE WHEN epa.notif = 1 THEN 1 END) AS notif_count
FROM event_push_actions AS epa
{receipts_joins}
WHERE user_id = ?
AND NOT {thread_id_clause}
AND epa.notif = 1
AND (threaded_receipt_stream_ordering IS NULL OR stream_ordering > threaded_receipt_stream_ordering)
AND (unthreaded_receipt_stream_ordering IS NULL OR stream_ordering > unthreaded_receipt_stream_ordering)
GROUP BY epa.room_id
"""

args.extend(thread_ids_args)
txn.execute(sql, args)

for room_id, notif_count in txn:
room_to_count[room_id] += notif_count

return room_to_count

@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
Expand Down
47 changes: 39 additions & 8 deletions tests/storage/test_event_push_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_count_aggregation(self) -> None:

last_event_id: str

def _assert_counts(noitf_count: int, highlight_count: int) -> None:
def _assert_counts(notif_count: int, highlight_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
Expand All @@ -168,13 +168,22 @@ def _assert_counts(noitf_count: int, highlight_count: int) -> None:
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
)
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(aggregate_counts[room_id], notif_count)

def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
room_id,
Expand Down Expand Up @@ -283,7 +292,7 @@ def test_count_aggregation_threads(self) -> None:
last_event_id: str

def _assert_counts(
noitf_count: int,
notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
Expand All @@ -299,7 +308,7 @@ def _assert_counts(
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
Expand All @@ -318,6 +327,17 @@ def _assert_counts(
else:
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)

def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
Expand Down Expand Up @@ -454,7 +474,7 @@ def test_count_aggregation_mixed(self) -> None:
last_event_id: str

def _assert_counts(
noitf_count: int,
notif_count: int,
highlight_count: int,
thread_notif_count: int,
thread_highlight_count: int,
Expand All @@ -470,7 +490,7 @@ def _assert_counts(
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count,
notify_count=notif_count,
unread_count=0,
highlight_count=highlight_count,
),
Expand All @@ -489,6 +509,17 @@ def _assert_counts(
else:
self.assertEqual(counts.threads, {})

aggregate_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-aggregate-unread-counts",
self.store._get_unread_counts_by_room_for_user_txn,
user_id,
)
)
self.assertEqual(
aggregate_counts[room_id], notif_count + thread_notif_count
)

def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
Expand Down Expand Up @@ -646,7 +677,7 @@ def _create_event(type: str, content: JsonDict) -> str:
)
return result["event_id"]

def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
def _assert_counts(notif_count: int, thread_notif_count: int) -> None:
counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
Expand All @@ -658,7 +689,7 @@ def _assert_counts(noitf_count: int, thread_notif_count: int) -> None:
self.assertEqual(
counts.main_timeline,
NotifCounts(
notify_count=noitf_count, unread_count=0, highlight_count=0
notify_count=notif_count, unread_count=0, highlight_count=0
),
)
if thread_notif_count:
Expand Down

0 comments on commit e8bce89

Please sign in to comment.