Skip to content

Commit

Permalink
Huge update
Browse files Browse the repository at this point in the history
  • Loading branch information
Mr-Milk committed Aug 30, 2024
1 parent 8dd29b9 commit dc3027d
Show file tree
Hide file tree
Showing 22 changed files with 451 additions and 113 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,5 @@ cython_debug/
.idea/
.DS_Store

work/
work/
.nextflow.log*
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[build-system]
requires = ["hatchling", "versioningit"]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
Expand Down Expand Up @@ -54,6 +54,13 @@ dev = [
"pytest",
]

gigapath = [
"timm>=1.0.3",
"git+https://github.com/prov-gigapath/prov-gigapath.git",
"fairscale",
"einops",
]

# Define entry points
[project.scripts]
lazyslide = "lazyslide.__main__:app"
Expand Down
5 changes: 2 additions & 3 deletions src/lazyslide/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import pkg_resources
"""Efficient and Scalable Whole Slide Image (WSI) processing library."""
__version__ = "0.1.0"

from wsi_data import open_wsi
import lazyslide.pp as pp
import lazyslide.tl as tl
import lazyslide.pl as pl

version = __version__ = pkg_resources.get_distribution("lazyslide").version
40 changes: 28 additions & 12 deletions src/lazyslide/__main__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from pathlib import Path
from typing import Optional

import pandas as pd
from rich import print
from typer import Typer, Argument, Option

Expand Down Expand Up @@ -49,20 +51,20 @@ def preprocess(
filter_artifacts=filter_artifacts,
)

zs.pp.tissue_qc(wsi, tissue_qc.split(","))
if filter_tissue:
tissue_tb = wsi.sdata["tissues"]
wsi.sdata["tissues"] = tissue_tb[tissue_tb["qc"]]
# zs.pp.tissue_qc(wsi, tissue_qc.split(","))
# if filter_tissue:
# tissue_tb = wsi.sdata["tissues"]
# wsi.sdata["tissues"] = tissue_tb[tissue_tb["qc"]]

zs.pp.tiles(wsi, tile_px=tile_px, stride_px=stride_px, mpp=mpp)

zs.pp.tiles_qc(wsi, tile_qc.split(","))
if filter_tiles:
tile_tb = wsi.sdata["tiles"]
wsi.sdata["tiles"] = tile_tb[tile_tb["qc"]]

if report:
zs.pl.qc_summary(wsi, ["brightness", "redness"], ["focus", "contrast"])
# zs.pp.tiles_qc(wsi, tile_qc.split(","))
# if filter_tiles:
# tile_tb = wsi.sdata["tiles"]
# wsi.sdata["tiles"] = tile_tb[tile_tb["qc"]]
#
# if report:
# zs.pl.qc_summary(wsi, ["brightness", "redness"], ["focus", "contrast"])

