diff --git a/.gitignore b/.gitignore index 24ba9b9..5a5c7c9 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ docker-compose.yml Dockerfile +/config.yaml +*.bin # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/README.md b/README.md index 5a43594..249942c 100755 --- a/README.md +++ b/README.md @@ -1,9 +1,13 @@ [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/from-whole-slide-image-to-biomarker/classification-on-tcga)](https://paperswithcode.com/sota/classification-on-tcga?p=from-whole-slide-image-to-biomarker) + +> [!Important] +> STAMP v1.1.0 now uses PyTorch's FlashAttentionV2 implementation, which substantially improves memory efficiency when training. With this update, you *cannot* deploy a saved model from STAMP version ≤ 1.0.3 with this or subsequent versions. Therefore, it is recommended to only update to the latest version of STAMP when starting new experiments. Additionally, the optimizer has been updated from Adam to AdamW. Lastly, STAMP has built-in support for the [UNI Feature extractor](https://www.nature.com/articles/s41591-024-02857-3). Using it will require a Hugging Face account with granted access to the UNI model. For details on fair use, licensing and accessing the UNI model weights, refer to the [UNI GitHub repository](https://www.github.com/mahmoodlab/UNI.git). Note that the installation instructions and results within the STAMP [Nature Protocols paper](https://www.nature.com/articles/s41596-024-01047-2) refer to v1.0.3 of the software. The README file will always contain the most up-to-date installation instructions. + # STAMP protocol -A protocol for Solid Tumor Associative Modeling in Pathology. This repository contains the accompanying code for the steps described in the [preprint](https://arxiv.org/abs/2312.10944v1): +A protocol for Solid Tumor Associative Modeling in Pathology. This repository contains the accompanying code for the steps described in the [Nature Protocols paper](https://www.nature.com/articles/s41596-024-01047-2): ->From Whole Slide Image to Biomarker Prediction: A Protocol for End-to-End Deep Learning in Computational Pathology +>From whole-slide image to biomarker prediction: end-to-end weakly supervised deep learning in computational pathology The code can be executed either in a local environment, or in a containerized environment (preferred in clusters). @@ -30,11 +34,20 @@ pip install git+https://github.com/KatherLab/STAMP Once installed, you will be able to run the command line interface directly using the `stamp` command. -Finally, to download required resources such as the weights of the CTransPath feature extractor, run the following command: +Next, initialize STAMP and obtain the required configuration file, config.yaml, in your current working directory, by running the following command: + +```bash +stamp init +``` + +To download required resources such as the weights of the feature extractor, run the following command: ```bash stamp setup ``` +> [!Note] +> If you select a different feature extractor withing the configuration file, such as UNI, you will need to re-run the previous setup command to initiate the downloading step of the UNI feature extractor weights. This will trigger a prompt asking for your Hugging Face access key for the UNI model weights. + ## Using the container First, install Go and Singularity on your local machine using the [official installation instructions](https://docs.sylabs.io/guides/3.0/user-guide/installation.html). Note that the High-Performance Cluster (HPC) has Go and Singularity pre-installed, and do not require installation. @@ -61,6 +74,7 @@ Note that the binding of filesystems (-B) should be adapted to your own system. ## Running Available commands are: ```bash +stamp init # create a new configuration file in the current directory stamp setup # download required resources stamp config # print resolved configuration stamp preprocess # normalization and feature extraction with CTransPath @@ -71,19 +85,24 @@ stamp statistics # compute stats including ROC curves stamp heatmaps # generate heatmaps ``` -By default, stamp will use the configuration file `config.yaml` in the current working directory. If you want to use a different configuration file use the `--config` command line option, i.e. `stamp --config some/other/file.yaml train`. +> [!NOTE] +> By default, STAMP will use the configuration file `config.yaml` in the current working directory (or, if that does not exist, it will use the [default STAMP configuration file](stamp/config.yaml) shipped with this package). If you want to use a different configuration file, use the `--config` command line option, i.e. `stamp --config some/other/file.yaml train`. Note that the `--config` option must be supplied before any of the subcommands. You may also run `stamp init` to create a local `config.yaml` in the current working directory initialized to the default settings. ## Reference -If you find our work useful in your research or if you use parts of this code please consider citing our [preprint](https://arxiv.org/abs/2312.10944v1): +If you find our work useful in your research or if you use parts of this code please consider citing our [Nature Protocols publication](https://www.nature.com/articles/s41596-024-01047-2): ``` -@misc{nahhas2023wholeslide, - title={From Whole-slide Image to Biomarker Prediction: A Protocol for End-to-End Deep Learning in Computational Pathology}, - author={Omar S. M. El Nahhas and Marko van Treeck and Georg Wölflein and Michaela Unger and Marta Ligero and Tim Lenz and Sophia J. Wagner and Katherine J. Hewitt and Firas Khader and Sebastian Foersch and Daniel Truhn and Jakob Nikolas Kather}, - year={2023}, - eprint={2312.10944}, - archivePrefix={arXiv}, - primaryClass={cs.CV} +@Article{ElNahhas2024, +author={El Nahhas, Omar S. M. and van Treeck, Marko and W{\"o}lflein, Georg and Unger, Michaela and Ligero, Marta and Lenz, Tim and Wagner, Sophia J. and Hewitt, Katherine J. and Khader, Firas and Foersch, Sebastian and Truhn, Daniel and Kather, Jakob Nikolas}, +title={From whole-slide image to biomarker prediction: end-to-end weakly supervised deep learning in computational pathology}, +journal={Nature Protocols}, +year={2024}, +month={Sep}, +day={16}, +issn={1750-2799}, +doi={10.1038/s41596-024-01047-2}, +url={https://doi.org/10.1038/s41596-024-01047-2} } + ``` diff --git a/pyproject.toml b/pyproject.toml index c637ea4..a554c9c 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,9 +2,12 @@ requires = ["hatchling"] build-backend = "hatchling.build" +[tool.hatch.metadata] +allow-direct-references = true + [project] name = "stamp" -version = "0.0.2" +version = "1.1.1" authors = [ { name="Omar El Nahhas", email="omar.el_nahhas@tu-dresden.de" }, { name="Marko van Treeck", email="markovantreeck@gmail.com" }, @@ -27,7 +30,7 @@ dependencies = [ "scikit-learn~=1.3", "tqdm~=4.66", "fastai~=2.7", - "torch~=2.0", + "torch~=2.2", "torchvision~=0.15", "h5py~=3.10", "jaxtyping~=0.2", @@ -37,15 +40,16 @@ dependencies = [ "opencv-python~=4.8", "numba~=0.58", "gdown~=4.7", - "openpyxl~=3.1" + "openpyxl~=3.1", + "UNI@git+https://github.com/mahmoodlab/UNI.git" ] [project.scripts] "stamp" = "stamp.cli:main" [project.urls] -"Homepage" = "https://github.com/Avic3nna/STAMP" -"Bug Tracker" = "https://github.com/Avic3nna/STAMP/issues" +"Homepage" = "https://github.com/KatherLab/STAMP" +"Bug Tracker" = "https://github.com/KatherLab/STAMP/issues" [tool.hatch.build] include = [ diff --git a/setup/setup.sh b/setup/setup.sh index 77a6d1c..1fe156e 100644 --- a/setup/setup.sh +++ b/setup/setup.sh @@ -1,5 +1,7 @@ #!/bin/bash +set -eux + # get the directory of the script script_dir=$(realpath "$(dirname "${0}")") diff --git a/stamp/cli.py b/stamp/cli.py index 52457a6..4304dcb 100755 --- a/stamp/cli.py +++ b/stamp/cli.py @@ -1,17 +1,39 @@ -from omegaconf import OmegaConf +from omegaconf import OmegaConf, DictConfig +from omegaconf.listconfig import ListConfig import argparse from pathlib import Path -from omegaconf.dictconfig import DictConfig -from functools import partial import os from typing import Iterable, Optional +import shutil NORMALIZATION_TEMPLATE_URL = "https://github.com/Avic3nna/STAMP/blob/main/resources/normalization_template.jpg?raw=true" CTRANSPATH_WEIGHTS_URL = "https://drive.google.com/u/0/uc?id=1DoDx_70_TLj98gTf6YTXnu4tFhsFocDX&export=download" +DEFAULT_RESOURCES_DIR = Path(__file__).with_name("resources") +DEFAULT_CONFIG_FILE = Path("config.yaml") +STAMP_FACTORY_SETTINGS = Path(__file__).with_name("config.yaml") class ConfigurationError(Exception): pass +def check_path_exists(path): + directories = path.split(os.path.sep) + current_path = os.path.sep + for directory in directories: + current_path = os.path.join(current_path, directory) + if not os.path.exists(current_path): + return False, directory + return True, None + + +def check_and_handle_path(path, path_key, prefix): + exists, directory = check_path_exists(path) + if not exists: + print(f"From input path: '{path}'") + print(f"Directory '{directory}' does not exist.") + print(f"Check the input path of '{path_key}' from the '{prefix}' section.") + raise SystemExit(f"Stopping {prefix} due to faulty user input...") + + def _config_has_key(cfg: DictConfig, key: str): try: for k in key.split("."): @@ -22,28 +44,72 @@ def _config_has_key(cfg: DictConfig, key: str): return False return True -def require_configs(cfg: DictConfig, keys: Iterable[str], prefix: Optional[str] = None): - prefix = f"{prefix}." if prefix else "" - keys = [f"{prefix}{k}" for k in keys] +def require_configs(cfg: DictConfig, keys: Iterable[str], prefix: Optional[str] = None, + paths_to_check: Iterable[str] = []): + keys = [f"{prefix}.{k}" for k in keys] missing = [k for k in keys if not _config_has_key(cfg, k)] if len(missing) > 0: raise ConfigurationError(f"Missing required configuration keys: {missing}") + # Check if paths exist + for path_key in paths_to_check: + try: + #for all but modeling.statistics + path = cfg[prefix][path_key] + except: + #for modeling.statistics, handling the pred_csvs + path = OmegaConf.select(cfg, f"{prefix}.{path_key}") + if isinstance(path, ListConfig): + for p in path: + check_and_handle_path(p, path_key, prefix) + else: + check_and_handle_path(path, path_key, prefix) + + +def create_config_file(config_file: Optional[Path]): + """Create a new config file at the specified path (by copying the default config file).""" + config_file = config_file or DEFAULT_CONFIG_FILE + # Locate original config file + if not STAMP_FACTORY_SETTINGS.exists(): + raise ConfigurationError(f"Default STAMP config file not found at {STAMP_FACTORY_SETTINGS}") + # Copy original config file + shutil.copy(STAMP_FACTORY_SETTINGS, config_file) + print(f"Created new config file at {config_file.absolute()}") + +def resolve_config_file_path(config_file: Optional[Path]) -> Path: + """Resolve the path to the config file, falling back to the default config file if not specified.""" + if config_file is None: + if DEFAULT_CONFIG_FILE.exists(): + config_file = DEFAULT_CONFIG_FILE + else: + config_file = STAMP_FACTORY_SETTINGS + print(f"Falling back to default STAMP config file because {DEFAULT_CONFIG_FILE.absolute()} does not exist") + if not config_file.exists(): + raise ConfigurationError(f"Default STAMP config file not found at {config_file}") + if not config_file.exists(): + raise ConfigurationError(f"Config file {Path(config_file).absolute()} not found (run `stamp init` to create the config file or use the `--config` flag to specify a different config file)") + return config_file + def run_cli(args: argparse.Namespace): + # Handle init command + if args.command == "init": + create_config_file(args.config) + return + # Load YAML configuration - try: - cfg = OmegaConf.load(args.config) - except FileNotFoundError: - raise ConfigurationError(f"Config file {args.config} not found (use the --config flag to specify a different config file)") - + config_file = resolve_config_file_path(args.config) + cfg = OmegaConf.load(config_file) + # Set environment variables if "STAMP_RESOURCES_DIR" not in os.environ: - os.environ["STAMP_RESOURCES_DIR"] = str(Path(args.config).with_name("resources")) + os.environ["STAMP_RESOURCES_DIR"] = str(DEFAULT_RESOURCES_DIR) match args.command: + case "init": + return # this is handled above case "setup": # Download normalization template - normalization_template_path = Path(cfg.preprocessing.normalization_template) + normalization_template_path = Path(f"{os.environ['STAMP_RESOURCES_DIR']}/normalization_template.jpg") normalization_template_path.parent.mkdir(parents=True, exist_ok=True) if normalization_template_path.exists(): print(f"Skipping download, normalization template already exists at {normalization_template_path}") @@ -54,34 +120,50 @@ def run_cli(args: argparse.Namespace): with normalization_template_path.open("wb") as f: f.write(r.content) # Download feature extractor model - model_path = Path(cfg.preprocessing.model_path) + feat_extractor = cfg.preprocessing.feat_extractor + if feat_extractor == 'ctp': + model_path = Path(f"{os.environ['STAMP_RESOURCES_DIR']}/ctranspath.pth") + elif feat_extractor == 'uni': + model_path = Path(f"{os.environ['STAMP_RESOURCES_DIR']}/uni/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin") model_path.parent.mkdir(parents=True, exist_ok=True) if model_path.exists(): print(f"Skipping download, feature extractor model already exists at {model_path}") else: - print(f"Downloading CTransPath weights to {model_path}") - import gdown - gdown.download(CTRANSPATH_WEIGHTS_URL, str(model_path)) + if feat_extractor == 'ctp': + print(f"Downloading CTransPath weights to {model_path}") + import gdown + gdown.download(CTRANSPATH_WEIGHTS_URL, str(model_path)) + elif feat_extractor == 'uni': + print(f"Downloading UNI weights") + from uni.get_encoder import get_encoder + get_encoder(enc_name='uni', checkpoint='pytorch_model.bin', assets_dir=f"{os.environ['STAMP_RESOURCES_DIR']}/uni") case "config": print(OmegaConf.to_yaml(cfg, resolve=True)) case "preprocess": require_configs( cfg, - ["output_dir", "wsi_dir", "model_path", "cache_dir", "microns", "cores", "norm", "del_slide", "only_feature_extraction", "device", "normalization_template"], - prefix="preprocessing" + ["output_dir", "wsi_dir", "cache_dir", "microns", "cores", "norm", "del_slide", "only_feature_extraction", "device", "feat_extractor"], + prefix="preprocessing", + paths_to_check=["wsi_dir"] ) c = cfg.preprocessing # Some checks - if not Path(c.normalization_template).exists(): - raise ConfigurationError(f"Normalization template {c.normalization_template} does not exist, please run `stamp setup` to download it.") - if not Path(c.model_path).exists(): - raise ConfigurationError(f"Feature extractor model {c.model_path} does not exist, please run `stamp setup` to download it.") + normalization_template_path = Path(f"{os.environ['STAMP_RESOURCES_DIR']}/normalization_template.jpg") + if c.norm and not Path(normalization_template_path).exists(): + raise ConfigurationError(f"Normalization template {normalization_template_path} does not exist, please run `stamp setup` to download it.") + if c.feat_extractor == 'ctp': + model_path = f"{os.environ['STAMP_RESOURCES_DIR']}/ctranspath.pth" + elif c.feat_extractor == 'uni': + model_path = f"{os.environ['STAMP_RESOURCES_DIR']}/uni/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin" + if not Path(model_path).exists(): + raise ConfigurationError(f"Feature extractor model {model_path} does not exist, please run `stamp setup` to download it.") from .preprocessing.wsi_norm import preprocess preprocess( output_dir=Path(c.output_dir), wsi_dir=Path(c.wsi_dir), - model_path=Path(c.model_path), + model_path=Path(model_path), cache_dir=Path(c.cache_dir), + feat_extractor=c.feat_extractor, # patch_size=c.patch_size, target_microns=c.microns, cores=c.cores, @@ -89,14 +171,16 @@ def run_cli(args: argparse.Namespace): del_slide=c.del_slide, cache=c.cache if 'cache' in c else True, only_feature_extraction=c.only_feature_extraction, + keep_dir_structure=c.keep_dir_structure if 'keep_dir_structure' in c else False, device=c.device, - normalization_template=Path(c.normalization_template) + normalization_template=normalization_template_path ) case "train": require_configs( cfg, - ["output_dir", "feature_dir", "target_label", "cat_labels", "cont_labels"], - prefix="modeling" + ["clini_table", "slide_table", "output_dir", "feature_dir", "target_label", "cat_labels", "cont_labels"], + prefix="modeling", + paths_to_check=["clini_table", "slide_table", "feature_dir"] ) c = cfg.modeling from .modeling.marugoto.transformer.helpers import train_categorical_model_ @@ -111,8 +195,9 @@ def run_cli(args: argparse.Namespace): case "crossval": require_configs( cfg, - ["output_dir", "feature_dir", "target_label", "cat_labels", "cont_labels", "n_splits"], # this one requires the n_splits key! - prefix="modeling" + ["clini_table", "slide_table", "output_dir", "feature_dir", "target_label", "cat_labels", "cont_labels", "n_splits"], # this one requires the n_splits key! + prefix="modeling", + paths_to_check=["clini_table", "slide_table", "feature_dir"] ) c = cfg.modeling from .modeling.marugoto.transformer.helpers import categorical_crossval_ @@ -128,8 +213,9 @@ def run_cli(args: argparse.Namespace): case "deploy": require_configs( cfg, - ["output_dir", "deploy_feature_dir", "target_label", "cat_labels", "cont_labels", "model_path"], # this one requires the model_path key! - prefix="modeling" + ["clini_table", "slide_table", "output_dir", "deploy_feature_dir", "target_label", "cat_labels", "cont_labels", "model_path"], # this one requires the model_path key! + prefix="modeling", + paths_to_check=["clini_table", "slide_table", "deploy_feature_dir"] ) c = cfg.modeling from .modeling.marugoto.transformer.helpers import deploy_categorical_model_ @@ -141,11 +227,14 @@ def run_cli(args: argparse.Namespace): cat_labels=c.cat_labels, cont_labels=c.cont_labels, model_path=Path(c.model_path)) + print("Successfully deployed models") case "statistics": require_configs( cfg, ["pred_csvs", "target_label", "true_class", "output_dir"], - prefix="modeling.statistics") + prefix="modeling.statistics", + paths_to_check=["pred_csvs"] + ) from .modeling.statistics import compute_stats c = cfg.modeling.statistics if isinstance(c.pred_csvs,str): @@ -154,11 +243,14 @@ def run_cli(args: argparse.Namespace): target_label=c.target_label, true_class=c.true_class, output_dir=Path(c.output_dir)) + print("Successfully calculated statistics") case "heatmaps": require_configs( cfg, - ["feature_dir","wsi_dir","model_path","output_dir", "n_toptiles"], - prefix="heatmaps") + ["feature_dir","wsi_dir","model_path","output_dir", "n_toptiles", "overview"], + prefix="heatmaps", + paths_to_check=["feature_dir","wsi_dir","model_path"] + ) c = cfg.heatmaps from .heatmaps.__main__ import main main(slide_name=str(c.slide_name), @@ -166,22 +258,25 @@ def run_cli(args: argparse.Namespace): wsi_dir=Path(c.wsi_dir), model_path=Path(c.model_path), output_dir=Path(c.output_dir), - n_toptiles=int(c.n_toptiles)) + n_toptiles=int(c.n_toptiles), + overview=c.overview) + print("Successfully produced heatmaps") case _: raise ConfigurationError(f"Unknown command {args.command}") def main() -> None: parser = argparse.ArgumentParser(prog="stamp", description="STAMP: Solid Tumor Associative Modeling in Pathology") - parser.add_argument("--config", "-c", type=str, default="config.yaml", help="Path to config file") + parser.add_argument("--config", "-c", type=Path, default=None, help=f"Path to config file. Note that the --config option must be supplied before any of the subcommands. If unspecified, defaults to {DEFAULT_CONFIG_FILE.absolute()} or the default STAMP config file shipped with the package if {DEFAULT_CONFIG_FILE.absolute()} does not exist.") commands = parser.add_subparsers(dest="command") + commands.add_parser("init", help="Create a new STAMP configuration file at the path specified by --config") commands.add_parser("setup", help="Download required resources") - commands.add_parser("preprocess", help="Preprocess data") - commands.add_parser("train", help="Train a vision transformer model") - commands.add_parser("crossval", help="Train a vision transformer model with cross validation for modeling.n_splits folds") - commands.add_parser("deploy", help="Deploy a trained vision transformer model") - commands.add_parser("statistics", help="Generate ROC curves for a trained model") - commands.add_parser("config", help="Print the loaded configuation") + commands.add_parser("preprocess", help="Preprocess whole-slide images into feature vectors") + commands.add_parser("train", help="Train a Vision Transformer model") + commands.add_parser("crossval", help="Train a Vision Transformer model with cross validation for modeling.n_splits folds") + commands.add_parser("deploy", help="Deploy a trained Vision Transformer model") + commands.add_parser("statistics", help="Generate AUROCs and AUPRCs with 95%%CI for a trained Vision Transformer model") + commands.add_parser("config", help="Print the loaded configuration") commands.add_parser("heatmaps", help="Generate heatmaps for a trained model") args = parser.parse_args() @@ -199,4 +294,4 @@ def main() -> None: exit(1) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/config.yaml b/stamp/config.yaml similarity index 88% rename from config.yaml rename to stamp/config.yaml index 6aedd9d..bf56d51 100644 --- a/config.yaml +++ b/stamp/config.yaml @@ -1,7 +1,7 @@ # Main configuration file for STAMP. # # NOTE: you may use environment variables in this file, e.g. ${oc.env:STAMP_RESOURCES_DIR}. -# The STAMP_RESOURCES_DIR environment variable is a special environment variable that, if not set, will be set to the resources/ directory relative to the config.yaml file. +# The STAMP_RESOURCES_DIR environment variable is a special environment variable that, if not set, will be set to the resources/ directory relative to where STAMP is installed. # Only use absolute paths! @@ -10,9 +10,8 @@ preprocessing: wsi_dir: # Path of where the whole-slide images are. cache_dir: # Directory to store intermediate slide JPGs microns: 256 # Edge length in microns for each patch (default is 256, with pixel size 224, 256/224 = ~1.14MPP = ~9x magnification) - norm: true # Perform Macenko normalisation - normalization_template: ${oc.env:STAMP_RESOURCES_DIR}/normalization_template.jpg # Path to normalization template - model_path: ${oc.env:STAMP_RESOURCES_DIR}/ctranspath.pth # Path of where model for the feature extractor is + norm: false # Perform Macenko normalisation + feat_extractor: ctp # Use ctp for CTransPath (default) or uni for UNI (requires prior authentication) del_slide: false # Remove the original slide after processing cache: true # Save intermediate images (slide, background rejected, normalized) only_feature_extraction: false # Only perform feature extraction (intermediate images (background rejected, [normalized]) have to exist) @@ -46,4 +45,5 @@ heatmaps: wsi_dir: # Path to whole-slide image directory. model_path: /path/to/export.pkl # Path to saved model (only applicable to deployment) output_dir: # Path to output directory - n_toptiles: 8 # Number of toptiles, default is 8 \ No newline at end of file + n_toptiles: 8 # Number of toptiles, default is 8 + overview: true # Create final overview image \ No newline at end of file diff --git a/stamp/heatmaps/__main__.py b/stamp/heatmaps/__main__.py index 1a6c95a..c465700 100755 --- a/stamp/heatmaps/__main__.py +++ b/stamp/heatmaps/__main__.py @@ -139,6 +139,7 @@ def main( model_path: Path, output_dir: Path, n_toptiles: int = 8, + overview: bool = True, ) -> None: learn = load_learner(model_path) learn.model.eval() @@ -162,13 +163,13 @@ def main( preds, gradcam = gradcam_per_category( learn=learn, feats=feats, categories=categories ) - gradcam_2d = vals_to_im(gradcam.permute(-1, -2), coords // stride).detach() + gradcam_2d = vals_to_im(gradcam.permute(-1, -2), torch.div(coords, stride, rounding_mode='floor')).detach() scores = torch.softmax( learn.model(feats.unsqueeze(-2), torch.ones((len(feats)))), dim=1 ) - scores_2d = vals_to_im(scores, coords // stride).detach() - fig, axs = plt.subplots(nrows=2, ncols=min(2, len(categories)), figsize=(12, 8)) + scores_2d = vals_to_im(scores, torch.div(coords, stride, rounding_mode='floor')).detach() + fig, axs = plt.subplots(nrows=2, ncols=max(2, len(categories)), figsize=(12, 8)) show_class_map( class_ax=axs[0, 1], @@ -219,9 +220,10 @@ def main( ax.imshow(score_im) ax.set_title(f"{category} {preds[0,pos_idx]:1.2f}") - + target_size=np.array(score_im.shape[:2][::-1]) * 8 + # latest PIL requires shape to be a tuple (), not array [] Image.fromarray(np.uint8(score_im * 255)).resize( - np.array(score_im.shape[:2][::-1]) * 8, resample=Image.NEAREST + tuple(target_size), resample=Image.NEAREST ).save( slide_output_dir / f"scores-{h5_path.stem}--score_{category}={preds[0][pos_idx]:0.2f}.png" @@ -237,18 +239,19 @@ def main( n=n_toptiles, ) - thumb = show_thumb( - slide=slide, - thumb_ax=axs[0, 0], - attention=attention, - ) - Image.fromarray(thumb).save(slide_output_dir / f"thumbnail-{h5_path.stem}.png") + if overview: + thumb = show_thumb( + slide=slide, + thumb_ax=axs[0, 0], + attention=attention, + ) + Image.fromarray(thumb).save(slide_output_dir / f"thumbnail-{h5_path.stem}.png") - for ax in axs.ravel(): - ax.axis("off") + for ax in axs.ravel(): + ax.axis("off") - fig.savefig(slide_output_dir / f"overview-{h5_path.stem}.png") - plt.close(fig) + fig.savefig(slide_output_dir / f"overview-{h5_path.stem}.png") + plt.close(fig) if __name__ == "__main__": @@ -296,5 +299,12 @@ def main( required=False, help="Number of toptiles to generate, 8 by default", ) + parser.add_argument( + "--overview", + type=bool, + default=True, + required=False, + help="Generate final overview image", + ) args = parser.parse_args() main(**vars(args)) diff --git a/stamp/modeling/marugoto/transformer/TransMIL.py b/stamp/modeling/marugoto/transformer/TransMIL.py new file mode 100755 index 0000000..dcae465 --- /dev/null +++ b/stamp/modeling/marugoto/transformer/TransMIL.py @@ -0,0 +1,162 @@ +""" +In parts from https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py +""" + +import torch +from torch import nn +import torch.nn.functional as F +from einops import repeat + + + +class RMSNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.scale = dim ** 0.5 + self.gamma = nn.Parameter(torch.ones(dim)) + + def forward(self, x): + return F.normalize(x, dim = -1) * self.scale * self.gamma + + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, norm_layer=nn.LayerNorm, dropout=0.): + super().__init__() + self.mlp = nn.Sequential( + norm_layer(dim), + nn.Linear(dim, hidden_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + + def forward(self, x): + return self.mlp(x) + + +# class Attention(nn.Module): +# def __init__(self, dim, heads=8, dim_head=512 // 8, norm_layer=nn.LayerNorm, dropout=0.): +# super().__init__() +# inner_dim = dim_head * heads +# project_out = heads != 1 or dim_head != dim + +# self.heads = heads +# self.scale = dim_head ** -0.5 + +# self.norm = norm_layer(dim) + +# self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, dim), +# nn.Dropout(dropout) +# ) if project_out else nn.Identity() + +# def forward(self, x, mask=None): +# x = self.norm(x) + +# qkv = self.to_qkv(x).chunk(3, dim=-1) +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) +# dots = (q @ k.mT) * self.scale + +# if mask is not None: +# mask_value = torch.finfo(dots.dtype).min +# dots.masked_fill_(mask, mask_value) + +# # improve numerical stability of softmax +# dots = dots - torch.amax(dots, dim=-1, keepdim=True) +# attn = F.softmax(dots, dim=-1) + +# out = attn @ v +# out = rearrange(out, 'b h n d -> b n (h d)') +# return self.to_out(out), attn + + +class Attention(nn.Module): + def __init__(self, dim, heads=8, dim_head=512 // 8, norm_layer=nn.LayerNorm, dropout=0.): + super().__init__() + self.heads = heads + self.norm = norm_layer(dim) + self.mhsa = nn.MultiheadAttention(dim, heads, dropout, batch_first=True) + + def forward(self, x, mask=None): + if mask is not None: + mask = mask.repeat(self.heads, 1, 1) + + x = self.norm(x) + attn_output, _ = self.mhsa(x, x, x, need_weights=False, attn_mask=mask) + return attn_output + + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, norm_layer=nn.LayerNorm, dropout=0.): + super().__init__() + self.depth = depth + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + Attention(dim, heads=heads, dim_head=dim_head, norm_layer=norm_layer, dropout=dropout), + FeedForward(dim, mlp_dim, norm_layer=norm_layer, dropout=dropout) + ])) + self.norm = norm_layer(dim) + + def forward(self, x, mask=None): + for attn, ff in self.layers: + x_attn = attn(x, mask=mask) + x = x_attn + x + x = ff(x) + x + return self.norm(x) + + +class TransMIL(nn.Module): + def __init__(self, *, + num_classes: int, input_dim: int = 768, dim: int = 512, + depth: int = 2, heads: int = 8, dim_head: int = 64, mlp_dim: int = 2048, + pool: str ='cls', dropout: int = 0., emb_dropout: int = 0. + ): + super().__init__() + assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' + self.cls_token = nn.Parameter(torch.randn(dim)) + + self.fc = nn.Sequential(nn.Linear(input_dim, dim, bias=True), nn.GELU()) + self.dropout = nn.Dropout(emb_dropout) + + self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, nn.LayerNorm, dropout) + + self.pool = pool + self.mlp_head = nn.Sequential( + nn.Linear(dim, num_classes) + ) + + def forward(self, x, lens): + # remove unnecessary padding + # (deactivated for now, since the memory usage fluctuates more and is overall bigger) + # x = x[:, :torch.max(lens)].contiguous() + b, n, d = x.shape + + # map input sequence to latent space of TransMIL + x = self.dropout(self.fc(x)) + + add_cls = self.pool == 'cls' + if add_cls: + cls_tokens = repeat(self.cls_token, 'd -> b 1 d', b=b) + x = torch.cat((cls_tokens, x), dim=1) + lens = lens + 1 # account for cls token + + # mask indicating zero padded feature vectors + # (deactivated for now, since it seems to use more memory than without) + mask = None + if torch.amin(lens) != torch.amax(lens) and False: + mask = torch.arange(0, n + add_cls, dtype=torch.int32, device=x.device).repeat(b, 1) < lens[..., None] + mask = (~mask[:, None, :]).repeat(1, (n + add_cls), 1) # shape: (B, L, L) + # mask = (~mask[:, None, :]).expand(-1, (n + add_cls), -1) + + x = self.transformer(x, mask) + + if mask is not None and self.pool == 'mean': + x = torch.cumsum(x, dim=1)[torch.arange(b), lens - 1] + x = x / lens[..., None] + else: + x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] + + return self.mlp_head(x) diff --git a/stamp/modeling/marugoto/transformer/ViT.py b/stamp/modeling/marugoto/transformer/ViT.py deleted file mode 100755 index 4d59aca..0000000 --- a/stamp/modeling/marugoto/transformer/ViT.py +++ /dev/null @@ -1,76 +0,0 @@ - -import torch -from einops import rearrange, repeat -from einops.layers.torch import Rearrange -from torch import nn - -from .transformer import PreNorm, Attention, FeedForward - - -class Transformer(nn.Module): - def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): - super().__init__() - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append(nn.ModuleList([ - PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) - ])) - - def forward(self, x): #, register_hook=False - for attn, ff in self.layers: - x = attn(x) + x # , register_hook=register_hook - x = ff(x) + x - return x - - -class ViT(nn.Module): - def __init__(self, *, num_classes, input_dim=768, dim=512, depth=2, heads=8, mlp_dim=512, pool='cls', channels=3, - dim_head=64, dropout=0., emb_dropout=0.): - super().__init__() - # image_height, image_width = pair(image_size) - # patch_height, patch_width = pair(patch_size) - # - # assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.' - # - # num_patches = (image_height // patch_height) * (image_width // patch_width) - # patch_dim = channels * patch_height * patch_width - assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)' - - # self.to_patch_embedding = nn.Sequential( - # Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width), - # nn.Linear(patch_dim, dim), - # ) - - # self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) - self.fc = nn.Sequential(nn.Linear(input_dim, 512, bias=True), nn.ReLU()) # added by me - - self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) - self.dropout = nn.Dropout(emb_dropout) - - self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) - - self.pool = pool - self.to_latent = nn.Identity() - - self.mlp_head = nn.Sequential( - nn.LayerNorm(dim), - nn.Linear(dim, num_classes) - ) - - def forward(self, x, register_hook=False): - # x = self.to_patch_embedding(img) - b, n, d = x.shape - - x = self.fc(x) - cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b) - x = torch.cat((cls_tokens, x), dim=1) - # x += self.pos_embedding[:, :(n + 1)] - x = self.dropout(x) - - x = self.transformer(x) # , register_hook=register_hook - - x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0] - - x = self.to_latent(x) - return self.mlp_head(x) diff --git a/stamp/modeling/marugoto/transformer/base.py b/stamp/modeling/marugoto/transformer/base.py index c791a7c..afc943b 100755 --- a/stamp/modeling/marugoto/transformer/base.py +++ b/stamp/modeling/marugoto/transformer/base.py @@ -1,18 +1,21 @@ from typing import Any, Iterable, Optional, Sequence, Tuple, TypeVar from pathlib import Path -import os +from functools import partial import torch from torch import nn import torch.nn.functional as F from fastai.vision.all import ( Learner, DataLoader, DataLoaders, RocAuc, - SaveModelCallback, CSVLogger, EarlyStoppingCallback) + SaveModelCallback, CSVLogger, EarlyStoppingCallback, + MixedPrecision, AMPMode, OptimWrapper +) import pandas as pd import numpy as np +import matplotlib.pyplot as plt from .data import make_dataset, SKLearnEncoder -from .ViT import ViT +from .TransMIL import TransMIL __all__ = ['train', 'deploy'] @@ -28,8 +31,11 @@ def train( add_features: Iterable[Tuple[SKLearnEncoder, Sequence[Any]]] = [], valid_idxs: np.ndarray, n_epoch: int = 32, - patience: int = 16, + patience: int = 8, path: Optional[Path] = None, + batch_size: int = 64, + cores: int = 8, + plot: bool = False ) -> Learner: """Train a MLP on image features. @@ -39,6 +45,11 @@ def train( add_features: An (encoder, targets) pair for each additional input. valid_idxs: Indices of the datasets to use for validation. """ + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device.type == "cuda": + # allow for usage of TensorFloat32 as internal dtype for matmul on modern NVIDIA GPUs + torch.set_float32_matmul_precision("medium") + target_enc, targs = targets train_ds = make_dataset( bags=bags[~valid_idxs], @@ -55,39 +66,74 @@ def train( (enc, vals[valid_idxs]) for enc, vals in add_features], bag_size=None) - + # build dataloaders train_dl = DataLoader( - train_ds, batch_size=64, shuffle=True, num_workers=1, drop_last=True) + train_ds, batch_size=batch_size, shuffle=True, num_workers=cores, + drop_last=len(train_ds) > batch_size, + device=device, pin_memory=device.type == "cuda" + ) valid_dl = DataLoader( - valid_ds, batch_size=1, shuffle=False, num_workers=1) + valid_ds, batch_size=1, shuffle=False, num_workers=cores, + device=device, pin_memory=device.type == "cuda" + ) batch = train_dl.one_batch() - feature_dim=batch[0].shape[-1] - # for binary classification num_classes=2 - model = ViT(num_classes=len(target_enc.categories_[0]), input_dim=feature_dim) # Transformer(num_classes=2) - model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu')) # + feature_dim = batch[0].shape[-1] - # weigh inversely to class occurances + # for binary classification num_classes=2 + model = TransMIL( + num_classes=len(target_enc.categories_[0]), input_dim=feature_dim, + dim=512, depth=2, heads=8, mlp_dim=512, dropout=.0 + ) + # TODO: + # maybe increase mlp_dim? Not necessary 4*dim, but maybe a bit? + # maybe add at least some dropout? + + # model = torch.compile(model) + model.to(device) + print(f"Model: {model}", end=" ") + print(f"[Parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}]") + + # weigh inversely to class occurrences counts = pd.Series(targs[~valid_idxs]).value_counts() - weight = counts.sum() / counts weight /= weight.sum() # reorder according to vocab weight = torch.tensor( - list(map(weight.get, target_enc.categories_[0])), dtype=torch.float32) + list(map(weight.get, target_enc.categories_[0])), dtype=torch.float32, device=device) loss_func = nn.CrossEntropyLoss(weight=weight) - dls = DataLoaders(train_dl, valid_dl, device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')) # - learn = Learner(dls, model, loss_func=loss_func, - metrics=[RocAuc()], path=path) + dls = DataLoaders(train_dl, valid_dl, device=device) - cbs = [ - SaveModelCallback(fname=f'best_valid'), - EarlyStoppingCallback(monitor='roc_auc_score', - min_delta=0.01, patience=patience), - CSVLogger()] + learn = Learner( + dls, + model, + loss_func=loss_func, + opt_func = partial(OptimWrapper, opt=torch.optim.AdamW), + metrics=[RocAuc()], + path=path, + )#.to_bf16() - learn.fit_one_cycle(n_epoch=n_epoch, lr_max=1e-4, cbs=cbs) + cbs = [ + SaveModelCallback(monitor='valid_loss', fname=f'best_valid'), + EarlyStoppingCallback(monitor='valid_loss', patience=patience), + CSVLogger(), + # MixedPrecision(amp_mode=AMPMode.BF16) + ] + learn.fit_one_cycle(n_epoch=n_epoch, reset_opt=True, lr_max=1e-4, wd=1e-2, cbs=cbs) + + # Plot training and validation losses as well as learning rate schedule + if plot: + path_plots = path / "plots" + path_plots.mkdir(parents=True, exist_ok=True) + + learn.recorder.plot_loss() + plt.savefig(path_plots / 'losses_plot.png') + plt.close() + + learn.recorder.plot_sched() + plt.savefig(path_plots / 'lr_scheduler.png') + plt.close() return learn @@ -96,6 +142,7 @@ def deploy( test_df: pd.DataFrame, learn: Learner, *, target_label: Optional[str] = None, cat_labels: Optional[Sequence[str]] = None, cont_labels: Optional[Sequence[str]] = None, + device: torch.device = torch.device('cpu') ) -> pd.DataFrame: assert test_df.PATIENT.nunique() == len(test_df), 'duplicate patients!' #assert (len(add_label) @@ -122,7 +169,8 @@ def deploy( bag_size=None) test_dl = DataLoader( - test_ds, batch_size=1, shuffle=False, num_workers=1) + test_ds, batch_size=1, shuffle=False, num_workers=1, + device=device, pin_memory=device.type == "cuda") #removed softmax in forward, but add here to get 0-1 probabilities patient_preds, patient_targs = learn.get_preds(dl=test_dl, act=nn.Softmax(dim=1)) diff --git a/stamp/modeling/marugoto/transformer/helpers.py b/stamp/modeling/marugoto/transformer/helpers.py index f817096..06b11e7 100755 --- a/stamp/modeling/marugoto/transformer/helpers.py +++ b/stamp/modeling/marugoto/transformer/helpers.py @@ -1,7 +1,7 @@ from datetime import datetime import json +import os from pathlib import Path -from pyexpat import features from typing import Iterable, Optional, Sequence, Union import numpy as np @@ -23,6 +23,25 @@ PathLike = Union[str, Path] +class IncompatibleVersionError(Exception): + """Exception raised for loading a model with an incompatible version.""" + pass + + +def safe_load_learner(model_path, use_cpu): + try: + learn = load_learner(model_path, cpu=use_cpu) # if False will use GPU instead + return learn + except ModuleNotFoundError as e: + if e.name == "stamp.modeling.marugoto.transformer.ViT": + raise IncompatibleVersionError( + "The model checkpoint is incompatible with the current version of STAMP (>= 1.1.0). " + "Please use STAMP version <= 1.0.3 to deploy this checkpoint." + ) from e + else: + raise + + def train_categorical_model_( clini_table: PathLike, slide_table: PathLike, @@ -104,7 +123,7 @@ def train_categorical_model_( with open(output_path/'info.json', 'w') as f: json.dump(info, f) - target_enc = OneHotEncoder(sparse=False).fit(categories.reshape(-1, 1)) + target_enc = OneHotEncoder(sparse_output=False).fit(categories.reshape(-1, 1)) add_features = [] if cat_labels: add_features.append((_make_cat_enc(train_df, cat_labels), df[cat_labels].values)) @@ -116,6 +135,7 @@ def train_categorical_model_( add_features=add_features, valid_idxs=df.PATIENT.isin(valid_patients).values, path=output_path, + cores=max(1, os.cpu_count() // 4) ) # save some additional information to the learner to make deployment easier @@ -138,7 +158,7 @@ def _make_cat_enc(df, cats) -> SKLearnEncoder: fitting_cats.append(non_na_samples) cat_samples = np.stack(fitting_cats, axis=1) cat_enc = make_pipeline( - OneHotEncoder(sparse=False, handle_unknown='ignore'), + OneHotEncoder(sparse_output=False, handle_unknown='ignore'), StandardScaler(), ).fit(cat_samples) return cat_enc @@ -173,6 +193,14 @@ def deploy_categorical_model_( model_path: Path of the model to deploy. output_path: File to save model in. """ + use_cpu=True + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device.type == "cuda": + # allow for usage of TensorFloat32 as internal dtype for matmul on modern NVIDIA GPUs + torch.set_float32_matmul_precision("high") + + use_cpu= (device.type == "cpu") # True or False + feature_dir = Path(feature_dir) model_path = Path(model_path) output_path = Path(output_path) @@ -180,7 +208,8 @@ def deploy_categorical_model_( print(f'{preds_csv} already exists! Skipping...') return - learn = load_learner(model_path) + + learn = safe_load_learner(model_path, use_cpu=use_cpu) target_enc = get_target_enc(learn) categories = target_enc.categories_[0] @@ -190,7 +219,7 @@ def deploy_categorical_model_( test_df = get_cohort_df(clini_table, slide_table, feature_dir, target_label, categories) - patient_preds_df = deploy(test_df=test_df, learn=learn, target_label=target_label) + patient_preds_df = deploy(test_df=test_df, learn=learn, target_label=target_label, device=device) output_path.mkdir(parents=True, exist_ok=True) patient_preds_df.to_csv(preds_csv, index=False) @@ -220,6 +249,11 @@ def categorical_crossval_( output_path = Path(output_path) output_path.mkdir(exist_ok=True, parents=True) + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + if device.type == "cuda": + # allow for usage of TensorFloat32 as internal dtype for matmul on modern NVIDIA GPUs + torch.set_float32_matmul_precision("high") + # just a big fat object to dump all kinds of info into for later reference # not used during actual training info = { @@ -261,7 +295,7 @@ def categorical_crossval_( info['class distribution'] = {'overall': { k: int(v) for k, v in df[target_label].value_counts().items()}} - target_enc = OneHotEncoder(sparse=False).fit(categories.reshape(-1, 1)) + target_enc = OneHotEncoder(sparse_output=False).fit(categories.reshape(-1, 1)) if (fold_path := output_path/'folds.pt').exists(): folds = torch.load(fold_path) @@ -283,12 +317,13 @@ def categorical_crossval_( json.dump(info, f) for fold, (train_idxs, test_idxs) in enumerate(folds): + print(f"\nFold: {fold+1}/{n_splits}") fold_path = output_path/f'fold-{fold}' if (preds_csv := fold_path/'patient-preds.csv').exists(): print(f'{preds_csv} already exists! Skipping...') continue elif (fold_path/'export.pkl').exists(): - learn = load_learner(fold_path/'export.pkl') + learn = safe_load_learner(fold_path/'export.pkl') else: fold_train_df = df.iloc[train_idxs] learn = _crossval_train( @@ -301,7 +336,8 @@ def categorical_crossval_( fold_test_df.drop(columns='slide_path').to_csv(fold_path/'test.csv', index=False) patient_preds_df = deploy( test_df=fold_test_df, learn=learn, - target_label=target_label, cat_labels=cat_labels, cont_labels=cont_labels) + target_label=target_label, cat_labels=cat_labels, + cont_labels=cont_labels, device=device) patient_preds_df.to_csv(preds_csv, index=False) @@ -336,7 +372,9 @@ def _crossval_train( targets=(target_enc, fold_df[target_label].values), add_features=add_features, valid_idxs=fold_df.PATIENT.isin(valid_patients), - path=fold_path) + path=fold_path, + cores=max(1, os.cpu_count() // 4) + ) learn.target_label = target_label learn.cat_labels, learn.cont_labels = cat_labels, cont_labels diff --git a/stamp/modeling/marugoto/transformer/transformer.py b/stamp/modeling/marugoto/transformer/transformer.py deleted file mode 100755 index 2bae9ef..0000000 --- a/stamp/modeling/marugoto/transformer/transformer.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -In parts from https://github.com/lucidrains -""" - -import torch -from einops import rearrange -from torch import nn - - -class PreNorm(nn.Module): - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - -class Attention(nn.Module): - def __init__(self, dim=512, heads=8, dim_head=512 // 8, dropout=0.1): - super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - - self.heads = heads - self.scale = dim_head ** -0.5 - - self.attend = nn.Softmax(dim=-1) - self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) - - self.to_out = nn.Sequential( - nn.Linear(inner_dim, dim), - nn.Dropout(dropout) - ) if project_out else nn.Identity() - - def forward(self, x): - qkv = self.to_qkv(x).chunk(3, dim=-1) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv) - - dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale - - attn = self.attend(dots) - - out = torch.matmul(attn, v) - out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out) - - -class FeedForward(nn.Module): - def __init__(self, dim=512, hidden_dim=1024, dropout=0.1): - super().__init__() - self.net = nn.Sequential( - nn.Linear(dim, hidden_dim), - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(hidden_dim, dim), - nn.Dropout(dropout) - ) - - def forward(self, x): - return self.net(x) - - -class TransformerLayer(nn.Module): - def __init__(self, norm_layer=nn.LayerNorm, dim=512, heads=8, use_ff=True, use_norm=True): - super().__init__() - self.norm = norm_layer(dim) - self.attn = Attention(dim=dim, heads=heads, dim_head=dim // heads) - self.use_ff = use_ff - self.use_norm = use_norm - if self.use_ff: - self.ff = FeedForward() - - def forward(self, x): - if self.use_norm: - x = x + self.attn(self.norm(x)) - else: - x = x + self.attn(x) - if self.use_ff: - x = self.ff(x) + x - return x - - -class Transformer(nn.Module): - def __init__(self, num_classes): - super().__init__() - self.n_classes = num_classes - - self._fc1 = nn.Sequential(nn.Linear(2048, 512, bias=True), nn.ReLU()) - self.layer1 = TransformerLayer(dim=512, heads=8, use_ff=False, use_norm=True) - self.layer2 = TransformerLayer(dim=512, heads=8, use_ff=False, use_norm=True) - self._fc2 = nn.Linear(512, self.n_classes, bias=True) - - def forward(self, x,_): - - h = x - h = self._fc1(h) - h = self.layer1(h) - h = self.layer2(h) - h = h.mean(dim=1) - logits = self._fc2(h) - - return logits diff --git a/stamp/preprocessing/helpers/feature_extractors.py b/stamp/preprocessing/helpers/feature_extractors.py index dab9f6d..fb2c849 100755 --- a/stamp/preprocessing/helpers/feature_extractors.py +++ b/stamp/preprocessing/helpers/feature_extractors.py @@ -10,45 +10,76 @@ from tqdm import tqdm import json import h5py +import uni +import os from .swin_transformer import swin_tiny_patch4_window7_224, ConvStem __version__ = "001_01-10-2023" -class FeatureExtractor: - def __init__(self): - self.model_type = "CTransPath" - - def init_feat_extractor(self, checkpoint_path: str, device: str, **kwargs): +def get_digest(file: str): + sha256 = hashlib.sha256() + with open(file, 'rb') as f: + while True: + data = f.read(1 << 16) + if not data: + break + sha256.update(data) + return sha256.hexdigest() + +class FeatureExtractorCTP: + def __init__(self, checkpoint_path: str): + self.checkpoint_path = checkpoint_path + + def init_feat_extractor(self, device: str, **kwargs): """Extracts features from slide tiles. - Args: - checkpoint_path: Path to the model checkpoint file. """ - sha256 = hashlib.sha256() - with open(checkpoint_path, 'rb') as f: - while True: - data = f.read(1 << 16) - if not data: - break - sha256.update(data) - - assert sha256.hexdigest() == '7c998680060c8743551a412583fac689db43cec07053b72dfec6dcd810113539' + digest = get_digest(self.checkpoint_path) + assert digest == '7c998680060c8743551a412583fac689db43cec07053b72dfec6dcd810113539' - model = swin_tiny_patch4_window7_224(embed_layer=ConvStem, pretrained=False) - model.head = nn.Identity() + self.model = swin_tiny_patch4_window7_224(embed_layer=ConvStem, pretrained=False) + self.model.head = nn.Identity() - ctranspath = torch.load(checkpoint_path, map_location=torch.device('cpu')) - model.load_state_dict(ctranspath['model'], strict=True) + ctranspath = torch.load(self.checkpoint_path, map_location=torch.device('cpu')) + self.model.load_state_dict(ctranspath['model'], strict=True) if torch.cuda.is_available(): - model = model.to(device) + self.model = self.model.to(device) + + self.transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) - print("CTransPath model successfully initialised...\n") model_name='xiyuewang-ctranspath-7c998680' - return model, model_name + print("CTransPath model successfully initialised...\n") + return model_name +class FeatureExtractorUNI: + def init_feat_extractor(self, device: str, **kwargs): + """Extracts features from slide tiles. + Requirements: + Permission from authors via huggingface: https://huggingface.co/MahmoodLab/UNI + Huggingface account with valid login token + On first model initialization, you will be prompted to enter your login token. The token is + then stored in ./home//.cache/huggingface/token. Subsequent inits do not require you to re-enter the token. + + Args: + device: "cuda" or "cpu" + """ + asset_dir = f"{os.environ['STAMP_RESOURCES_DIR']}/uni" + model, transform = uni.get_encoder(enc_name="uni", device=device, assets_dir=asset_dir) + self.model = model + self.transform = transform + + digest = get_digest(f"{asset_dir}/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin") + model_name = f"mahmood-uni-{digest[:8]}" + print("UNI model successfully initialised...\n") + return model_name class SlideTileDataset(Dataset): def __init__(self, patches: np.array, transform=None, *, repetitions: int = 1) -> None: @@ -72,7 +103,7 @@ def __getitem__(self, i): def extract_features_( *, - model, model_name, norm_wsi_img: np.ndarray, coords: list, wsi_name: str, outdir: Path, + model, model_name, transform, norm_wsi_img: np.ndarray, coords: list, wsi_name: str, outdir: Path, augmented_repetitions: int = 0, cores: int = 8, is_norm: bool = True, device: str = 'cpu', target_microns: int = 256, patch_size: int = 224 ) -> None: @@ -87,23 +118,18 @@ def extract_features_( only one, non-augmentation iteration will be done. """ - normal_transform = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - augmenting_transform = transforms.Compose([ - transforms.Resize(224), - transforms.CenterCrop(224), - transforms.RandomHorizontalFlip(p=.5), - transforms.RandomVerticalFlip(p=.5), - transforms.RandomApply([transforms.GaussianBlur(3)], p=.5), - transforms.RandomApply([transforms.ColorJitter( - brightness=.1, contrast=.2, saturation=.25, hue=.125)], p=.5), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) + # Obsolete (?) + # augmenting_transform = transforms.Compose([ + # transforms.Resize(224), + # transforms.CenterCrop(224), + # transforms.RandomHorizontalFlip(p=.5), + # transforms.RandomVerticalFlip(p=.5), + # transforms.RandomApply([transforms.GaussianBlur(3)], p=.5), + # transforms.RandomApply([transforms.ColorJitter( + # brightness=.1, contrast=.2, saturation=.25, hue=.125)], p=.5), + # transforms.ToTensor(), + # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + # ]) extractor_string = f'STAMP-extract-{__version__}_{model_name}' with open(outdir.parent/'info.json', 'w') as f: @@ -113,7 +139,7 @@ def extract_features_( 'microns': target_microns, 'patch_size': patch_size}, f) - unaugmented_ds = SlideTileDataset(norm_wsi_img, normal_transform) + unaugmented_ds = SlideTileDataset(norm_wsi_img, transform) augmented_ds = [] #clean up memory diff --git a/stamp/preprocessing/helpers/loading_slides.py b/stamp/preprocessing/helpers/loading_slides.py index 69e0410..334cb0e 100755 --- a/stamp/preprocessing/helpers/loading_slides.py +++ b/stamp/preprocessing/helpers/loading_slides.py @@ -35,7 +35,7 @@ def load_slide(slide: openslide.OpenSlide, target_mpp: float = 256/224, cores: i for i in range(steps): # row for j in range(steps): # column future = executor.submit( - _load_tile, slide, (stride*(j, i)), stride, tile_target_size) + _load_tile, slide, (stride*(j, i)), stride, tuple(tile_target_size)) future_coords[future] = (i, j) # write the loaded tiles into an image as soon as they are loaded @@ -132,4 +132,4 @@ def test_get_raw_tile_list(): assert len(canny_patch_list) == 4 assert len(coords_list) == 4 assert canny_patch_list[0].shape == (224,224,3) - assert coords_list[0] == (0,0) \ No newline at end of file + assert coords_list[0] == (0,0) diff --git a/stamp/preprocessing/wsi_norm.py b/stamp/preprocessing/wsi_norm.py index 8b436b4..9036ec5 100755 --- a/stamp/preprocessing/wsi_norm.py +++ b/stamp/preprocessing/wsi_norm.py @@ -25,34 +25,36 @@ from .helpers.common import supported_extensions from .helpers.concurrent_canny_rejection import reject_background from .helpers.loading_slides import process_slide_jpg, load_slide, get_raw_tile_list -from .helpers.feature_extractors import FeatureExtractor, extract_features_ +from .helpers.feature_extractors import FeatureExtractorCTP, FeatureExtractorUNI, extract_features_ from .helpers.exceptions import MPPExtractionError PIL.Image.MAX_IMAGE_PIXELS = None +def clean_lockfile(file): + if os.path.exists(file): # Catch collision cases + os.remove(file) + @contextmanager def lock_file(slide_path: Path): try: - Path(f"{slide_path}.tmp").touch() + Path(f"{slide_path}.lock").touch() except PermissionError: pass # No write permissions for wsi directory try: yield finally: - if os.path.exists(f"{slide_path}.tmp"): # Catch collision cases - os.remove(f"{slide_path}.tmp") + clean_lockfile(f"{slide_path}.lock") def test_wsidir_write_permissions(wsi_dir: Path): try: - testfile = wsi_dir/f"test_{time.time()}.tmp" + testfile = wsi_dir/f"test_{str(os.getpid())}.tmp" Path(testfile).touch() except PermissionError: logging.warning("No write permissions for wsi directory! If multiple stamp processes are running " "in parallel, the final summary may show an incorrect number of slides processed.") finally: - if os.path.exists(testfile): - os.remove(testfile) + clean_lockfile(testfile) def save_image(image, path: Path): width, height = image.size @@ -62,20 +64,29 @@ def save_image(image, path: Path): return image.save(path) -def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Path, - norm: bool, del_slide: bool, only_feature_extraction: bool, cache: bool = True, - cores: int = 8, target_microns: int = 256, patch_size: int = 224, - device: str = "cuda", normalization_template: Path = None): +def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Path, norm: bool, + del_slide: bool, only_feature_extraction: bool, cache: bool = True, cores: int = 8, + target_microns: int = 256, patch_size: int = 224, keep_dir_structure: bool = False, + device: str = "cuda", normalization_template: Path = None, feat_extractor: str = "ctp"): + # Clean up potentially old leftover .lock files + for lockfile in wsi_dir.glob("**/*.lock"): + if time.time() - os.path.getmtime(lockfile) > 20: + clean_lockfile(lockfile) has_gpu = torch.cuda.is_available() target_mpp = target_microns/patch_size patch_shape = (patch_size, patch_size) #(224, 224) by default step_size = patch_size #have 0 overlap by default - + # Initialize the feature extraction model - print(f"Initialising CTransPath model as feature extractor...") - extractor = FeatureExtractor() - model, model_name = extractor.init_feat_extractor(checkpoint_path=model_path, device=device) + print(f"Initialising feature extractor {feat_extractor}...") + if feat_extractor == "ctp": + extractor = FeatureExtractorCTP(checkpoint_path=model_path) + elif feat_extractor == "uni": + extractor = FeatureExtractorUNI() + else: + raise Exception(f"Invalid feature extractor '{feat_extractor}' selected") + model_name = extractor.init_feat_extractor(device=device) # Create cache and output directories if cache: cache_dir.mkdir(exist_ok=True, parents=True) @@ -85,7 +96,7 @@ def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Pat output_file_dir = output_dir/model_name_norm output_file_dir.mkdir(parents=True, exist_ok=True) # Create logfile and set up logging - logfile_name = "logfile_" + time.strftime("%Y-%m-%d_%H-%M-%S") + logfile_name = "logfile_" + time.strftime("%Y-%m-%d_%H-%M-%S") + "_" + str(os.getpid()) logdir = output_file_dir/logfile_name logging.basicConfig(filename=logdir, force=True, level=logging.INFO, format="[%(levelname)s] %(message)s") logging.getLogger().addHandler(logging.StreamHandler()) @@ -103,7 +114,7 @@ def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Pat if norm: print("\nInitialising Macenko normaliser...") print(normalization_template) - target = cv2.imread(normalization_template) + target = cv2.imread(str(normalization_template)) target = cv2.cvtColor(target, cv2.COLOR_BGR2RGB) normalizer = stainNorm_Macenko.Normalizer() normalizer.fit(target) @@ -114,6 +125,7 @@ def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Pat img_name = "norm_slide.jpg" if norm else "canny_slide.jpg" # Get list of slides, filter out slides that have already been processed + print("Scanning for existing feature files...") existing = [f.stem for f in output_file_dir.glob("**/*.h5")] if output_file_dir.exists() else [] if not only_feature_extraction: img_dir = [svs for ext in supported_extensions for svs in wsi_dir.glob(f"**/*{ext}")] @@ -141,8 +153,13 @@ def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Pat print("\n") logging.info(f"===== Processing slide {slide_name} =====") - feat_out_dir = output_file_dir/slide_name - if not (os.path.exists((f"{feat_out_dir}.h5"))) and not os.path.exists(f"{slide_url}.tmp"): + slide_subdir = slide_url.parent.relative_to(wsi_dir) + if not keep_dir_structure or slide_subdir == Path("."): + feat_out_dir = output_file_dir/slide_name + else: + (output_file_dir/slide_subdir).mkdir(parents=True, exist_ok=True) + feat_out_dir = output_file_dir/slide_subdir/slide_name + if not (os.path.exists((f"{feat_out_dir}.h5"))) and not os.path.exists(f"{slide_url}.lock"): with lock_file(slide_url): if ( (only_feature_extraction and (slide_jpg := slide_url).exists()) or \ @@ -219,13 +236,13 @@ def preprocess(output_dir: Path, wsi_dir: Path, model_path: Path, cache_dir: Pat if os.path.exists(slide_url): os.remove(slide_url) - print("\nExtracting CTransPath features from slide...") + print(f"\nExtracting {model_name} features from slide...") start_time = time.time() if len(canny_norm_patch_list) > 0: - extract_features_(model=model, model_name=model_name, norm_wsi_img=canny_norm_patch_list, - coords=coords_list, wsi_name=slide_name, outdir=feat_out_dir, cores=cores, - is_norm=norm, device=device if has_gpu else "cpu", target_microns=target_microns, - patch_size=patch_size) + extract_features_(model=extractor.model, transform=extractor.transform, model_name=model_name, + norm_wsi_img=canny_norm_patch_list, coords=coords_list, wsi_name=slide_name, + outdir=feat_out_dir, cores=cores, is_norm=norm, device=device if has_gpu else "cpu", + target_microns=target_microns, patch_size=patch_size) logging.info(f"Extracted features from slide: {time.time() - start_time:.2f} seconds ({len(canny_norm_patch_list)} tiles)") num_processed += 1 else: