Skip to content

Commit

Permalink
Add type hints to tests/rest/client (matrix-org#12066)
Browse files Browse the repository at this point in the history
  • Loading branch information
dklimpel authored Feb 23, 2022
1 parent 5b2b368 commit 64c73c6
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 119 deletions.
1 change: 1 addition & 0 deletions changelog.d/12066.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to `tests/rest/client`.
70 changes: 38 additions & 32 deletions tests/rest/client/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from typing import Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from twisted.internet.defer import succeed
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource

import synapse.rest.admin
from synapse.api.constants import LoginType
from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.rest.client import account, auth, devices, login, logout, register
from synapse.rest.synapse.client import build_synapse_client_resource_tree
from synapse.server import HomeServer
from synapse.storage.database import LoggingTransaction
from synapse.types import JsonDict, UserID
from synapse.util import Clock

from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
Expand All @@ -33,11 +37,11 @@


class DummyRecaptchaChecker(UserInteractiveAuthChecker):
def __init__(self, hs):
def __init__(self, hs: HomeServer) -> None:
super().__init__(hs)
self.recaptcha_attempts = []
self.recaptcha_attempts: List[Tuple[dict, str]] = []

def check_auth(self, authdict, clientip):
def check_auth(self, authdict: dict, clientip: str) -> Any:
self.recaptcha_attempts.append((authdict, clientip))
return succeed(True)

Expand All @@ -50,7 +54,7 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def make_homeserver(self, reactor, clock):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:

config = self.default_config()

Expand All @@ -61,7 +65,7 @@ def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(config=config)
return hs

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.recaptcha_checker = DummyRecaptchaChecker(hs)
auth_handler = hs.get_auth_handler()
auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
Expand Down Expand Up @@ -101,7 +105,7 @@ def recaptcha(
self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a")

def test_fallback_captcha(self):
def test_fallback_captcha(self) -> None:
"""Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec
channel = self.register(
Expand Down Expand Up @@ -132,7 +136,7 @@ def test_fallback_captcha(self):
# We're given a registered user.
self.assertEqual(channel.json_body["user_id"], "@user:test")

def test_complete_operation_unknown_session(self):
def test_complete_operation_unknown_session(self) -> None:
"""
Attempting to mark an invalid session as complete should error.
"""
Expand Down Expand Up @@ -165,7 +169,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]

def default_config(self):
def default_config(self) -> Dict[str, Any]:
config = super().default_config()

# public_baseurl uses an http:// scheme because FakeChannel.isSecure() returns
Expand All @@ -182,12 +186,12 @@ def default_config(self):

return config

def create_resource_dict(self):
def create_resource_dict(self) -> Dict[str, Resource]:
resource_dict = super().create_resource_dict()
resource_dict.update(build_synapse_client_resource_tree(self.hs))
return resource_dict

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.device_id = "dev1"
Expand Down Expand Up @@ -229,7 +233,7 @@ def delete_devices(self, expected_response: int, body: JsonDict) -> FakeChannel:

return channel

def test_ui_auth(self):
def test_ui_auth(self) -> None:
"""
Test user interactive authentication outside of registration.
"""
Expand Down Expand Up @@ -259,7 +263,7 @@ def test_ui_auth(self):
},
)

def test_grandfathered_identifier(self):
def test_grandfathered_identifier(self) -> None:
"""Check behaviour without "identifier" dict
Synapse used to require clients to submit a "user" field for m.login.password
Expand All @@ -286,7 +290,7 @@ def test_grandfathered_identifier(self):
},
)

