Skip to content

Commit

Permalink
Add type hints for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
WhyNotHugo committed Jun 17, 2024
1 parent 088ae46 commit 5128c79
Show file tree
Hide file tree
Showing 12 changed files with 597 additions and 281 deletions.
44 changes: 27 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,17 @@
import os
import time
from datetime import datetime
from typing import Callable
from typing import Iterable
from typing import ParamSpec
from typing import TypeVar
from uuid import uuid4

import py
import pytest
import pytz
from click.testing import CliRunner
from click.testing import Result
from dateutil.tz import tzlocal
from hypothesis import HealthCheck
from hypothesis import Verbosity
Expand All @@ -19,15 +25,15 @@


@pytest.fixture()
def default_database(tmpdir):
def default_database(tmpdir: py.path.local) -> model.Database:
return model.Database(
[tmpdir.mkdir("default")],
tmpdir.mkdir(uuid4().hex).join("cache.sqlite3"),
)


@pytest.fixture()
def config(tmpdir, default_database):
def config(tmpdir: py.path.local, default_database: model.Database) -> py.path.local:
config_path = tmpdir.join("config.py")
config_path.write(
f'path = "{tmpdir}/*"\n'
Expand All @@ -38,23 +44,27 @@ def config(tmpdir, default_database):
return config_path


_T = TypeVar("_T")
_P = ParamSpec("_P")


@pytest.fixture()
def runner(config, sleep):
def runner(config: py.path.local, sleep: Callable[[], None]) -> CliRunner:
class SleepyCliRunner(CliRunner):
"""
Sleeps before invoking to make sure cache entries have expired.
"""

def invoke(self, *args, **kwargs):
def invoke(self, *args, **kwargs) -> Result:
sleep()
return super().invoke(*args, **kwargs)

return SleepyCliRunner(env={"TODOMAN_CONFIG": str(config)})


@pytest.fixture()
def create(tmpdir):
def inner(name, content, list_name="default"):
def create(tmpdir: py.path.local) -> Callable[[str, str, str], py.path.local]:
def inner(name: str, content: str, list_name: str = "default") -> py.path.local:
path = tmpdir.ensure_dir(list_name).join(name)
path.write(
"BEGIN:VCALENDAR\nBEGIN:VTODO\n" + content + "END:VTODO\nEND:VCALENDAR"
Expand All @@ -65,8 +75,8 @@ def inner(name, content, list_name="default"):


@pytest.fixture()
def now_for_tz():
def inner(tz="CET"):
def now_for_tz() -> Callable[[str], datetime]:
def inner(tz: str = "CET") -> datetime:
"""
Provides the current time cast to a given timezone.
Expand All @@ -80,8 +90,8 @@ def inner(tz="CET"):


@pytest.fixture()
def todo_factory(default_database):
def inner(**attributes):
def todo_factory(default_database: model.Database) -> Callable:
def inner(**attributes) -> model.Todo:
todo = model.Todo(new=True)
todo.list = next(iter(default_database.lists()))

Expand All @@ -97,17 +107,17 @@ def inner(**attributes):


@pytest.fixture()
def default_formatter():
def default_formatter() -> DefaultFormatter:
return DefaultFormatter(tz_override=pytz.timezone("CET"))


@pytest.fixture()
def humanized_formatter():
def humanized_formatter() -> HumanizedFormatter:
return HumanizedFormatter(tz_override=pytz.timezone("CET"))


@pytest.fixture(scope="session")
def sleep(tmpdir_factory):
def sleep(tmpdir_factory: pytest.TempdirFactory) -> Callable[[], None]:
"""
Sleeps as long as needed for the filesystem's mtime to pick up differences
Expand All @@ -119,12 +129,12 @@ def sleep(tmpdir_factory):
"""
tmpfile = tmpdir_factory.mktemp("sleep").join("touch_me")

def touch_and_mtime():
def touch_and_mtime() -> float:
tmpfile.open("w").close()
stat = os.stat(str(tmpfile))
return getattr(stat, "st_mtime_ns", stat.st_mtime)

def inner():
def inner() -> None:
time.sleep(i)

i = 0.00001
Expand All @@ -149,8 +159,8 @@ def inner():


@pytest.fixture()
def todos(default_database, sleep):
def inner(**filters):
def todos(default_database: model.Database, sleep: Callable[[], None]) -> Callable:
def inner(**filters) -> Iterable[model.Todo]:
sleep()
default_database.update_cache()
return default_database.todos(**filters)
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest


def is_fs_case_sensitive():
def is_fs_case_sensitive() -> bool | None:
with TemporaryDirectory() as tmpdir:
os.mkdir(os.path.join(tmpdir, "casesensitivetest"))
try:
Expand All @@ -20,7 +20,7 @@ def is_fs_case_sensitive():
return False


def is_pyicu_installed():
def is_pyicu_installed() -> bool:
try:
import icu # noqa: F401: This is an import to tests if it's installed.
except ImportError:
Expand Down
38 changes: 20 additions & 18 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,57 +2,51 @@

from datetime import date
from datetime import datetime
from typing import Callable

import icalendar
import pytest
import py
import pytz
from dateutil.tz import tzlocal
from freezegun import freeze_time

from todoman.model import Database
from todoman.model import Todo
from todoman.model import VtodoWriter


def test_datetime_serialization(todo_factory, tmpdir):
def test_datetime_serialization(todo_factory: Callable, tmpdir: py.local.path) -> None:
now = datetime(2017, 8, 31, 23, 49, 53, tzinfo=pytz.UTC)
todo = todo_factory(created_at=now)
filename = tmpdir.join("default").join(todo.filename)
with open(str(filename)) as f:
assert "CREATED:20170831T234953Z\n" in f.readlines()


def test_serialize_created_at(todo_factory):
def test_serialize_created_at(todo_factory: Callable) -> None:
now = datetime.now(tz=pytz.UTC)
todo = todo_factory(created_at=now)
vtodo = VtodoWriter(todo).serialize()

assert vtodo.get("created") is not None


def test_serialize_dtstart(todo_factory):
def test_serialize_dtstart(todo_factory: Callable) -> None:
now = datetime.now(tz=pytz.UTC)
todo = todo_factory(start=now)
vtodo = VtodoWriter(todo).serialize()

assert vtodo.get("dtstart") is not None


def test_serializer_raises(todo_factory):
todo = todo_factory()
writter = VtodoWriter(todo)

with pytest.raises(Exception, match="Unknown field nonexistant"):
writter.serialize_field("nonexistant", 7)


def test_supported_fields_are_serializeable():
def test_supported_fields_are_serializeable() -> None:
supported_fields = set(Todo.ALL_SUPPORTED_FIELDS)
serialized_fields = set(VtodoWriter.FIELD_MAP.keys())

assert supported_fields == serialized_fields


def test_vtodo_serialization(todo_factory):
def test_vtodo_serialization(todo_factory: Callable) -> None:
"""Test VTODO serialization: one field of each type."""
description = "A tea would be nice, thanks."
todo = todo_factory(
Expand All @@ -78,12 +72,20 @@ def test_vtodo_serialization(todo_factory):


@freeze_time("2017-04-04 20:11:57")
def test_update_last_modified(todo_factory, todos, tmpdir):
def test_update_last_modified(
todo_factory: Callable,
todos: Callable,
tmpdir: py.path.local,
) -> None:
todo = todo_factory()
assert todo.last_modified == datetime.now(tzlocal())


def test_sequence_increment(default_database, todo_factory, todos):
def test_sequence_increment(
default_database: Database,
todo_factory: Callable,
todos: Callable,
) -> None:
todo = todo_factory()
assert todo.sequence == 1

Expand All @@ -95,8 +97,8 @@ def test_sequence_increment(default_database, todo_factory, todos):
assert todo.sequence == 2


def test_normalize_datetime():
writter = VtodoWriter(None)
def test_normalize_datetime(todo_factory: Callable) -> None:
writter = VtodoWriter(todo_factory())
assert writter.normalize_datetime(date(2017, 6, 17)) == date(2017, 6, 17)
assert writter.normalize_datetime(datetime(2017, 6, 17)) == datetime(
2017, 6, 17, tzinfo=tzlocal()
Expand Down
Loading

0 comments on commit 5128c79

Please sign in to comment.