wsi.save()
print(f"Saved to {wsi.sdata.path}")
Expand All @@ -72,13 +74,27 @@ def preprocess(
def feature(
slide: str = WSI,
model: str = Argument(..., help="A model name or the path to the model file"),
slide_encoder: str = None,
output: Optional[str] = OUTPUT,
):
import lazyslide as zs

wsi = zs.open_wsi(slide, backed_file=output)
print(f"Read slide file {slide}")
print(f"Extract features using model {model}")
zs.tl.feature_extraction(wsi, model)
zs.tl.feature_extraction(wsi, model, slide_encoder=slide_encoder)
wsi.save()
print(f"Write to {wsi.sdata.path}")


@app.command()
def agg_wsi(
slide_table: Path = Argument(..., help="The slide table file"),
output: Optional[str] = OUTPUT,
):
from wsi_data import agg_wsi

print(f"Read slide table {slide_table}")
slides_table = pd.read_csv(slide_table)
data = agg_wsi(slides_table, "features")
data.write_zarr(output)
1 change: 1 addition & 0 deletions src/lazyslide/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .plip import PLIP, PLIPVision
from .conch import CONCH, CONCHVision
from .gigapath import GigaPath, GigaPathSlideEncoder
4 changes: 2 additions & 2 deletions src/lazyslide/models/conch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class CONCH(torch.nn.Module):
def __init__(self, model_path=None, auth_token=None):
def __init__(self, model_path=None, token=None):
try:
from conch.open_clip_custom import create_model_from_pretrained
from conch.open_clip_custom import get_tokenizer
Expand All @@ -18,7 +18,7 @@ def __init__(self, model_path=None, auth_token=None):
model_path = "hf_hub:MahmoodLab/conch"

self.model, self.processor = create_model_from_pretrained(
"conch_ViT-B-16", model_path, hf_auth_token=auth_token
"conch_ViT-B-16", model_path, hf_auth_token=token
)
self.tokenizer = get_tokenizer()

Expand Down
25 changes: 20 additions & 5 deletions src/lazyslide/models/gigapath.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


class GigaPath(torch.nn.Module):
def __init__(self, model_path=None, auth_token=None):
def __init__(self, model_path=None, token=None):
try:
import timm
from huggingface_hub import login
Expand All @@ -14,8 +14,23 @@ def __init__(self, model_path=None, auth_token=None):

super().__init__()

if auth_token is not None:
login(auth_token)
# Version check
try:
from packaging import version

timm_version = version.parse(timm.__version__)
minimum_version = version.parse("1.0.3")
if timm_version < minimum_version:
raise ImportError(
f"Gigapath needs timm >= 1.0.3. You have version {timm_version}."
f"Run `pip install --upgrade timm` to install the latest version."
)
# If packaging is not installed, skip the version check
except ImportError:
pass

if token is not None:
login(token)

model = timm.create_model("hf_hub:prov-gigapath/prov-gigapath", pretrained=True)
self.model = model
Expand All @@ -29,7 +44,7 @@ def __init__(self, model_path=None, auth_token=None):
try:
import timm
from huggingface_hub import login
import gigapath
from gigapath.slide_encoder import create_model
except ImportError:
raise ImportError(
"To use GigaPathSlideEncoder, you need to install timm and gigapath. You can install it using "
Expand All @@ -41,7 +56,7 @@ def __init__(self, model_path=None, auth_token=None):
if auth_token is not None:
login(auth_token)

model = gigapath.slide_encoder.create_model(
model = create_model(
"hf_hub:prov-gigapath/prov-gigapath", "gigapath_slide_enc12l768d", 1536
)
self.model = model
Expand Down
16 changes: 6 additions & 10 deletions src/lazyslide/models/plip.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,15 @@


class PLIP(torch.nn.Module):
def __init__(self, model_path=None, auth_token=None):
def __init__(self, model_path=None, token=None):
from transformers import CLIPModel, CLIPProcessor

super().__init__()

if model_path is None:
model_path = "vinid/plip"
self.model = CLIPModel.from_pretrained(model_path, use_auth_token=auth_token)
self.processor = CLIPProcessor.from_pretrained(
model_path, use_auth_token=auth_token
)
self.model = CLIPModel.from_pretrained(model_path, use_auth_token=token)
self.processor = CLIPProcessor.from_pretrained(model_path, use_auth_token=token)

def encode_image(self, image, normalize=True):
if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -47,19 +45,17 @@ def forward(self, image):


class PLIPVision(torch.nn.Module):
def __init__(self, model_path=None, auth_token=None):
def __init__(self, model_path=None, token=None):
from transformers import CLIPVisionModelWithProjection, CLIPProcessor

super().__init__()

if model_path is None:
model_path = "vinid/plip"
self.model = CLIPVisionModelWithProjection.from_pretrained(
model_path, use_auth_token=auth_token
)
self.processor = CLIPProcessor.from_pretrained(
model_path, use_auth_token=auth_token
model_path, use_auth_token=token
)
self.processor = CLIPProcessor.from_pretrained(model_path, use_auth_token=token)

def encode_image(self, image, normalize=False):
inputs = self.processor(images=image, return_tensors="pt")
Expand Down
8 changes: 4 additions & 4 deletions src/lazyslide/pl/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@
ADOBE_SPECTRUM = [
"#0FB5AE",
"#F68511",
"#4046CA",
"#7326D3",
"#147AF3",
"#E8C600",
"#DE3D82",
"#72E06A",
"#7E84FA",
"#DE3D82",
"#7326D3",
"#008F5D",
"#CB5D00",
"#E8C600",
"#4046CA",
"#BCE931",
]

Expand Down
1 change: 1 addition & 0 deletions src/lazyslide/pp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .tissue import find_tissue, tissue_qc
from .tiles import tiles, tiles_qc
from .graph import tile_graph
33 changes: 20 additions & 13 deletions src/lazyslide/pp/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,30 @@
from itertools import chain

import numpy as np
import pandas as pd
from numba import njit
from anndata import AnnData
from scipy.sparse import csr_matrix, spmatrix, isspmatrix_csr, SparseEfficiencyWarning
from scipy.spatial import Delaunay
from sklearn.metrics import euclidean_distances
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.neighbors import NearestNeighbors

from wsi_data import WSIData
from lazyslide._const import Key
from lazyslide.wsi import WSI


def tile_graph(
wsi: WSI,
wsi: WSIData,
n_neighs: int = 4,
n_rings: int = 1,
delaunay=False,
transform: str = None,
set_diag: bool = False,
tile_key: str = Key.tiles,
table_key: str = None,
):
import anndata as ad

coords = wsi.get_tiles_table(tile_key, ["x", "y"]).values
coords = wsi.sdata[tile_key][["x", "y"]].values
Adj, Dst = _spatial_neighbor(
coords, n_neighs, delaunay, n_rings, transform, set_diag
)
Expand All @@ -39,16 +40,22 @@ def tile_graph(
"transform": transform,
},
}

if f"{tile_key}_graph" in wsi.sdata.tables:
table = wsi.sdata.tables[f"{tile_key}_graph"]
# TODO: Store in a anndata object
if table_key is None:
table_key = Key.tile_graph(tile_key)
if table_key not in wsi.sdata:
table = AnnData(
obs=pd.DataFrame(index=np.arange(coords.shape[0], dtype=int).astype(str)),
obsp={conns_key: Adj, dists_key: Dst},
uns=neighbors_dict,
)
else:
table = ad.AnnData()
table = wsi.sdata[table_key]
table.obsp[conns_key] = Adj
table.obsp[dists_key] = Dst
table.uns["spatial"] = neighbors_dict

table.obsp[conns_key] = Adj
table.obsp[dists_key] = Dst
table.uns["neighbors"] = neighbors_dict
wsi.sdata.tables[f"{tile_key}_graph"] = table
wsi.add_table(table_key, table)


def _spatial_neighbor(
Expand Down
41 changes: 41 additions & 0 deletions src/lazyslide/pp/load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from __future__ import annotations

from pathlib import Path
from typing import List

from geopandas import GeoDataFrame

from lazyslide._const import Key
from wsi_data import WSIData


def load_annotations(
wsi: WSIData,
annotations: str | Path | GeoDataFrame = None,
join_with: str | List[str] = Key.tiles,
key_added: str = "annotations",
):
"""Load the geojson file and add it to the WSI data"""
import geopandas as gpd

if isinstance(annotations, (str, Path)):
geo_path = Path(annotations)
anno_df = gpd.read_file(geo_path)
elif isinstance(annotations, GeoDataFrame):
anno_df = annotations
else:
raise ValueError(f"Invalid annotations: {annotations}")

wsi.add_shapes(key_added, anno_df)

# get tiles
if isinstance(join_with, str):
join_with = [join_with]

for key in join_with:
if key in wsi.sdata:
tile_df = wsi.sdata[key]
# join the annotations with the tiles
gdf = gpd.sjoin(tile_df[["geometry"]], anno_df, how="left", op="intersects")
wsi.update_shapes_data(key, gdf)
return wsi
5 changes: 4 additions & 1 deletion src/lazyslide/tl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
from .features import feature_extraction
from .features import feature_extraction, encode_slide
from .tissue_props import tissue_props
from .utag import utag_feature
from .domain import spatial_domain
from .text_annotate import text_embedding
Loading

0 comments on commit dc3027d

Please sign in to comment.