Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put pyright in strict mode #143

Merged
merged 6 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Set pyright to strict mode and fix type errors
  • Loading branch information
callumforrester committed Sep 16, 2024
commit c05b5f10770f7cfb72ba70b095ac5864b0ad96da
1,863 changes: 1 addition & 1,862 deletions schema.json

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/scanspec/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
@click.version_option(prog_name="scanspec", message="%(version)s")
@click.pass_context
def cli(ctx, log_level: str):
def cli(ctx: click.Context, log_level: str):
"""Top level scanspec command line interface."""
level = getattr(logging, log_level.upper(), None)
logging.basicConfig(format="%(levelname)s:%(message)s", level=level)
Expand Down Expand Up @@ -50,7 +50,7 @@ def plot(spec: str):
@click.option(
"--port", default=8080, help="The port that the scanspec service will be hosted on."
)
def service(cors, port):
def service(cors: bool, port: int):
"""Run up a REST service."""
from scanspec.service import run_app

Expand Down
98 changes: 57 additions & 41 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,21 @@

from __future__ import annotations

import itertools
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from typing import Any, Generic, Literal, TypeVar

import numpy as np
import numpy.typing as npt
from pydantic import BaseModel, ConfigDict, Field, GetCoreSchemaHandler
from pydantic.dataclasses import is_pydantic_dataclass, rebuild_dataclass
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

__all__ = [
"if_instance_do",
"Axis",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was in __all__ so the docs didn't complain about a missing ref, please can you check the docs still look ok after this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I agree with #143 (comment), I'm putting this back, however I was planning to manually check the docs and the plotting once you were happy with everything else, so will leave this thread open as a placeholder for that.

"AxesPoints",
"Frames",
"SnakedFrames",
Expand All @@ -31,6 +32,9 @@
StrictConfig: ConfigDict = {"extra": "forbid"}

C = TypeVar("C")
T = TypeVar("T")

GapArray = npt.NDArray[np.bool]


def discriminated_union_of_subclasses(
Expand Down Expand Up @@ -111,7 +115,7 @@ def calculate(self) -> int:
tagged_union = _TaggedUnion(super_cls, discriminator)
_tagged_unions[super_cls] = tagged_union

def add_subclass_to_union(subclass):
def add_subclass_to_union(subclass: type[C]):
# Add a discriminator field to a subclass so it can
# be identified when deserializing
subclass.__annotations__ = {
Expand All @@ -120,7 +124,9 @@ def add_subclass_to_union(subclass):
}
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):
def get_schema_of_union(
cls: type[C], source_type: Any, handler: GetCoreSchemaHandler
):
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
Expand All @@ -138,7 +144,7 @@ def get_schema_of_union(cls, source_type: Any, handler: GetCoreSchemaHandler):


class _TaggedUnion:
def __init__(self, base_class: type, discriminator: str):
def __init__(self, base_class: type[Any], discriminator: str):
self._base_class = base_class
# Classes and their field names that refer to this tagged union
self._discriminator = discriminator
Expand All @@ -154,7 +160,7 @@ def add_member(self, cls: type):
_TaggedUnion._rebuild(member)

@staticmethod
def _rebuild(cls_or_func: type | Callable):
def _rebuild(cls_or_func: Callable[..., T]) -> None:
if isclass(cls_or_func):
if is_pydantic_dataclass(cls_or_func):
rebuild_dataclass(cls_or_func, force=True)
Expand All @@ -170,11 +176,13 @@ def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:


@lru_cache(1)
def _make_schema(members: tuple[type, ...], handler):
def _make_schema(
members: tuple[type[Any], ...], handler: Callable[[Any], CoreSchema]
) -> dict[str, CoreSchema]:
return {member.__name__: handler(member) for member in members}


def if_instance_do(x: Any, cls: type, func: Callable):
def if_instance_do(x: C, cls: type[C], func: Callable[[C], T]):
"""If x is of type cls then return func(x), otherwise return NotImplemented.

Used as a helper when implementing operator overloading.
Expand All @@ -188,9 +196,12 @@ def if_instance_do(x: Any, cls: type, func: Callable):
#: A type variable for an `axis_` that can be specified for a scan
Axis = TypeVar("Axis")

#: Alternative axis variable to be used when two are required in the same type binding
OtherAxis = TypeVar("OtherAxis")

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = dict[Axis, np.ndarray]
AxesPoints = dict[Axis, npt.NDArray[np.floating[Any]]]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -226,7 +237,7 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
#: The midpoints of scan frames for each axis
self.midpoints = midpoints
Expand Down Expand Up @@ -274,7 +285,9 @@ def __len__(self) -> int:
# All axespoints arrays are same length, pick the first one
return len(self.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.signedinteger[Any]], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.

Args:
Expand All @@ -293,7 +306,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
return {k: v[dim_indices] for k, v in d.items()}
return {}

def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None:
def extract_gap(gaps: Iterable[GapArray]) -> GapArray | None:
for gap in gaps:
if not calculate_gap:
return gap[dim_indices]
Expand Down Expand Up @@ -326,7 +339,7 @@ def concat_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = np.concatenate(self.lower[ax], other.lower[ax])
return {a: np.concatenate([d[a] for d in ds]) for a in self.axes()}

def concat_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def concat_gap(gaps: Sequence[GapArray]) -> GapArray:
g = np.concatenate(gaps)
# Calc the first frame
g[0] = gap_between_frames(other, self)
Expand Down Expand Up @@ -354,7 +367,7 @@ def zip_dict(ds: Sequence[AxesPoints[Axis]]) -> AxesPoints[Axis]:
# lower[ax] = {**self.lower[ax], **other.lower[ax]}
return dict(kv for d in ds for kv in d.items())

def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def zip_gap(gaps: Sequence[GapArray]) -> GapArray:
# Gap if either frames has a gap. E.g.
# gap[i] = self.gap[i] | other.gap[i]
return np.logical_or.reduce(gaps)
Expand All @@ -364,24 +377,24 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:

def _merge_frames(
*stack: Frames[Axis],
dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None],
dict_merge: Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge: Callable[[Sequence[GapArray]], GapArray | None],
) -> Frames[Axis]:
types = {type(fs) for fs in stack}
assert len(types) == 1, f"Mismatching types for {stack}"
cls = types.pop()

# If any lower or upper are different, apply to those
kwargs = {}
for a in ("lower", "upper"):
if any(fs.midpoints is not getattr(fs, a) for fs in stack):
kwargs[a] = dict_merge([getattr(fs, a) for fs in stack])

# Apply to midpoints, force calculation of gap
return cls(
midpoints=dict_merge([fs.midpoints for fs in stack]),
gap=gap_merge([fs.gap for fs in stack]),
**kwargs,
# If any lower or upper are different, apply to those
lower=dict_merge([fs.lower for fs in stack])
if any(fs.midpoints is not fs.lower for fs in stack)
else None,
upper=dict_merge([fs.upper for fs in stack])
if any(fs.midpoints is not fs.upper for fs in stack)
else None,
)


Expand All @@ -393,19 +406,23 @@ def __init__(
midpoints: AxesPoints[Axis],
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
gap: GapArray | None = None,
):
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
# Override first element of gap to be True, as subsequent runs
# of snake scans are always joined end -> start
self.gap[0] = False

@classmethod
def from_frames(cls, frames: Frames[Axis]) -> SnakedFrames[Axis]:
def from_frames(
cls: type[SnakedFrames[Any]], frames: Frames[OtherAxis]
) -> SnakedFrames[OtherAxis]:
"""Create a snaked version of a `Frames` object."""
return cls(frames.midpoints, frames.lower, frames.upper, frames.gap)

def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
def extract(
self, indices: npt.NDArray[np.int32], calculate_gap: bool = True
) -> Frames[Axis]:
"""Return a new Frames object restricted to the indices provided.

Args:
Expand Down Expand Up @@ -434,23 +451,23 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
cls = type(self)
gap = None

# If lower or upper are different, apply to those
kwargs = {}
if self.midpoints is not self.lower:
# If going backwards select from the opposite bound
kwargs["lower"] = {
# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()},
gap=gap,
# If lower or upper are different, apply to those
lower={
k: np.where(backwards, self.upper[k][snake_indices], v[snake_indices])
for k, v in self.lower.items()
}
if self.midpoints is not self.upper:
kwargs["upper"] = {
if self.midpoints is not self.lower
else None,
upper={
k: np.where(backwards, self.lower[k][snake_indices], v[snake_indices])
for k, v in self.upper.items()
}

# Apply to midpoints
return cls(
{k: v[snake_indices] for k, v in self.midpoints.items()}, gap=gap, **kwargs
if self.midpoints is not self.upper
else None,
)


Expand All @@ -459,7 +476,9 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes())


def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
def squash_frames(
stack: list[Frames[Axis]], check_path_changes: bool = True
) -> Frames[Axis]:
"""Squash a stack of nested Frames into a single one.

Args:
Expand Down Expand Up @@ -624,10 +643,7 @@ def __init__(self, stack: list[Frames[Axis]]):
@property
def axes(self) -> list[Axis]:
"""The axes that will be present in each points dictionary."""
axes = []
for frames in self.stack:
axes += frames.axes()
return axes
return list(itertools.chain(*(frames.axes() for frames in self.stack)))

def __len__(self) -> int:
"""The number of dictionaries that will be produced if iterated over."""
Expand Down
Loading