forked from wangsssky/MedicalMatting
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request wangsssky#1 from wangsssky/master
init
- Loading branch information
Showing
34 changed files
with
2,335 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,114 @@ | ||
# MedicalMatting | ||
# Medical Matting: A New Perspective on Medical Segmentation with Uncertainty, [arxiv](https://arxiv.org/abs/2106.09887) | ||
|
||
The datasets and codes of the paper "Medical Matting: A New Perspective on Medical Segmentation with Uncertainty" will be uploaded here. | ||
This is a PyTorch implementation of our paper. We introduce matting as a soft segmentation method and a new perspective to deal with and represent uncertain regions into medical scenes. | ||
|
||
|
||
## Reference | ||
``` | ||
@InProceedings{10.1007/978-3-030-87199-4_54, | ||
author="Wang, Lin and Ju, Lie and Zhang, Donghao and Wang, Xin and He, Wanji and Huang, Yelin and Yang, Zhiwen and Yao, Xuan and Zhao, Xin and Ye, Xiufen and Ge, Zongyuan", | ||
title="Medical Matting: A New Perspective on Medical Segmentation with Uncertainty", | ||
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2021", | ||
year="2021", | ||
pages="573--583", | ||
} | ||
``` | ||
A jounal extension version can be achieved at [arxiv](https://arxiv.org/abs/2106.09887v3). | ||
|
||
## 1. Requirements | ||
|
||
torch>=1.8.0; torchvision>=0.9.0; matplotlib; numpy; opencv_python; pandas; Pillow; PyYAML; scikit_image; scikit_learn; scipy; skimage; tensorboardX; tqdm; pickle; | ||
|
||
data prepare: unpack the zip files in the dataset folder first. | ||
|
||
Directory structure in this repo: | ||
``` | ||
│MedicalMatting/ | ||
│ config.py | ||
│ evaluate.py | ||
│ params_brain.yaml | ||
│ params_isic.yaml | ||
│ params_lidc.yaml | ||
│ README.md | ||
│ train.py | ||
+---dataloader | ||
│ data_loader.py | ||
│ data_spliter.py | ||
│ transform.py | ||
│ utils.py | ||
+---dataset | ||
│ brain_growth_alpha.pkl | ||
│ isic_attributes.pkl | ||
│ lidc_attributes.pkl | ||
+---model | ||
│ │ loss_functions.py | ||
│ │ loss_strategy.py | ||
│ │ medical_matting.py | ||
│ │ utils.py | ||
│ +---matting_network | ||
│ │ cbam.py | ||
│ │ matting_net.py | ||
│ │ resnet_block.py | ||
│ +---metrics | ||
│ │ compute_connectivity_error.py | ||
│ │ compute_gradient_loss.py | ||
│ │ compute_mse_loss.py | ||
│ │ compute_sad_loss.py | ||
│ │ dice_accuracy.py | ||
│ │ generalised_energy_distance.py | ||
│ │ utils.py | ||
│ +---probabilistic_unet | ||
│ │ axis_aligned_conv_gaussian.py | ||
│ │ encoder.py | ||
│ │ fcomb.py | ||
│ │ prob_unet.py | ||
│ \---unet | ||
│ unet.py | ||
│ unet_blocks.py | ||
+---models | ||
\---utils | ||
logger.py | ||
utils.py | ||
``` | ||
|
||
## 2. Train | ||
|
||
- LIDC-IDRI | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_lidc.yaml | ||
``` | ||
|
||
- ISIC | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_isic.yaml | ||
``` | ||
|
||
- Brain-growth | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_brain.yaml | ||
``` | ||
|
||
## 3. Evaluation | ||
``` | ||
CUDA_VISIBLE_DEVICES=0. python evaluate.py --config /path/to/params_**task.yaml \ | ||
--save_path /path/to/your/model/dir | ||
``` | ||
|
||
## Acknowledgements | ||
The following code is referenced in this repo. | ||
- A PyTorch implementation of [Probablistic UNet](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch). | ||
- [CBAM](https://github.com/Jongchan/attention-module) | ||
- We reimplement the evaluation metrics in Matlab [DIM_evaluate_code](https://sites.google.com/view/deepimagematting) with PyTorch, more details please refer to [DIM_evaluation_code_python](https://github.com/wangsssky/DIM_evaluation_code_python). | ||
|
||
Datasets: | ||
|
||
The datasets in this paper were constructed based on the LIDC-IDRI, ISIC, and Brain-growth dataset, and the rights to the images used are owned by the original datasets. Please refer to the requirements of the original datasets for any use of the original images. | ||
- [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI) | ||
- The authors acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study. | ||
- The LIDC-IDRI dataset used in this paper was obtained based on further annotation of these [patches](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch). | ||
- [ISIC](https://www.isic-archive.com/#!/topWithHeader/wideContentTop/main) | ||
- [Brain-growth](https://qubiq21.grand-challenge.org/QUBIQ2021/) | ||
|
||
## LICENSE | ||
|
||
The codes of this repo is under the GPL license. For commercial use, please contact with the authors. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# configuration for the models | ||
import yaml | ||
|
||
|
||
class Config: | ||
def __init__(self, config_path): | ||
|
||
with open(config_path, encoding='utf-8') as f: | ||
yaml_dict = yaml.load(f, Loader=yaml.FullLoader) | ||
|
||
# ----------- parse yaml ---------------# | ||
self.DATA_PATH = yaml_dict['DATA_PATH'] | ||
self.DATASET = yaml_dict['DATASET'] | ||
|
||
if 'MASK_NUM' in yaml_dict: | ||
self.MASK_NUM = yaml_dict['MASK_NUM'] | ||
else: | ||
if self.DATASET == 'brain-growth': | ||
self.MASK_NUM = 7 | ||
self.INPUT_CHANNEL = 1 | ||
self.INPUT_SIZE = 128 | ||
elif self.DATASET == 'lidc': | ||
self.MASK_NUM = 4 | ||
self.INPUT_CHANNEL = 1 | ||
self.INPUT_SIZE = 128 | ||
elif self.DATASET == 'isic': | ||
self.MASK_NUM = 3 | ||
self.INPUT_CHANNEL = 3 | ||
self.INPUT_SIZE = 256 | ||
else: | ||
raise ValueError('unsupport dataset {}'.format(self.DATASET)) | ||
print('MASK_NUM:', self.MASK_NUM) | ||
|
||
if 'LEVEL' in yaml_dict: | ||
self.LEVEL = yaml_dict['LEVEL'] | ||
print('LEVEL:', self.LEVEL) | ||
else: | ||
self.LEVEL = None | ||
|
||
self.KFOLD = yaml_dict['KFOLD'] | ||
self.RANDOM_SEED = yaml_dict['RANDOM_SEED'] | ||
|
||
self.USE_MATTING = yaml_dict['USE_MATTING'] | ||
if self.USE_MATTING: | ||
self.MODEL_NAME = 'ProbUnet_Matting' | ||
else: | ||
self.MODEL_NAME = 'ProbUnet' | ||
self.MODEL_DIR = yaml_dict['MODEL_DIR'] + self.MODEL_NAME | ||
self.UNCERTAINTY_MAP = yaml_dict['UNCERTAINTY_MAP'] | ||
|
||
self.EPOCH_NUM = yaml_dict['EPOCH_NUM'] | ||
self.RESUME_FROM = yaml_dict['RESUME_FROM'] | ||
self.TRAIN_MATTING_START_FROM = yaml_dict['TRAIN_MATTING_START_FROM'] | ||
|
||
self.TRAIN_BATCHSIZE = yaml_dict['TRAIN_BATCHSIZE'] | ||
self.VAL_BATCHSIZE = yaml_dict['VAL_BATCHSIZE'] | ||
self.TRAIN_TIME_AUG = yaml_dict['TRAIN_TIME_AUG'] | ||
|
||
self.OPTIMIZER = yaml_dict['OPTIMIZER'] | ||
self.WEIGHT_DECAY = yaml_dict['WEIGHT_DECAY'] | ||
self.MOMENTUM = yaml_dict['MOMENTUM'] | ||
self.LEARNING_RATE = float(yaml_dict['LEARNING_RATE']) | ||
self.WARM_LEN = yaml_dict['WARM_LEN'] | ||
|
||
self.GEN_TYPE = yaml_dict['GEN_TYPE'] | ||
self.NUM_FILTERS = yaml_dict['NUM_FILTERS'] | ||
self.LATENT_DIM = yaml_dict['LATENT_DIM'] | ||
self.SAMPLING_NUM = yaml_dict['SAMPLING_NUM'] | ||
self.USE_BN = yaml_dict['USE_BN'] | ||
self.POSTERIOR_TARGET = yaml_dict['POSTERIOR_TARGET'] | ||
|
||
# self.REG_SCALE = float(yaml_dict['REG_SCALE']) | ||
self.KL_SCALE = float(yaml_dict['KL_SCALE']) | ||
self.RECONSTRUCTION_SCALE = yaml_dict['RECONSTRUCTION_SCALE'] | ||
self.ALPHA_SCALE = yaml_dict['ALPHA_SCALE'] | ||
self.ALPHA_GRADIENT_SCALE = yaml_dict['ALPHA_GRADIENT_SCALE'] | ||
self.LOSS_STRATEGY = yaml_dict['LOSS_STRATEGY'] | ||
|
||
self.PRT_LOSS = yaml_dict['PRT_LOSS'] | ||
self.VISUALIZE = yaml_dict['VISUALIZE'] | ||
|
||
|
||
if __name__ == '__main__': | ||
cfg = Config(config_path='./params.yaml') | ||
print(cfg) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import torch | ||
from torch.utils.data.dataset import Dataset | ||
|
||
import os | ||
import copy | ||
import numpy as np | ||
import cv2 | ||
import pickle | ||
|
||
|
||
# Dataset for Medical Matting | ||
class AlphaDataset(Dataset): | ||
def __init__(self, dataset_location, input_size=128): | ||
self.images = [] | ||
self.mask_labels = [] | ||
self.alphas = [] | ||
self.series_uid = [] | ||
|
||
# read dataset | ||
max_bytes = 2 ** 31 - 1 | ||
data = {} | ||
print("Loading file", dataset_location) | ||
bytes_in = bytearray(0) | ||
file_size = os.path.getsize(dataset_location) | ||
with open(dataset_location, 'rb') as f_in: | ||
for _ in range(0, file_size, max_bytes): | ||
bytes_in += f_in.read(max_bytes) | ||
new_data = pickle.loads(bytes_in) | ||
data.update(new_data) | ||
|
||
# load dataset | ||
for key, value in data.items(): | ||
# image 0-255, alpha 0-255, mask [0,1] | ||
self.images.append(pad_im(value['image'], input_size)) | ||
masks = [] | ||
for mask in value['masks']: | ||
masks.append(pad_im(mask, input_size)) | ||
self.mask_labels.append(masks) | ||
if 'alpha' in value.keys(): | ||
self.alphas.append(pad_im(value['alpha'], input_size)) | ||
else: | ||
self.alphas.append(None) | ||
self.series_uid.append(value['series_uid']) | ||
|
||
# check | ||
assert (len(self.images) == len(self.mask_labels) == len(self.series_uid)) | ||
for image in self.images: | ||
assert np.max(image) <= 255 and np.min(image) >= 0 | ||
for alpha in self.alphas: | ||
assert np.max(alpha) <= 255 and np.min(alpha) >= 0 | ||
for mask in self.mask_labels: | ||
assert np.max(mask) <= 1 and np.min(mask) >= 0 | ||
|
||
# free | ||
del new_data | ||
del data | ||
|
||
def __getitem__(self, index): | ||
image = copy.deepcopy(self.images[index]) | ||
mask_labels = copy.deepcopy(self.mask_labels[index]) | ||
alpha = copy.deepcopy(self.alphas[index]) | ||
series_uid = self.series_uid[index] | ||
|
||
return image, mask_labels, alpha, series_uid | ||
|
||
# Override to give PyTorch size of dataset | ||
def __len__(self): | ||
return len(self.images) | ||
|
||
|
||
def pad_im(image, size, value=0): | ||
shape = image.shape | ||
if len(shape) == 2: | ||
h, w = shape | ||
else: | ||
h, w, c = shape | ||
|
||
if h == w: | ||
if h == size: | ||
padded_im = image | ||
else: | ||
padded_im = cv2.resize(image, (size, size), cv2.INTER_CUBIC) | ||
else: | ||
if h > w: | ||
pad_1 = (h - w) // 2 | ||
pad_2 = (h - w) - pad_1 | ||
padded_im = cv2.copyMakeBorder(image, 0, 0, pad_1, pad_2, cv2.BORDER_CONSTANT, value=value) | ||
else: | ||
pad_1 = (w - h) // 2 | ||
pad_2 = (w - h) - pad_1 | ||
padded_im = cv2.copyMakeBorder(image, pad_1, pad_2, 0, 0, cv2.BORDER_CONSTANT, value=value) | ||
if padded_im.shape[0] != size: | ||
padded_im = cv2.resize(padded_im, (size, size), cv2.INTER_CUBIC) | ||
|
||
return padded_im |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data.sampler import SubsetRandomSampler | ||
|
||
import numpy as np | ||
from sklearn.model_selection import KFold | ||
from dataloader.data_loader import AlphaDataset | ||
|
||
|
||
class AlphaDatasetSpliter(): | ||
def __init__(self, opt, input_size): | ||
self.opt = opt | ||
self.train_dataset = AlphaDataset( | ||
dataset_location=opt.DATA_PATH, input_size=input_size) | ||
self.test_dataset = AlphaDataset( | ||
dataset_location=opt.DATA_PATH, input_size=input_size) | ||
self.kf = KFold(n_splits=opt.KFOLD, shuffle=False) | ||
|
||
self.splits = [] | ||
|
||
if opt.DATASET == 'lidc': | ||
uid_dict = {} | ||
for idx, uid in enumerate(self.train_dataset.series_uid): | ||
pid = uid.split('_')[0] | ||
if pid in uid_dict.keys(): | ||
uid_dict[pid].append(idx) | ||
else: | ||
uid_dict[pid] = [idx] | ||
|
||
pids = list(uid_dict.keys()) | ||
np.random.seed(opt.RANDOM_SEED) | ||
np.random.shuffle(pids) | ||
for (train_pid_index, test_pid_index) in self.kf.split(np.arange(len(pids))): | ||
train_index = [] | ||
test_index = [] | ||
for pid_idx in train_pid_index: | ||
train_index += uid_dict[pids[pid_idx]] | ||
for pid_idx in test_pid_index: | ||
test_index += uid_dict[pids[pid_idx]] | ||
self.splits.append({'train_index': train_index, 'test_index': test_index}) | ||
else: | ||
indices = list(range(len(self.train_dataset))) | ||
np.random.seed(opt.RANDOM_SEED) | ||
np.random.shuffle(indices) | ||
for (train_index, test_index) in self.kf.split(np.arange(len(self.train_dataset))): | ||
self.splits.append({ | ||
'train_index': [indices[i] for i in train_index.tolist()], | ||
'test_index': [indices[i] for i in test_index.tolist()]}) | ||
|
||
def get_datasets(self, fold_idx): | ||
train_indices = self.splits[fold_idx]['train_index'] | ||
test_indices = self.splits[fold_idx]['test_index'] | ||
train_sampler = SubsetRandomSampler(train_indices) | ||
test_sampler = SubsetRandomSampler(test_indices) | ||
train_loader = DataLoader( | ||
self.train_dataset, batch_size=self.opt.TRAIN_BATCHSIZE, sampler=train_sampler) | ||
test_loader = DataLoader( | ||
self.test_dataset, batch_size=self.opt.VAL_BATCHSIZE, sampler=test_sampler) | ||
print("Number of training/test patches:", (len(train_indices), len(test_indices))) | ||
|
||
return train_loader, test_loader |
Oops, something went wrong.