Skip to content

Commit

Permalink
Merge pull request wangsssky#1 from wangsssky/master
Browse files Browse the repository at this point in the history
init
  • Loading branch information
wangsssky authored Jan 1, 2022
2 parents 798f2bd + a3784c6 commit e560d10
Show file tree
Hide file tree
Showing 34 changed files with 2,335 additions and 2 deletions.
115 changes: 113 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,114 @@
# MedicalMatting
# Medical Matting: A New Perspective on Medical Segmentation with Uncertainty, [arxiv](https://arxiv.org/abs/2106.09887)

The datasets and codes of the paper "Medical Matting: A New Perspective on Medical Segmentation with Uncertainty" will be uploaded here.
This is a PyTorch implementation of our paper. We introduce matting as a soft segmentation method and a new perspective to deal with and represent uncertain regions into medical scenes.


## Reference
```
@InProceedings{10.1007/978-3-030-87199-4_54,
author="Wang, Lin and Ju, Lie and Zhang, Donghao and Wang, Xin and He, Wanji and Huang, Yelin and Yang, Zhiwen and Yao, Xuan and Zhao, Xin and Ye, Xiufen and Ge, Zongyuan",
title="Medical Matting: A New Perspective on Medical Segmentation with Uncertainty",
booktitle="Medical Image Computing and Computer Assisted Intervention -- MICCAI 2021",
year="2021",
pages="573--583",
}
```
A jounal extension version can be achieved at [arxiv](https://arxiv.org/abs/2106.09887v3).

## 1. Requirements

torch>=1.8.0; torchvision>=0.9.0; matplotlib; numpy; opencv_python; pandas; Pillow; PyYAML; scikit_image; scikit_learn; scipy; skimage; tensorboardX; tqdm; pickle;

data prepare: unpack the zip files in the dataset folder first.

Directory structure in this repo:
```
│MedicalMatting/
│ config.py
│ evaluate.py
│ params_brain.yaml
│ params_isic.yaml
│ params_lidc.yaml
│ README.md
│ train.py
+---dataloader
│ data_loader.py
│ data_spliter.py
│ transform.py
│ utils.py
+---dataset
│ brain_growth_alpha.pkl
│ isic_attributes.pkl
│ lidc_attributes.pkl
+---model
│ │ loss_functions.py
│ │ loss_strategy.py
│ │ medical_matting.py
│ │ utils.py
│ +---matting_network
│ │ cbam.py
│ │ matting_net.py
│ │ resnet_block.py
│ +---metrics
│ │ compute_connectivity_error.py
│ │ compute_gradient_loss.py
│ │ compute_mse_loss.py
│ │ compute_sad_loss.py
│ │ dice_accuracy.py
│ │ generalised_energy_distance.py
│ │ utils.py
│ +---probabilistic_unet
│ │ axis_aligned_conv_gaussian.py
│ │ encoder.py
│ │ fcomb.py
│ │ prob_unet.py
│ \---unet
│ unet.py
│ unet_blocks.py
+---models
\---utils
logger.py
utils.py
```

## 2. Train

- LIDC-IDRI
```bash
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_lidc.yaml
```

- ISIC
```bash
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_isic.yaml
```

- Brain-growth
```bash
CUDA_VISIBLE_DEVICES=0 python train.py --config /path/to/params_brain.yaml
```

## 3. Evaluation
```
CUDA_VISIBLE_DEVICES=0. python evaluate.py --config /path/to/params_**task.yaml \
--save_path /path/to/your/model/dir
```

## Acknowledgements
The following code is referenced in this repo.
- A PyTorch implementation of [Probablistic UNet](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch).
- [CBAM](https://github.com/Jongchan/attention-module)
- We reimplement the evaluation metrics in Matlab [DIM_evaluate_code](https://sites.google.com/view/deepimagematting) with PyTorch, more details please refer to [DIM_evaluation_code_python](https://github.com/wangsssky/DIM_evaluation_code_python).

Datasets:

The datasets in this paper were constructed based on the LIDC-IDRI, ISIC, and Brain-growth dataset, and the rights to the images used are owned by the original datasets. Please refer to the requirements of the original datasets for any use of the original images.
- [LIDC-IDRI](https://wiki.cancerimagingarchive.net/display/Public/LIDC-IDRI)
- The authors acknowledge the National Cancer Institute and the Foundation for the National Institutes of Health, and their critical role in the creation of the free publicly available LIDC/IDRI Database used in this study.
- The LIDC-IDRI dataset used in this paper was obtained based on further annotation of these [patches](https://github.com/stefanknegt/Probabilistic-Unet-Pytorch).
- [ISIC](https://www.isic-archive.com/#!/topWithHeader/wideContentTop/main)
- [Brain-growth](https://qubiq21.grand-challenge.org/QUBIQ2021/)

## LICENSE

The codes of this repo is under the GPL license. For commercial use, please contact with the authors.
89 changes: 89 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# configuration for the models
import yaml


class Config:
def __init__(self, config_path):

with open(config_path, encoding='utf-8') as f:
yaml_dict = yaml.load(f, Loader=yaml.FullLoader)

# ----------- parse yaml ---------------#
self.DATA_PATH = yaml_dict['DATA_PATH']
self.DATASET = yaml_dict['DATASET']

if 'MASK_NUM' in yaml_dict:
self.MASK_NUM = yaml_dict['MASK_NUM']
else:
if self.DATASET == 'brain-growth':
self.MASK_NUM = 7
self.INPUT_CHANNEL = 1
self.INPUT_SIZE = 128
elif self.DATASET == 'lidc':
self.MASK_NUM = 4
self.INPUT_CHANNEL = 1
self.INPUT_SIZE = 128
elif self.DATASET == 'isic':
self.MASK_NUM = 3
self.INPUT_CHANNEL = 3
self.INPUT_SIZE = 256
else:
raise ValueError('unsupport dataset {}'.format(self.DATASET))
print('MASK_NUM:', self.MASK_NUM)

if 'LEVEL' in yaml_dict:
self.LEVEL = yaml_dict['LEVEL']
print('LEVEL:', self.LEVEL)
else:
self.LEVEL = None

self.KFOLD = yaml_dict['KFOLD']
self.RANDOM_SEED = yaml_dict['RANDOM_SEED']

self.USE_MATTING = yaml_dict['USE_MATTING']
if self.USE_MATTING:
self.MODEL_NAME = 'ProbUnet_Matting'
else:
self.MODEL_NAME = 'ProbUnet'
self.MODEL_DIR = yaml_dict['MODEL_DIR'] + self.MODEL_NAME
self.UNCERTAINTY_MAP = yaml_dict['UNCERTAINTY_MAP']

self.EPOCH_NUM = yaml_dict['EPOCH_NUM']
self.RESUME_FROM = yaml_dict['RESUME_FROM']
self.TRAIN_MATTING_START_FROM = yaml_dict['TRAIN_MATTING_START_FROM']

self.TRAIN_BATCHSIZE = yaml_dict['TRAIN_BATCHSIZE']
self.VAL_BATCHSIZE = yaml_dict['VAL_BATCHSIZE']
self.TRAIN_TIME_AUG = yaml_dict['TRAIN_TIME_AUG']

self.OPTIMIZER = yaml_dict['OPTIMIZER']
self.WEIGHT_DECAY = yaml_dict['WEIGHT_DECAY']
self.MOMENTUM = yaml_dict['MOMENTUM']
self.LEARNING_RATE = float(yaml_dict['LEARNING_RATE'])
self.WARM_LEN = yaml_dict['WARM_LEN']

self.GEN_TYPE = yaml_dict['GEN_TYPE']
self.NUM_FILTERS = yaml_dict['NUM_FILTERS']
self.LATENT_DIM = yaml_dict['LATENT_DIM']
self.SAMPLING_NUM = yaml_dict['SAMPLING_NUM']
self.USE_BN = yaml_dict['USE_BN']
self.POSTERIOR_TARGET = yaml_dict['POSTERIOR_TARGET']

# self.REG_SCALE = float(yaml_dict['REG_SCALE'])
self.KL_SCALE = float(yaml_dict['KL_SCALE'])
self.RECONSTRUCTION_SCALE = yaml_dict['RECONSTRUCTION_SCALE']
self.ALPHA_SCALE = yaml_dict['ALPHA_SCALE']
self.ALPHA_GRADIENT_SCALE = yaml_dict['ALPHA_GRADIENT_SCALE']
self.LOSS_STRATEGY = yaml_dict['LOSS_STRATEGY']

self.PRT_LOSS = yaml_dict['PRT_LOSS']
self.VISUALIZE = yaml_dict['VISUALIZE']


if __name__ == '__main__':
cfg = Config(config_path='./params.yaml')
print(cfg)




95 changes: 95 additions & 0 deletions dataloader/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from torch.utils.data.dataset import Dataset

import os
import copy
import numpy as np
import cv2
import pickle


# Dataset for Medical Matting
class AlphaDataset(Dataset):
def __init__(self, dataset_location, input_size=128):
self.images = []
self.mask_labels = []
self.alphas = []
self.series_uid = []

# read dataset
max_bytes = 2 ** 31 - 1
data = {}
print("Loading file", dataset_location)
bytes_in = bytearray(0)
file_size = os.path.getsize(dataset_location)
with open(dataset_location, 'rb') as f_in:
for _ in range(0, file_size, max_bytes):
bytes_in += f_in.read(max_bytes)
new_data = pickle.loads(bytes_in)
data.update(new_data)

# load dataset
for key, value in data.items():
# image 0-255, alpha 0-255, mask [0,1]
self.images.append(pad_im(value['image'], input_size))
masks = []
for mask in value['masks']:
masks.append(pad_im(mask, input_size))
self.mask_labels.append(masks)
if 'alpha' in value.keys():
self.alphas.append(pad_im(value['alpha'], input_size))
else:
self.alphas.append(None)
self.series_uid.append(value['series_uid'])

# check
assert (len(self.images) == len(self.mask_labels) == len(self.series_uid))
for image in self.images:
assert np.max(image) <= 255 and np.min(image) >= 0
for alpha in self.alphas:
assert np.max(alpha) <= 255 and np.min(alpha) >= 0
for mask in self.mask_labels:
assert np.max(mask) <= 1 and np.min(mask) >= 0

# free
del new_data
del data

def __getitem__(self, index):
image = copy.deepcopy(self.images[index])
mask_labels = copy.deepcopy(self.mask_labels[index])
alpha = copy.deepcopy(self.alphas[index])
series_uid = self.series_uid[index]

return image, mask_labels, alpha, series_uid

# Override to give PyTorch size of dataset
def __len__(self):
return len(self.images)


def pad_im(image, size, value=0):
shape = image.shape
if len(shape) == 2:
h, w = shape
else:
h, w, c = shape

if h == w:
if h == size:
padded_im = image
else:
padded_im = cv2.resize(image, (size, size), cv2.INTER_CUBIC)
else:
if h > w:
pad_1 = (h - w) // 2
pad_2 = (h - w) - pad_1
padded_im = cv2.copyMakeBorder(image, 0, 0, pad_1, pad_2, cv2.BORDER_CONSTANT, value=value)
else:
pad_1 = (w - h) // 2
pad_2 = (w - h) - pad_1
padded_im = cv2.copyMakeBorder(image, pad_1, pad_2, 0, 0, cv2.BORDER_CONSTANT, value=value)
if padded_im.shape[0] != size:
padded_im = cv2.resize(padded_im, (size, size), cv2.INTER_CUBIC)

return padded_im
60 changes: 60 additions & 0 deletions dataloader/data_spliter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

import numpy as np
from sklearn.model_selection import KFold
from dataloader.data_loader import AlphaDataset


class AlphaDatasetSpliter():
def __init__(self, opt, input_size):
self.opt = opt
self.train_dataset = AlphaDataset(
dataset_location=opt.DATA_PATH, input_size=input_size)
self.test_dataset = AlphaDataset(
dataset_location=opt.DATA_PATH, input_size=input_size)
self.kf = KFold(n_splits=opt.KFOLD, shuffle=False)

self.splits = []

if opt.DATASET == 'lidc':
uid_dict = {}
for idx, uid in enumerate(self.train_dataset.series_uid):
pid = uid.split('_')[0]
if pid in uid_dict.keys():
uid_dict[pid].append(idx)
else:
uid_dict[pid] = [idx]

pids = list(uid_dict.keys())
np.random.seed(opt.RANDOM_SEED)
np.random.shuffle(pids)
for (train_pid_index, test_pid_index) in self.kf.split(np.arange(len(pids))):
train_index = []
test_index = []
for pid_idx in train_pid_index:
train_index += uid_dict[pids[pid_idx]]
for pid_idx in test_pid_index:
test_index += uid_dict[pids[pid_idx]]
self.splits.append({'train_index': train_index, 'test_index': test_index})
else:
indices = list(range(len(self.train_dataset)))
np.random.seed(opt.RANDOM_SEED)
np.random.shuffle(indices)
for (train_index, test_index) in self.kf.split(np.arange(len(self.train_dataset))):
self.splits.append({
'train_index': [indices[i] for i in train_index.tolist()],
'test_index': [indices[i] for i in test_index.tolist()]})

def get_datasets(self, fold_idx):
train_indices = self.splits[fold_idx]['train_index']
test_indices = self.splits[fold_idx]['test_index']
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(test_indices)
train_loader = DataLoader(
self.train_dataset, batch_size=self.opt.TRAIN_BATCHSIZE, sampler=train_sampler)
test_loader = DataLoader(
self.test_dataset, batch_size=self.opt.VAL_BATCHSIZE, sampler=test_sampler)
print("Number of training/test patches:", (len(train_indices), len(test_indices)))

return train_loader, test_loader
Loading

0 comments on commit e560d10

Please sign in to comment.