Skip to content

Commit

Permalink
[nnUNet/PyT] Upgrade PLT
Browse files Browse the repository at this point in the history
  • Loading branch information
michal2409 committed Apr 21, 2022
1 parent 4ae2ae4 commit 3e8897f
Show file tree
Hide file tree
Showing 20 changed files with 186 additions and 538 deletions.
2 changes: 1 addition & 1 deletion PyTorch/Segmentation/nnUNet/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ FROM ${FROM_IMAGE_NAME}

ADD ./requirements.txt .
RUN pip install --disable-pip-version-check -r requirements.txt
RUN pip install monai==0.8.0 --no-dependencies
RUN pip install monai==0.8.1 --no-dependencies
RUN pip uninstall -y torchtext
RUN pip install numpy --upgrade

Expand Down
2 changes: 1 addition & 1 deletion PyTorch/Segmentation/nnUNet/data_loading/dali_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
11 changes: 7 additions & 4 deletions PyTorch/Segmentation/nnUNet/data_loading/data_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -77,8 +77,11 @@ def get_split(data, idx):
return list(np.array(data)[idx])


def load_data(path, files_pattern):
return sorted(glob.glob(os.path.join(path, files_pattern)))
def load_data(path, files_pattern, non_empty=True):
data = sorted(glob.glob(os.path.join(path, files_pattern)))
if non_empty:
assert len(data) > 0, f"No data found in {path} with pattern {files_pattern}"
return data


def get_kfold_splitter(nfolds):
Expand All @@ -87,7 +90,7 @@ def get_kfold_splitter(nfolds):

def get_test_fnames(args, data_path, meta=None):
kfold = get_kfold_splitter(args.nfolds)
test_imgs = load_data(data_path, "*_x.npy")
test_imgs = load_data(data_path, "*_x.npy", non_empty=False)
if args.exec_mode == "predict" and "val" in data_path:
_, val_idx = list(kfold.split(test_imgs))[args.fold]
test_imgs = sorted(get_split(test_imgs, val_idx))
Expand Down
2 changes: 1 addition & 1 deletion PyTorch/Segmentation/nnUNet/data_preprocessing/configs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion PyTorch/Segmentation/nnUNet/download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion PyTorch/Segmentation/nnUNet/evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
103 changes: 49 additions & 54 deletions PyTorch/Segmentation/nnUNet/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,49 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ctypes
import os

import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, early_stopping
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, ModelSummary, RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from data_loading.data_module import DataModule
from nnunet.nn_unet import NNUnet
from utils.args import get_main_args
from utils.gpu_affinity import set_affinity
from utils.logger import LoggingCallback
from utils.utils import make_empty_dir, set_cuda_devices, verify_ckpt_path
from utils.utils import make_empty_dir, set_cuda_devices, set_granularity, verify_ckpt_path

if __name__ == "__main__":
args = get_main_args()

if args.affinity != "disabled":
set_affinity(int(os.getenv("LOCAL_RANK", "0")), args.gpus, mode=args.affinity)

# Limit number of CPU threads
os.environ["OMP_NUM_THREADS"] = "1"
# Set device limit on the current device cudaLimitMaxL2FetchGranularity = 0x05
_libcudart = ctypes.CDLL("libcudart.so")
pValue = ctypes.cast((ctypes.c_int * 1)(), ctypes.POINTER(ctypes.c_int))
_libcudart.cudaDeviceSetLimit(ctypes.c_int(0x05), ctypes.c_int(128))
_libcudart.cudaDeviceGetLimit(pValue, ctypes.c_int(0x05))
assert pValue.contents.value == 128

set_granularity() # Increase maximum fetch granularity of L2 to 128 bytes
set_cuda_devices(args)
seed_everything(args.seed)
data_module = DataModule(args)
data_module.prepare_data()
data_module.setup()
ckpt_path = verify_ckpt_path(args)