def test_can_change_body(self):
def test_can_change_body(self) -> None:
"""
The client dict can be modified during the user interactive authentication session.
Expand Down Expand Up @@ -325,7 +329,7 @@ def test_can_change_body(self):
},
)

def test_cannot_change_uri(self):
def test_cannot_change_uri(self) -> None:
"""
The initial requested URI cannot be modified during the user interactive authentication session.
"""
Expand Down Expand Up @@ -362,7 +366,7 @@ def test_cannot_change_uri(self):
)

@unittest.override_config({"ui_auth": {"session_timeout": "5s"}})
def test_can_reuse_session(self):
def test_can_reuse_session(self) -> None:
"""
The session can be reused if configured.
Expand Down Expand Up @@ -409,7 +413,7 @@ def test_can_reuse_session(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_via_sso(self):
def test_ui_auth_via_sso(self) -> None:
"""Test a successful UI Auth flow via SSO
This includes:
Expand Down Expand Up @@ -452,7 +456,7 @@ def test_ui_auth_via_sso(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_does_not_offer_password_for_sso_user(self):
def test_does_not_offer_password_for_sso_user(self) -> None:
login_resp = self.helper.login_via_oidc("username")
user_tok = login_resp["access_token"]
device_id = login_resp["device_id"]
Expand All @@ -464,7 +468,7 @@ def test_does_not_offer_password_for_sso_user(self):
flows = channel.json_body["flows"]
self.assertEqual(flows, [{"stages": ["m.login.sso"]}])

def test_does_not_offer_sso_for_password_user(self):
def test_does_not_offer_sso_for_password_user(self) -> None:
channel = self.delete_device(
self.user_tok, self.device_id, HTTPStatus.UNAUTHORIZED
)
Expand All @@ -474,7 +478,7 @@ def test_does_not_offer_sso_for_password_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_offers_both_flows_for_upgraded_user(self):
def test_offers_both_flows_for_upgraded_user(self) -> None:
"""A user that had a password and then logged in with SSO should get both flows"""
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
self.assertEqual(login_resp["user_id"], self.user)
Expand All @@ -491,7 +495,7 @@ def test_offers_both_flows_for_upgraded_user(self):

@skip_unless(HAS_OIDC, "requires OIDC")
@override_config({"oidc_config": TEST_OIDC_CONFIG})
def test_ui_auth_fails_for_incorrect_sso_user(self):
def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
"""If the user tries to authenticate with the wrong SSO user, they get an error"""
# log the user in
login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
Expand Down Expand Up @@ -534,7 +538,7 @@ class RefreshAuthTests(unittest.HomeserverTestCase):
]
hijack_auth = False

def prepare(self, reactor, clock, hs):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)

Expand All @@ -548,7 +552,7 @@ def use_refresh_token(self, refresh_token: str) -> FakeChannel:
{"refresh_token": refresh_token},
)

def is_access_token_valid(self, access_token) -> bool:
def is_access_token_valid(self, access_token: str) -> bool:
"""
Checks whether an access token is valid, returning whether it is or not.
"""
Expand All @@ -561,7 +565,7 @@ def is_access_token_valid(self, access_token) -> bool:

return code == HTTPStatus.OK

def test_login_issue_refresh_token(self):
def test_login_issue_refresh_token(self) -> None:
"""
A login response should include a refresh_token only if asked.
"""
Expand Down Expand Up @@ -591,7 +595,7 @@ def test_login_issue_refresh_token(self):
self.assertIn("refresh_token", login_with_refresh.json_body)
self.assertIn("expires_in_ms", login_with_refresh.json_body)

def test_register_issue_refresh_token(self):
def test_register_issue_refresh_token(self) -> None:
"""
A register response should include a refresh_token only if asked.
"""
Expand Down Expand Up @@ -627,7 +631,7 @@ def test_register_issue_refresh_token(self):
self.assertIn("refresh_token", register_with_refresh.json_body)
self.assertIn("expires_in_ms", register_with_refresh.json_body)

def test_token_refresh(self):
def test_token_refresh(self) -> None:
"""
A refresh token can be used to issue a new access token.
"""
Expand Down Expand Up @@ -665,7 +669,7 @@ def test_token_refresh(self):
)

@override_config({"refreshable_access_token_lifetime": "1m"})
def test_refreshable_access_token_expiration(self):
def test_refreshable_access_token_expiration(self) -> None:
"""
The access token should have some time as specified in the config.
"""
Expand Down Expand Up @@ -722,7 +726,9 @@ def test_refreshable_access_token_expiration(self):
"nonrefreshable_access_token_lifetime": "10m",
}
)
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self):
def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(
self,
) -> None:
"""
Tests that the expiry times for refreshable and non-refreshable access
tokens can be different.
Expand Down Expand Up @@ -782,7 +788,7 @@ def test_different_expiry_for_refreshable_and_nonrefreshable_access_tokens(self)
@override_config(
{"refreshable_access_token_lifetime": "1m", "refresh_token_lifetime": "2m"}
)
def test_refresh_token_expiry(self):
def test_refresh_token_expiry(self) -> None:
"""
The refresh token can be configured to have a limited lifetime.
When that lifetime has ended, the refresh token can no longer be used to
Expand Down Expand Up @@ -834,7 +840,7 @@ def test_refresh_token_expiry(self):
"session_lifetime": "3m",
}
)
def test_ultimate_session_expiry(self):
def test_ultimate_session_expiry(self) -> None:
"""
The session can be configured to have an ultimate, limited lifetime.
"""
Expand Down Expand Up @@ -882,7 +888,7 @@ def test_ultimate_session_expiry(self):
refresh_response.code, HTTPStatus.FORBIDDEN, refresh_response.result
)

def test_refresh_token_invalidation(self):
def test_refresh_token_invalidation(self) -> None:
"""Refresh tokens are invalidated after first use of the next token.
A refresh token is considered invalid if:
Expand Down Expand Up @@ -987,7 +993,7 @@ def test_refresh_token_invalidation(self):
fifth_refresh_response.code, HTTPStatus.OK, fifth_refresh_response.result
)

def test_many_token_refresh(self):
def test_many_token_refresh(self) -> None:
"""
If a refresh is performed many times during a session, there shouldn't be
extra 'cruft' built up over time.
Expand Down
Loading

0 comments on commit 64c73c6

Please sign in to comment.