Skip to content

Commit

Permalink
Move download of uni model weights to stamp setup
Browse files Browse the repository at this point in the history
  • Loading branch information
cornzz committed Apr 12, 2024
1 parent 1c52154 commit b949d28
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Next, initialize STAMP and obtain the required configuration file, config.yaml,
stamp init
```

To download required resources such as the weights of the CTransPath feature extractor, run the following command:
To download required resources such as the weights of the feature extractor, run the following command:
```bash
stamp setup
```
Expand Down
21 changes: 17 additions & 4 deletions stamp/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,23 @@ def run_cli(args: argparse.Namespace):
with normalization_template_path.open("wb") as f:
f.write(r.content)
# Download feature extractor model
model_path = Path(cfg.preprocessing.model_path)
feat_extractor = cfg.preprocessing.feat_extractor
if feat_extractor == 'ctp':
model_path = Path(cfg.preprocessing.model_path)
elif feat_extractor == 'uni':
model_path = Path(f"{os.environ['STAMP_RESOURCES_DIR']}/uni/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin")
model_path.parent.mkdir(parents=True, exist_ok=True)
if model_path.exists():
print(f"Skipping download, feature extractor model already exists at {model_path}")
else:
print(f"Downloading CTransPath weights to {model_path}")
import gdown
gdown.download(CTRANSPATH_WEIGHTS_URL, str(model_path))
if feat_extractor == 'ctp':
print(f"Downloading CTransPath weights to {model_path}")
import gdown
gdown.download(CTRANSPATH_WEIGHTS_URL, str(model_path))
elif feat_extractor == 'uni':
print(f"Downloading UNI weights")
from uni.get_encoder import get_encoder
get_encoder(enc_name='uni', checkpoint='pytorch_model.bin', assets_dir=f"{os.environ['STAMP_RESOURCES_DIR']}/uni")
case "config":
print(OmegaConf.to_yaml(cfg, resolve=True))
case "preprocess":
Expand All @@ -107,6 +116,10 @@ def run_cli(args: argparse.Namespace):
raise ConfigurationError(f"Normalization template {c.normalization_template} does not exist, please run `stamp setup` to download it.")
if c.feat_extractor == 'ctp' and not Path(c.model_path).exists():
raise ConfigurationError(f"Feature extractor model {c.model_path} does not exist, please run `stamp setup` to download it.")
if c.feat_extractor == 'uni':
uni_path = f"{os.environ['STAMP_RESOURCES_DIR']}/uni/vit_large_patch16_224.dinov2.uni_mass100k/pytorch_model.bin"
if not Path(uni_path).exists():
raise ConfigurationError(f"Feature extractor model {uni_path} does not exist, please run `stamp setup` to download it.")
from .preprocessing.wsi_norm import preprocess
preprocess(
output_dir=Path(c.output_dir),
Expand Down

0 comments on commit b949d28

Please sign in to comment.