Skip to content

Commit f75b111

Browse files
author
neelabh17
committed
mrnet
1 parent 003f4d2 commit f75b111

14 files changed

+589
-0
lines changed

MRNet-Single-Model/.gitignore

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
*.npy
2+
*.DS_Store
3+
4+
*.pth
5+
*.csv
6+
7+
*.pyc
8+
9+
.vscode
10+
11+
runs/
12+
images/
13+
*__pycache__*

MRNet-Single-Model/LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 Big Vision LLC
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

MRNet-Single-Model/README.md

+98
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
<div align="center">
2+
<img src="content/logo.jpg" width ="600" height="300"/>
3+
4+
# Stanford MRnet Challenge
5+
6+
**This repo contains code for the MRNet Challenge**
7+
8+
9+
For more details refer to https://stanfordmlgroup.github.io/competitions/mrnet/
10+
11+
</div>
12+
13+
# Install dependencies
14+
1. `pip install git+https://github.com/ncullen93/torchsample`
15+
2. `pip install nibabel`
16+
3. `pip install sklearn`
17+
4. `pip install pandas`
18+
19+
Install other dependencies as per requirement
20+
21+
# Instructions to run the training
22+
1. Clone the repository.
23+
24+
2. Download the dataset (~5.7 GB), and put `train` and `valid` folders along with all the the `.csv` files inside `images` folder at root directory.
25+
```Shell
26+
images/
27+
train/
28+
axial/
29+
sagittal/
30+
coronal/
31+
val/
32+
axial/
33+
sagittal/
34+
coronal/
35+
train-abnormal.csv
36+
train-acl.csv
37+
train-meniscus.csv
38+
valid-abnormal.csv
39+
valid-acl.csv
40+
valid-meniscus.csv
41+
```
42+
43+
3. Make a new folder called `weights` at root directory, and inside the `weights` folder create three more folders namely `acl`, `abnormal` and `meniscus`.
44+
45+
4. All the hyperparameters are defined in `config.py` file. Feel free to play around those.
46+
47+
5. Now finally run the training using `python train.py`. All the logs for tensorboard will be stored in the `runs` directory at the root of the project.
48+
49+
# Understanding the Dataset
50+
51+
<div align="center">
52+
53+
<img src="content/mri_scan.png" width ="650" height="600"/>
54+
55+
</div>
56+
57+
The dataset contains MRIs of different people. Each MRI consists of multiple images.
58+
Each MRI has data in 3 perpendicular planes. And each plane as variable number of slices.
59+
60+
Each slice is an `256x256` image
61+
62+
For example:
63+
64+
For `MRI 1` we will have 3 planes:
65+
66+
Plane 1- with 35 slices
67+
68+
Plane 2- with 34 slices
69+
70+
Place 3 with 35 slices
71+
72+
Each MRI has to be classisifed against 3 diseases
73+
74+
Major challenge with while selecting the model structure was the inconsistency in the data. Although the image size remains constant , the number of slices per plane are variable within a single MRI and varies across all MRIs.
75+
76+
So we are proposing a model for each plane. For each model the `batch size` will be variable and equal to `number of slices in the plane of the MRI`. So training each model, we will get features for each plane.
77+
78+
We also plan to have 3 separate models for each disease.
79+
80+
# Model Specifications
81+
82+
<div align="center">
83+
84+
<img src="content/model.png" width ="700" height="490"/>
85+
86+
</div>
87+
88+
We will be using Alexnet pretrained as a feature extractor. When we would have trained the 3 models on the 3 planes, we will use its feature extractor layer as an input to a `global` model for the final classification
89+
90+
# Contributors
91+
<p >
92+
-- Neelabh Madan
93+
<a href = https://github.com/neelabh17 target='blank'> <img src=https://github.com/edent/SuperTinyIcons/blob/master/images/svg/github.svg height='30' weight='30'/></a>
94+
<br>
95+
96+
-- Jatin Prakash <a href = https://github.com/bicycleman15 target='blank'> <img src=https://github.com/edent/SuperTinyIcons/blob/master/images/svg/github.svg height='30' weight='30'/></a>
97+
98+

