Skip to content

Commit

Permalink
fix bug in setting state dict from relative path (SBU-BMI#123)
Browse files Browse the repository at this point in the history
  • Loading branch information
kaczmarj authored May 9, 2023
1 parent c0471c1 commit 45dfeea
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
11 changes: 8 additions & 3 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,12 +818,15 @@ def test_invalid_modeldefs(modeldef, tmp_path: Path):
def test_valid_modeldefs(tmp_path: Path):
from wsinfer._modellib.models import Weights

weights_file = tmp_path / "weights.pt"
# Put the weights in a different directory than the config to make sure that
# relative paths work.
weights_file = tmp_path / "ckpts" / "weights.pt"
weights_file.parent.mkdir()
modeldef = dict(
version="1.0",
name="foo",
architecture="resnet34",
file=str(weights_file),
file="ckpts/weights.pt",
num_classes=2,
transform=dict(resize_size=224, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
patch_size_pixels=350,
Expand All @@ -838,7 +841,9 @@ def test_valid_modeldefs(tmp_path: Path):
Weights.from_yaml(path)

weights_file.touch()
assert Weights.from_yaml(path)
w = Weights.from_yaml(path)
assert w.file is not None
assert Path(w.file).exists()


def test_model_registration(tmp_path: Path):
Expand Down
9 changes: 7 additions & 2 deletions wsinfer/_modellib/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,22 +173,27 @@ def _validate_input(d: dict, config_path: Path) -> None:
@classmethod
def from_yaml(cls, path):
"""Create a new instance of Weights from a YAML file."""
path = Path(path)

with open(path) as f:
d = yaml.safe_load(f)
cls._validate_input(d, config_path=Path(path))
cls._validate_input(d, config_path=path)

transform = PatchClassification(
resize_size=d["transform"]["resize_size"],
mean=d["transform"]["mean"],
std=d["transform"]["std"],
)
if d.get("file") is not None:
file = path.parent / d.get("file")
else:
file = None
return Weights(
name=d["name"],
architecture=d["architecture"],
url=d.get("url"),
url_file_name=d.get("url_file_name"),
file=d.get("file"),
file=file,
num_classes=d["num_classes"],
transform=transform,
patch_size_pixels=d["patch_size_pixels"],
Expand Down

0 comments on commit 45dfeea

Please sign in to comment.