callbacks = None
model_ckpt = None
model = NNUnet(args)
callbacks = [RichProgressBar(), ModelSummary(max_depth=2)]
logger = False
if args.benchmark:
model = NNUnet(args)
batch_size = args.batch_size if args.exec_mode == "train" else args.val_batch_size
filnename = args.logname if args.logname is not None else "perf1.json"
callbacks = [
filnename = args.logname if args.logname is not None else "perf.json"
callbacks.append(
LoggingCallback(
log_dir=args.results,
filnename=filnename,
Expand All @@ -63,57 +48,67 @@
warmup=args.warmup,
dim=args.dim,
)
]
)
elif args.exec_mode == "train":
model = NNUnet(args)
early_stopping = EarlyStopping(monitor="dice_mean", patience=args.patience, verbose=True, mode="max")
callbacks = [early_stopping]
if args.tb_logs:
logger = TensorBoardLogger(
save_dir=f"{args.results}/tb_logs",
name=f"task={args.task}_dim={args.dim}_fold={args.fold}_precision={16 if args.amp else 32}",
default_hp_metric=False,
version=0,
)
callbacks.append(
EarlyStopping(
monitor="dice",
patience=args.patience,
verbose=True,
mode="max",
)
)
if args.save_ckpt:
model_ckpt = ModelCheckpoint(
dirpath=f"{args.ckpt_store_dir}/checkpoints", filename="{epoch}-{dice_mean:.2f}", monitor="dice_mean", mode="max", save_last=True
callbacks.append(
ModelCheckpoint(
dirpath=f"{args.ckpt_store_dir}/checkpoints",
filename="{epoch}-{dice:.2f}",
monitor="dice",
mode="max",
save_last=True,
)
)
callbacks.append(model_ckpt)
else: # Evaluation or inference
if ckpt_path is not None:
model = NNUnet.load_from_checkpoint(ckpt_path)
else:
model = NNUnet(args)

trainer = Trainer(
logger=False,
gpus=args.gpus,
precision=16 if args.amp else 32,
logger=logger,
default_root_dir=args.results,
benchmark=True,
deterministic=False,
min_epochs=args.epochs,
max_epochs=args.epochs,
sync_batchnorm=args.sync_batchnorm,
precision=16 if args.amp else 32,
gradient_clip_val=args.gradient_clip_val,
enable_checkpointing=args.save_ckpt,
callbacks=callbacks,
num_sanity_val_steps=0,
default_root_dir=args.results,
resume_from_checkpoint=ckpt_path,
accelerator="ddp" if args.gpus > 1 else None,
checkpoint_callback=args.save_ckpt,
accelerator="gpu",
devices=args.gpus,
num_nodes=args.nodes,
strategy="ddp" if args.gpus > 1 else None,
limit_train_batches=1.0 if args.train_batches == 0 else args.train_batches,
limit_val_batches=1.0 if args.test_batches == 0 else args.test_batches,
limit_test_batches=1.0 if args.test_batches == 0 else args.test_batches,
)

if args.benchmark:
if args.exec_mode == "train":
trainer.fit(model, train_dataloader=data_module.train_dataloader())
trainer.fit(model, train_dataloaders=data_module.train_dataloader())
else:
# warmup
trainer.test(model, test_dataloaders=data_module.test_dataloader())
trainer.test(model, dataloaders=data_module.test_dataloader(), verbose=False)
# benchmark run
trainer.current_epoch = 1
trainer.test(model, test_dataloaders=data_module.test_dataloader())
model.start_benchmark = 1
trainer.test(model, dataloaders=data_module.test_dataloader(), verbose=False)
elif args.exec_mode == "train":
trainer.fit(model, data_module)
trainer.fit(model, datamodule=data_module, ckpt_path=ckpt_path)
elif args.exec_mode == "evaluate":
model.args = args
trainer.test(model, test_dataloaders=data_module.val_dataloader())
trainer.validate(model, val_dataloaders=data_module.val_dataloader())
elif args.exec_mode == "predict":
if args.save_preds:
ckpt_name = "_".join(args.ckpt_path.split("/")[-1].split(".")[:-1])
Expand All @@ -125,4 +120,4 @@
model.save_dir = save_dir
make_empty_dir(save_dir)
model.args = args
trainer.test(model, test_dataloaders=data_module.test_dataloader())
trainer.test(model, test_dataloaders=data_module.test_dataloader(), ckpt_path=ckpt_path)
10 changes: 6 additions & 4 deletions PyTorch/Segmentation/nnUNet/nnunet/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,12 +20,14 @@ class Loss(nn.Module):
def __init__(self, focal):
super(Loss, self).__init__()
if focal:
self.loss = DiceFocalLoss(gamma=2.0, softmax=True, to_onehot_y=True, batch=True)
self.loss_fn = DiceFocalLoss(
include_background=False, softmax=True, to_onehot_y=True, batch=True, gamma=2.0
)
else:
self.loss = DiceCELoss(softmax=True, to_onehot_y=True, batch=True)
self.loss_fn = DiceCELoss(include_background=False, softmax=True, to_onehot_y=True, batch=True)

def forward(self, y_pred, y_true):
return self.loss(y_pred, y_true)
return self.loss_fn(y_pred, y_true)


class LossBraTS(nn.Module):
Expand Down
64 changes: 24 additions & 40 deletions PyTorch/Segmentation/nnUNet/nnunet/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,6 +13,8 @@
# limitations under the License.

import torch
from monai.metrics import compute_meandice, do_metric_reduction
from monai.networks.utils import one_hot
from torchmetrics import Metric


Expand All @@ -21,53 +23,35 @@ def __init__(self, n_class, brats):
super().__init__(dist_sync_on_step=False)
self.n_class = n_class
self.brats = brats
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("steps", default=torch.zeros(1), dist_reduce_fx="sum")
self.add_state("dice", default=torch.zeros((n_class,)), dist_reduce_fx="sum")
self.add_state("loss", default=torch.zeros(1), dist_reduce_fx="sum")

def update(self, preds, target, loss):
def update(self, p, y, l):
if self.brats:
p = (torch.sigmoid(p) > 0.5).int()
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
y = torch.stack([y_wt, y_tc, y_et], dim=1)
else:
p, y = self.ohe(torch.argmax(p, dim=1)), self.ohe(y)

self.steps += 1
self.dice += self.compute_stats_brats(preds, target) if self.brats else self.compute_stats(preds, target)
self.loss += loss
self.loss += l
self.dice += self.compute_metric(p, y, compute_meandice, 1, 0)

def compute(self):
return 100 * self.dice / self.steps, self.loss / self.steps

def compute_stats_brats(self, p, y):
scores = torch.zeros(self.n_class, device=p.device, dtype=torch.float32)
p = (torch.sigmoid(p) > 0.5).int()
y_wt, y_tc, y_et = y > 0, ((y == 1) + (y == 3)) > 0, y == 3
y = torch.stack([y_wt, y_tc, y_et], dim=1)
def ohe(self, x):
return one_hot(x.unsqueeze(1), num_classes=self.n_class + 1, dim=1)

for i in range(self.n_class):
p_i, y_i = p[:, i], y[:, i]
if (y_i != 1).all():
# no foreground class
scores[i - 1] += 1 if (p_i != 1).all() else 0
continue
tp, fn, fp = self.get_stats(p_i, y_i, 1)
denom = (2 * tp + fp + fn).to(torch.float)
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
scores[i - 1] += score_cls
return scores
def compute_metric(self, p, y, metric_fn, best_metric, worst_metric):
metric = metric_fn(p, y, include_background=self.brats)
metric = torch.nan_to_num(metric, nan=worst_metric, posinf=worst_metric, neginf=worst_metric)
metric = do_metric_reduction(metric, "mean_batch")[0]

def compute_stats(self, preds, target):
scores = torch.zeros(self.n_class, device=preds.device, dtype=torch.float32)
preds = torch.argmax(preds, dim=1)
for i in range(1, self.n_class + 1):
if (target != i).all():
# no foreground class
scores[i - 1] += 1 if (preds != i).all() else 0
continue
tp, fn, fp = self.get_stats(preds, target, i)
denom = (2 * tp + fp + fn).to(torch.float)
score_cls = (2 * tp).to(torch.float) / denom if torch.is_nonzero(denom) else 0.0
scores[i - 1] += score_cls
return scores
for i in range(self.n_class):
if (y[:, i] != 1).all():
metric[i - 1] += best_metric if (p[:, i] != 1).all() else worst_metric

@staticmethod
def get_stats(preds, target, class_idx):
tp = torch.logical_and(preds == class_idx, target == class_idx).sum()
fn = torch.logical_and(preds != class_idx, target == class_idx).sum()
fp = torch.logical_and(preds == class_idx, target != class_idx).sum()
return tp, fn, fp
return metric
Loading

0 comments on commit 3e8897f

Please sign in to comment.