MRNet-Single-Model/config.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
config = {
2+
'max_epoch' : 50,
3+
'log_train' : 100,
4+
'lr' : 1e-5,
5+
'starting_epoch' : 0,
6+
'batch_size' : 1,
7+
'log_val' : 10,
8+
'task' : 'abnormal', # "meniscus" and "acl" are the other options
9+
'weight_decay' : 0.01,
10+
'patience' : 5,
11+
'save_model' : 1,
12+
'exp_name' : 'test'
13+
}

MRNet-Single-Model/content/logo.jpg

27.4 KB
Loading

MRNet-Single-Model/content/model.png

207 KB
Loading
122 KB
Loading
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .dataset import MRData, load_data

MRNet-Single-Model/dataset/dataset.py

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
import os
2+
import pandas as pd
3+
import numpy as np
4+
5+
import torch
6+
import torch.utils.data as data
7+
8+
from torchsample.transforms import RandomRotate, RandomTranslate, RandomFlip, ToTensor, Compose, RandomAffine
9+
from torchvision import transforms
10+
11+
INPUT_DIM = 224
12+
MAX_PIXEL_VAL = 255
13+
MEAN = 58.09
14+
STDDEV = 49.73
15+
16+
class MRData():
17+
"""This class used to load MRnet dataset from `./images` dir
18+
"""
19+
20+
def __init__(self,task = 'acl', train = True, transform = None, weights = None):
21+
"""Initialize the dataset
22+
23+
Args:
24+
plane : along which plane to load the data
25+
task : for which task to load the labels
26+
train : whether to load the train or val data
27+
transform : which transforms to apply
28+
weights (Tensor) : Give wieghted loss to postive class eg. `weights=torch.tensor([2.223])`
29+
"""
30+
self.planes=['axial', 'coronal', 'sagittal']
31+
self.records = None
32+
# an empty dictionary
33+
self.image_path={}
34+
35+
if train:
36+
self.records = pd.read_csv('./images/train-{}.csv'.format(task),header=None, names=['id', 'label'])
37+
38+
'''
39+
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
40+
image for that plane}
41+
'''
42+
for plane in self.planes:
43+
self.image_path[plane] = './images/train/{}/'.format(plane)
44+
else:
45+
transform = None
46+
self.records = pd.read_csv('./images/valid-{}.csv'.format(task),header=None, names=['id', 'label'])
47+
'''
48+
self.image_path[<plane>]= dictionary {<plane>: path to folder containing
49+
image for that plane}
50+
'''
51+
for plane in self.planes:
52+
self.image_path[plane] = './images/valid/{}/'.format(plane)
53+
54+
55+
self.transform = transform
56+
57+
self.records['id'] = self.records['id'].map(
58+
lambda i: '0' * (4 - len(str(i))) + str(i))
59+
# empty dictionary
60+
self.paths={}
61+
for plane in self.planes:
62+
self.paths[plane] = [self.image_path[plane] + filename +
63+
'.npy' for filename in self.records['id'].tolist()]
64+
65+
self.labels = self.records['label'].tolist()
66+
67+
pos = sum(self.labels)
68+
neg = len(self.labels) - pos
69+
70+
# Find the wieghts of pos and neg classes
71+
if weights:
72+
self.weights = torch.FloatTensor(weights)
73+
else:
74+
self.weights = torch.FloatTensor([neg / pos])
75+
76+
print('Number of -ve samples : ', neg)
77+
print('Number of +ve samples : ', pos)
78+
print('Weights for loss is : ', self.weights)
79+
80+
def __len__(self):
81+
"""Return the total number of images in the dataset."""
82+
return len(self.records)
83+
84+
def __getitem__(self, index):
85+
"""
86+
Returns `(images,labels)` pair
87+
where image is a list [imgsPlane1,imgsPlane2,imgsPlane3]
88+
and labels is a list [gt,gt,gt]
89+
"""
90+
img_raw = {}
91+
92+
for plane in self.planes:
93+
img_raw[plane] = np.load(self.paths[plane][index])
94+
img_raw[plane] = self._resize_image(img_raw[plane])
95+
96+
label = self.labels[index]
97+
if label == 1:
98+
label = torch.FloatTensor([1])
99+
elif label == 0:
100+
label = torch.FloatTensor([0])
101+
102+
return [img_raw[plane] for plane in self.planes], label
103+
104+
def _resize_image(self, image):
105+
"""Resize the image to `(3,224,224)` and apply
106+
transforms if possible.
107+
"""
108+
# Resize the image
109+
pad = int((image.shape[2] - INPUT_DIM)/2)
110+
image = image[:,pad:-pad,pad:-pad]
111+
image = (image-np.min(image))/(np.max(image)-np.min(image))*MAX_PIXEL_VAL
112+
image = (image - MEAN) / STDDEV
113+
114+
if self.transform:
115+
image = self.transform(image)
116+
else:
117+
image = np.stack((image,)*3, axis=1)
118+
119+
image = torch.FloatTensor(image)
120+
return image
121+
122+
def load_data(task : str):
123+
124+
# Define the Augmentation here only
125+
augments = Compose([
126+
transforms.Lambda(lambda x: torch.Tensor(x)),
127+
RandomRotate(25),
128+
RandomTranslate([0.11, 0.11]),
129+
RandomFlip(),
130+
transforms.Lambda(lambda x: x.repeat(3, 1, 1, 1).permute(1, 0, 2, 3)),
131+
])
132+
133+
print('Loading Train Dataset of {} task...'.format(task))
134+
train_data = MRData(task, train=True, transform=augments)
135+
train_loader = data.DataLoader(
136+
train_data, batch_size=1, num_workers=11, shuffle=True
137+
)
138+
139+
print('Loading Validation Dataset of {} task...'.format(task))
140+
val_data = MRData(task, train=False)
141+
val_loader = data.DataLoader(
142+
val_data, batch_size=1, num_workers=11, shuffle=False
143+
)
144+
145+
return train_loader, val_loader, train_data.weights, val_data.weights

