Skip to content

Commit

Permalink
update train—faster method
Browse files Browse the repository at this point in the history
  • Loading branch information
lee-zq committed Mar 15, 2020
1 parent 14bcb95 commit 0689fe4
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Introduction
The repository is a 3DUNet implemented with pytorch, referring to this [project](https://github.com/panxiaobai/lits_pytorch). I have redesigned the code structure and used the model to perform liver and tumor segmentation on the lits2017 dataset.
requirement:
#### requirement:
```angular2
pytorch >= 1.1.0
torchvision
Expand All @@ -12,16 +12,16 @@ Scipy
```
##

### Quickly Start
#### 1) LITS2017 dataset preprocessing:
## Quickly Start
### 1) LITS2017 dataset preprocessing:
1. Download dataset from google drive: [Liver Tumor Segmentation Challenge.](https://drive.google.com/drive/folders/0B0vscETPGI1-Q1h1WFdEM2FHSUE)
Or from my share: https://pan.baidu.com/s/1WgP2Ttxn_CV-yRT4UyqHWw
Extraction code:hfl8
2. Then you need decompress the data set and put the volume data and segmentation labels into different local folders, such as `./dataset/data` and `./dataset/label`
3. Finally, you need to change the root path of the volume data and segmentation labels in `preprocess/preprocess_LiTS.py`, such as:
```
row_dataset_path = './dataset/' # path of origin dataset
fixed_dataset_path = './fixed/' # path of fixed/preprocessed dataset
fixed_dataset_path = './fixed/' # path of fixed(preprocessed) dataset
```
4 run `python preprocess/preprocess_LiTS.py`
If nothing goes wrong, you can see the following files in the dir `./fixed`
Expand All @@ -41,7 +41,9 @@ If nothing goes wrong, you can see the following files in the dir `./fixed`
segmentation-10.nii
...
```
#### 2) Training 3DUNet
### 2) Training 3DUNet
1. Firstly, you should change the some parameters in `config.py`,especially, please set `--dataset_path` to `./fixed`
All parameters are commented in the file `config.py`
2. Secondely,run `python train.py`
---
In addition, during the training process you will find that loading train data is time-consuming, you can use `train_faster.py` to train model. `train_faster.py` calls `./dataset/dataset2_lits.py`, which will crop multiple training samples from an input sample to form a batch for quick training
2 changes: 1 addition & 1 deletion train_faster.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def train(model, train_loader, optimizer, epoch, logger):
model = UNet(1, [32, 48, 64, 96, 128], 3, net_mode='3d',conv_block=RecombinationBlock).to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
init_util.print_network(model)
# model = nn.DataParallel(model, device_ids=[0]) # multi-GPU
# model = nn.DataParallel(model, device_ids=[0]) # multi-GPU

logger = logger.Logger('./output/{}'.format(args.save))
for epoch in range(1, args.epochs + 1):
Expand Down

0 comments on commit 0689fe4

Please sign in to comment.