Skip to content

Commit

Permalink
Convert user_get_threepids response to attrs. (matrix-org#16468)
Browse files Browse the repository at this point in the history
This improves type annotations by not having a dictionary of Any values.
  • Loading branch information
clokep authored Oct 12, 2023
1 parent a4904dc commit cc865ff
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 18 deletions.
1 change: 1 addition & 0 deletions changelog.d/16468.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve type hints.
4 changes: 2 additions & 2 deletions synapse/handlers/account_validity.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:

addresses = []
for threepid in threepids:
if threepid["medium"] == "email":
addresses.append(threepid["address"])
if threepid.medium == "email":
addresses.append(threepid.address)

return addresses

Expand Down
4 changes: 3 additions & 1 deletion synapse/handlers/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Sequence, Set

import attr

from synapse.api.constants import Direction, Membership
from synapse.events import EventBase
from synapse.types import JsonMapping, RoomStreamToken, StateMap, UserID, UserInfo
Expand Down Expand Up @@ -93,7 +95,7 @@ async def get_user(self, user: UserID) -> Optional[JsonMapping]:
]
user_info_dict["displayname"] = profile.display_name
user_info_dict["avatar_url"] = profile.avatar_url
user_info_dict["threepids"] = threepids
user_info_dict["threepids"] = [attr.asdict(t) for t in threepids]
user_info_dict["external_ids"] = external_ids
user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())

Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/deactivate_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ async def deactivate_account(

# Remove any local threepid associations for this account.
local_threepids = await self.store.user_get_threepids(user_id)
for threepid in local_threepids:
for local_threepid in local_threepids:
await self._auth_handler.delete_local_threepid(
user_id, threepid["medium"], threepid["address"]
user_id, local_threepid.medium, local_threepid.address
)

# delete any devices belonging to the user, which will also
Expand Down
2 changes: 1 addition & 1 deletion synapse/module_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ async def get_threepids_for_user(self, user_id: str) -> List[Dict[str, str]]:
"msisdn" for phone numbers, and an "address" key which value is the
threepid's address.
"""
return await self._store.user_get_threepids(user_id)
return [attr.asdict(t) for t in await self._store.user_get_threepids(user_id)]

def check_user_exists(self, user_id: str) -> "defer.Deferred[Optional[str]]":
"""Check if user exists.
Expand Down
3 changes: 1 addition & 2 deletions synapse/rest/admin/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,9 +329,8 @@ async def on_PUT(

if threepids is not None:
# get changed threepids (added and removed)
# convert List[Dict[str, Any]] into Set[Tuple[str, str]]
cur_threepids = {
(threepid["medium"], threepid["address"])
(threepid.medium, threepid.address)
for threepid in await self.store.user_get_threepids(user_id)
}
add_threepids = new_threepids - cur_threepids
Expand Down
4 changes: 3 additions & 1 deletion synapse/rest/client/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from pydantic.v1 import StrictBool, StrictStr, constr
else:
from pydantic import StrictBool, StrictStr, constr

import attr
from typing_extensions import Literal

from twisted.web.server import Request
Expand Down Expand Up @@ -595,7 +597,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:

threepids = await self.datastore.user_get_threepids(requester.user.to_string())

return 200, {"threepids": threepids}
return 200, {"threepids": [attr.asdict(t) for t in threepids]}

# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
# the endpoint is deprecated. (If you really want to, you could do this by reusing
Expand Down
19 changes: 14 additions & 5 deletions synapse/storage/databases/main/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,14 @@ class LoginTokenLookupResult:
"""The session ID advertised by the SSO Identity Provider."""


@attr.s(frozen=True, slots=True, auto_attribs=True)
class ThreepidResult:
medium: str
address: str
validated_at: int
added_at: int


class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(
self,
Expand Down Expand Up @@ -988,13 +996,14 @@ async def user_add_threepid(
{"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
)

async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
async def user_get_threepids(self, user_id: str) -> List[ThreepidResult]:
results = await self.db_pool.simple_select_list(
"user_threepids",
{"user_id": user_id},
["medium", "address", "validated_at", "added_at"],
"user_get_threepids",
keyvalues={"user_id": user_id},
retcols=["medium", "address", "validated_at", "added_at"],
desc="user_get_threepids",
)
return [ThreepidResult(**r) for r in results]

async def user_delete_threepid(
self, user_id: str, medium: str, address: str
Expand Down
8 changes: 4 additions & 4 deletions tests/module_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ def test_can_register_user(self) -> None:
self.assertEqual(len(emails), 1)

email = emails[0]
self.assertEqual(email["medium"], "email")
self.assertEqual(email["address"], "[email protected]")
self.assertEqual(email.medium, "email")
self.assertEqual(email.address, "[email protected]")

# Should these be 0?
self.assertEqual(email["validated_at"], 0)
self.assertEqual(email["added_at"], 0)
self.assertEqual(email.validated_at, 0)
self.assertEqual(email.added_at, 0)

# Check that the displayname was assigned
displayname = self.get_success(
Expand Down

0 comments on commit cc865ff

Please sign in to comment.