MRNet-Single-Model/models/MRnet.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import torch
2+
import torch.nn as nn
3+
from torchvision import models
4+
import os
5+
6+
class MRnet(nn.Module):
7+
"""MRnet uses pretrained resnet50 as a backbone to extract features
8+
"""
9+
10+
def __init__(self): # add conf file
11+
12+
super(MRnet,self).__init__()
13+
14+
# init three backbones for three axis
15+
self.axial = models.alexnet(pretrained=True).features
16+
self.coronal = models.alexnet(pretrained=True).features
17+
self.saggital = models.alexnet(pretrained=True).features
18+
19+
self.pool_axial = nn.AdaptiveAvgPool2d(1)
20+
self.pool_coronal = nn.AdaptiveAvgPool2d(1)
21+
self.pool_saggital = nn.AdaptiveAvgPool2d(1)
22+
23+
self.fc = nn.Sequential(
24+
nn.Linear(in_features=3*256,out_features=1)
25+
)
26+
27+
def forward(self,x):
28+
""" Input is given in the form of `[image1, image2, image3]` where
29+
`image1 = [1, slices, 3, 224, 224]`. Note that `1` is due to the
30+
dataloader assigning it a single batch.
31+
"""
32+
33+
# squeeze the first dimension as there
34+
# is only one patient in each batch
35+
images = [torch.squeeze(img, dim=0) for img in x]
36+
37+
image1 = self.axial(images[0])
38+
image2 = self.coronal(images[1])
39+
image3 = self.saggital(images[2])
40+
41+
image1 = self.pool_axial(image1).view(image1.size(0), -1)
42+
image2 = self.pool_coronal(image2).view(image2.size(0), -1)
43+
image3 = self.pool_saggital(image3).view(image3.size(0), -1)
44+
45+
image1 = torch.max(image1,dim=0,keepdim=True)[0]
46+
image2 = torch.max(image2,dim=0,keepdim=True)[0]
47+
image3 = torch.max(image3,dim=0,keepdim=True)[0]
48+
49+
output = torch.cat([image1,image2,image3], dim=1)
50+
51+
output = self.fc(output)
52+
return output
53+
54+
def _load_wieghts(self):
55+
"""load pretrained weights"""
56+
pass

MRNet-Single-Model/models/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .MRnet import MRnet

0 commit comments

Comments
 (0)