Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Koori authored and Koori committed Apr 18, 2024
1 parent 4acb779 commit 05446e3
Show file tree
Hide file tree
Showing 13 changed files with 1,786 additions and 1 deletion.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
/utils/__pycache__
/output
/network/__pycache__
/modules/__pycache__
/build_dataset/__pycache__
**/__init__.py
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ This repository contains a PyTorch Lightning implementation of U-Net for Dual En
- This project also utilizes GitHub Copilot and ChatGPT-4 for code suggestions and debugging assistance.

#### Dataset
- The code is designed to be compatible with any DECT Pair Dataset.
- The code is designed to be compatible with any DECT Pair Dataset. Model hyperparameters should be fine-tuned on the data set to achieve optimal accuracy.
- **Note:** The private dataset PLAData scanned at Nanjing General Hospital of PLA, is not authorized for public distribution.

#### Contact Information
Expand Down
37 changes: 37 additions & 0 deletions U-NetModelREADME.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
### pytorch Implementation of U-Net, R2U-Net, Attention U-Net, Attention R2U-Net

**(This repository is no longer being updated)**

**U-Net: Convolutional Networks for Biomedical Image Segmentation**

https://arxiv.org/abs/1505.04597

**Recurrent Residual Convolutional Neural Network based on U-Net (R2U-Net) for Medical Image Segmentation**

https://arxiv.org/abs/1802.06955

**Attention U-Net: Learning Where to Look for the Pancreas**

https://arxiv.org/abs/1804.03999

**Attention R2U-Net : Just integration of two recent advanced works (R2U-Net + Attention U-Net)**


## U-Net
![U-Net](/img/U-Net.png)


## R2U-Net
![R2U-Net](/img/R2U-Net.png)

## Attention U-Net
![AttU-Net](/img/AttU-Net.png)

## Attention R2U-Net
![AttR2U-Net](/img/AttR2U-Net.png)

## Evaluation
we just test the models with [ISIC 2018 dataset](https://challenge2018.isic-archive.com/task1/training/). The dataset was split into three subsets, training set, validation set, and test set, which the proportion is 70%, 10% and 20% of the whole dataset, respectively. The entire dataset contains 2594 images where 1815 images were used
for training, 259 for validation and 520 for testing models.

![evaluation](/img/Evaluation.png)
56 changes: 56 additions & 0 deletions build_dataset/PLADataModule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

import numpy as np
import torch
from torch.utils.data import DataLoader


from utils.utils import normalize

torch.use_deterministic_algorithms(True, warn_only=True)
import os

from torch.utils.data import IterableDataset
from pytorch_lightning import LightningDataModule

class NPYIterableDataset(IterableDataset):
def __init__(self, root_dir):
super().__init__()
self.root_dir = root_dir
self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.npy')]

def __len__(self):
return len(self.files)

def __iter__(self):
for file in self.files:
data = np.load(file, allow_pickle=True).item()
image = data['data']
label = data['label']
patient_id = data['patient_id']
image_id = data['image_id']
yield torch.tensor(image, dtype=torch.float32), torch.tensor(label, dtype=torch.float32), str(patient_id), str(image_id)

class NanjingPLA_DECT(LightningDataModule):
def __init__(self,datatype,train_root_dir, valid_root_dir,test_root_dir, batch_size, gt_shape):
super().__init__()
self.datatype = datatype
self.train_root_dir = train_root_dir
self.valid_root_dir = valid_root_dir
self.test_root_dir = test_root_dir
self.batch_size = batch_size
self.gt_shape = gt_shape
self.train_mean,self.val_mean=0,0
self.train_std,self.val_std=1,1

def train_dataloader(self):
train_dataset = NPYIterableDataset(self.train_root_dir)
return DataLoader(train_dataset, batch_size=self.batch_size, num_workers=32,pin_memory=True,
prefetch_factor=2)

def val_dataloader(self):
valid_dataset = NPYIterableDataset(self.valid_root_dir)
return DataLoader(valid_dataset, batch_size=1, num_workers=16)

def test_dataloader(self):
test_dataset = NPYIterableDataset(self.test_root_dir)
return DataLoader(test_dataset, batch_size=1, num_workers=4)
67 changes: 67 additions & 0 deletions build_dataset/dataset_zipper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from argparse import ArgumentParser
import os
import glob
import numpy as np
import pydicom as dicom
import cv2

ROOT_FOLDER = '/data_new3/username/DL/PLA_data_bak/denoised/train'
ROOT_FOLDER2 = '/data_new3/username/DL/PLA_data/denoised/test'
MASK_PATH = '/data_new3/username/DL/scripts/result.tif'
SUB_FOLDER = ['100kv', '140kv']
START_SLICE = [10, 120, 150, 70, 15, 35, 35, 20, 15, 75, 35, 53, 45, 10, 90, 25]

def apply_mask(image, mask_path):
mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (image.shape[1], image.shape[0]))
masked_image = image.copy()
masked_image[mask == 0] = 0
return masked_image.astype(np.float32)

def save_patient_data(root_folder, save_path="/data_new3/username/DualEnergyCTSynthesis/dataset", dataset_type='train'):
print(root_folder)
patient_names = sorted([d for d in os.listdir(root_folder) if os.path.isdir(os.path.join(root_folder, d))])
patient_id=0
#print(patient_names)

for i, patient_name in enumerate(patient_names):
patient_folder = os.path.join(root_folder, patient_name)
data_files = glob.glob(os.path.join(patient_folder, SUB_FOLDER[0])+ '/*.IMA')
label_files = glob.glob(os.path.join(patient_folder, SUB_FOLDER[1])+ '/*.IMA')
data_files.sort()
label_files.sort()
#print(data_files)

if len(data_files) != len(label_files):
raise RuntimeError("Unequal number between data files and label files!")
patient_id = i
if dataset_type == 'train' and patient_id >=13:
dataset_type = 'valid'
for j, (data_file, label_file) in enumerate(zip(data_files, label_files)):
if j < START_SLICE[i] and dataset_type == 'train':
continue
data_dcm = dicom.read_file(data_file)
label_dcm = dicom.read_file(label_file)
data = apply_mask(data_dcm.pixel_array, MASK_PATH)
label = apply_mask(label_dcm.pixel_array, MASK_PATH)

file_name = f"{dataset_type}_{patient_id:02d}_{j + 1:03d}.npy"
print(f"Saving {file_name}...")
np.save(os.path.join(save_path, file_name), {'data': data, 'label': label, 'patient_id': patient_id, 'image_id': j + 1})


def main_func(save_path):
if not os.path.exists(save_path):
os.makedirs(save_path)

save_patient_data(ROOT_FOLDER,save_path, 'train')
save_patient_data(ROOT_FOLDER2,save_path, 'test')

print("Data saved.")

if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--save-path", type=str, default="/data_new3/username/DualEnergyCTSynthesis/dataset",
help="Path to save npy files.")
args = parser.parse_args()
main_func(**vars(args))
Loading

0 comments on commit 05446e3

Please sign in to comment.