Skip to content

Commit

Permalink
增加消融实验:Only Fundus
Browse files Browse the repository at this point in the history
  • Loading branch information
TianYi2000 committed Jul 23, 2021
1 parent 5c4d0ad commit e06d566
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 166 deletions.
230 changes: 82 additions & 148 deletions net/two_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,75 +3,78 @@
import torch.optim as optim
import torchvision
from torchvision import models
from resnest.torch import resnest50
from net.SCNet.scnet import scnet50

def pretrain_models(model_name = 'resnet50', inner_feature=1000 ,lock_weight = False):
if model_name == "resnet18":
model = models.resnet18(pretrained=True)

elif model_name == "resnet34":
model = models.resnet34(pretrained=True)

elif model_name == "resnet50":
model = models.resnet50(pretrained=True)

# elif model_name == "resnest50":
# from resnest.torch import resnest50
# model = resnest50(pretrained=True)

elif model_name == "inceptionv3":
model = models.inception_v3(pretrained=True)
kernel_count = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
model.aux_logits = False

elif model_name == 'scnet50':
from net.SCNet.scnet import scnet50
model = scnet50(pretrained=True)
else:
return

if (lock_weight == True):
for p in model.parameters():
p.requires_grad = False
kernel_count = model.fc.in_features
model.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
return model

def load_models(model_path, model_name = 'resnet50',label_type ='single-label', inner_feature=1000 ,lock_weight = False):
model = torch.load(model_path)
if (lock_weight):
for p in model.parameters():
p.requires_grad = False
if 'resnet' in model_name and '50' not in model_name:
kernel_count = 512 # 读出来的---------------

elif 'resnest' in model_name or 'scnet' in model_name or 'resnet50' in model_name:
kernel_count = 2048 # 读出来的---------------

elif 'inception' in model_name:
kernel_count = 768 # 读出来的---------------
if label_type == 'multilabel':
model.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
model.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
kernel_count = 2048 # 读出来的---------------
model.aux_logits = False

# todo(hty):这里的nn.Linear(kernel_count, inner_feature)是否有办法赋予初始参数(而非全0或者是自带的某些默认初始参数)
if label_type == 'multilabel':
model.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
model.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))

return model

class TwoStreamNet(nn.Module):
def __init__(self, fundus_path, OCT_path, fundus_model='resnest50', OCT_model='inceptionv3',
def __init__(self, fundus_path,OCT_path, fundus_model='resnest50', OCT_model='inceptionv3',
num_classes=1000, label_type='single-label', inner_feature=1000):

super(TwoStreamNet, self).__init__()
self.label_type = label_type

self.model1 = torch.load(fundus_path)
self.model2 = torch.load(OCT_path)
# for p in self.model1.parameters():
# p.requires_grad = False
#
# for p in self.model2.parameters():
# p.requires_grad = False
# todo(hty):这里的nn.Linear(kernel_count, inner_feature)是否有办法赋予初始参数(而非全0或者是自带的某些默认初始参数)
if 'resnet' in fundus_model and '50' not in fundus_model:
kernel_count = 512 # 读出来的---------------
if self.label_type == 'multilabel':
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid()) # kernelCount
else:
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif 'resnest' in fundus_model or 'scnet' in fundus_model or 'resnet50' in fundus_model:
kernel_count = 2048 # 读出来的---------------
if self.label_type == 'multilabel':
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid()) # kernelCount
else:
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif 'inception' in fundus_model:
kernel_count = 768 # 读出来的---------------
if self.label_type == 'multilabel':
self.model1.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
self.model1.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
kernel_count = 2048 # 读出来的---------------
if self.label_type == 'multilabel':
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
self.model1.aux_logits = False

if 'resnet' in OCT_model and '50' not in OCT_model:
kernel_count = 512 # 读出来的---------------
if self.label_type == 'multilabel':
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid()) # kernelCount
else:
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif 'resnest' in OCT_model or 'scnet' in OCT_model or 'resnet50' in OCT_model:
kernel_count = 2048 # 读出来的---------------
if self.label_type == 'multilabel':
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid()) # kernelCount
else:
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif 'inception' in OCT_model:
kernel_count = 768 # 读出来的---------------
if self.label_type == 'multilabel':
self.model2.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
self.model2.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
kernel_count = 2048 # 读出来的---------------
if self.label_type == 'multilabel':
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature), nn.Sigmoid())
else:
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
self.model2.aux_logits = False

self.model1 = load_models(model_path= fundus_path, model_name=fundus_model, label_type=label_type, inner_feature=inner_feature)
self.model2 = load_models(model_path= OCT_path, model_name=OCT_model, label_type=label_type, inner_feature=inner_feature)
#todo(hty):这里只有一层会不会不太够?
self.fc = nn.Sequential(nn.Linear(inner_feature * 2, num_classes))

Expand All @@ -82,6 +85,23 @@ def forward(self, x1, x2):
x = self.fc(torch.cat((x1, x2), 1))
return x

class Only_Fundus_Net(nn.Module):
def __init__(self, fundus_path, fundus_model='resnest50', OCT_model='inceptionv3',
num_classes=1000, label_type='single-label', inner_feature=1000):

super(Only_Fundus_Net, self).__init__()
self.label_type = label_type

self.model1 = load_models(model_path= fundus_path, model_name=fundus_model, label_type=label_type, inner_feature=inner_feature)
self.model2 = pretrain_models(model_name = OCT_model, inner_feature = inner_feature, lock_weight = False)
self.fc = nn.Sequential(nn.Linear(inner_feature * 2, num_classes))

def forward(self, x1, x2):

x1 = self.model1(x1)
x2 = self.model2(x2)
x = self.fc(torch.cat((x1, x2), 1))
return x

class BaseLineNet(nn.Module):
def __init__(self, fundus_model='resnest50', OCT_model='inceptionv3', num_classes=1000,
Expand All @@ -90,94 +110,8 @@ def __init__(self, fundus_model='resnest50', OCT_model='inceptionv3', num_classe
super(BaseLineNet, self).__init__()
self.label_type = label_type

if fundus_model == "resnet18":
self.model1 = models.resnet18(pretrained=True)
# for p in self.model1.parameters():
# p.requires_grad = False
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif fundus_model == "resnet34":
self.model1 = models.resnet34(pretrained=True)
# for p in self.model1.parameters():
# p.requires_grad = False
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif fundus_model == "resnet50":
self.model1 = models.resnet50(pretrained=True)
# for p in self.model1.parameters():
# p.requires_grad = False
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif fundus_model == "resnest50":
from resnest.torch import resnest50
self.model1 = resnest50(pretrained=True)
for p in self.model1.parameters():
p.requires_grad = False
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif fundus_model == "inceptionv3":
self.model1 = models.inception_v3(pretrained=True)
# for p in self.model1.parameters():
# p.requires_grad = False
kernel_count = self.model1.AuxLogits.fc.in_features
self.model1.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
self.model1.aux_logits = False
elif fundus_model == 'scnet50':
from net.SCNet.scnet import scnet50
self.model1 = scnet50(pretrained=True)
for p in self.model1.parameters():
p.requires_grad = False
kernel_count = self.model1.fc.in_features
self.model1.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
else:
return

if OCT_model == "resnet18":
self.model2 = models.resnet18(pretrained=True)
# for p in self.model2.parameters():
# p.requires_grad = False
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif OCT_model == "resnet34":
self.model2 = models.resnet34(pretrained=True)
# for p in self.model2.parameters():
# p.requires_grad = False
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif OCT_model == "resnet50":
self.model2 = models.resnet50(pretrained=True)
# for p in self.model2.parameters():
# p.requires_grad = False
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif OCT_model == "resnest50":
from resnest.torch import resnest50
self.model2 = resnest50(pretrained=True)
for p in self.model2.parameters():
p.requires_grad = False
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
elif OCT_model == "inceptionv3":
self.model2 = models.inception_v3(pretrained=True)
# for p in self.model2.parameters():
# p.requires_grad = False
kernel_count = self.model2.AuxLogits.fc.in_features
self.model2.AuxLogits.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
self.model2.aux_logits = False
elif OCT_model == 'scnet50':
from net.SCNet.scnet import scnet50
self.model2 = scnet50(pretrained=True)
for p in self.model2.parameters():
p.requires_grad = False
kernel_count = self.model2.fc.in_features
self.model2.fc = nn.Sequential(nn.Linear(kernel_count, inner_feature))
else:
return

self.model1 = pretrain_models(model_name = fundus_model, inner_feature = inner_feature, lock_weight = False)
self.model2 = pretrain_models(model_name = OCT_model, inner_feature = inner_feature, lock_weight = False)
self.fc = nn.Sequential(nn.Linear(inner_feature * 2, num_classes))

def forward(self, x1, x2):
Expand Down
6 changes: 3 additions & 3 deletions test_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import cv2

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

BATCH_SIZE = 8 # RECEIVED_PARAMS["batch_size"]
WORKERS = 1
Expand All @@ -32,8 +32,8 @@

print("Test baseline ", model_path)

data_dir = '/home/hejiawen/datasets/AMD_processed/'
list_dir = '/home/hejiawen/datasets/AMD_processed/label/new_two_stream/'
data_dir = '/home/hutianyi/datasets/AMD_processed/'
list_dir = '/home/hutianyi/datasets/AMD_processed/label/new_two_stream/'


def test(model, val_loader, criterion):
Expand Down
21 changes: 9 additions & 12 deletions train_OCT.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -9,8 +10,7 @@
import os
from tqdm import tqdm
from tensorboardX import SummaryWriter
import requests
api_key = "a03f079569c5423fb8eca7be41f8dda5" #微信通知记录
from utils.Message import message

from sklearn.metrics import f1_score, roc_auc_score, recall_score, precision_score, accuracy_score, hamming_loss
# from sklearn.utils.class_weight import compute_sample_weight
Expand All @@ -19,11 +19,11 @@
import torch.nn.functional as F

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

METHOD = ''
LABEL = 'multilabel'
MODEL = "inceptionv3"
MODEL = "scnet50"
LOSS = 'bceloss'

START_EPOCH = 0
Expand Down Expand Up @@ -218,6 +218,8 @@ def validate(model, val_loader, criterion, writer, epoch):
writer.add_scalar("Val/ELoss", out_loss, epoch)
tbar.close()
print(f1, auroc, recall, precision, acc, avg, hamming)
if epoch % 10 == 0:
message('Train_OCT_Epoch' + str(epoch), 'f1='+str(f1)+'\nauroc='+ str(auroc)+'\nrecall='+ str(recall)+'\nprecision='+ str(precision)+'\nacc='+ str(acc)+'\navg='+ str(avg)+'\nhamming='+ str(hamming))
return avg


Expand Down Expand Up @@ -371,18 +373,13 @@ def main():
max_avg = avg
writer.close()

def message(title, body):
"""
微信通知打卡结果
"""
url = 'http://www.pushplus.plus/send?token='+api_key+'&title='+title+'&content='+body
requests.get(url)


if __name__ == '__main__':
import time
message('开始训练', f'模型为{model_name}')
message('开始训练Train_OCT', '模型为'+model_name)
start = time.time()
main()
end = time.time()
message('完成训练', f'总耗时{end - start}')
message('完成训练Train_OCT', '总耗时'+str(end - start))
print('总耗时', end - start)
6 changes: 3 additions & 3 deletions train_baseline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@

RESUME = False
NAME = METHOD + "+" + str(EPOCHS) + "+" + str(LR) + '+' + str(WEIGHT_DECAY) + '+' + LOSS
model_name = '2021_05_29+' + FUNDUS_MODEL + '+' + OCT_MODEL + '+' + NAME + '.pth'
model_name = '2021_07_23+' + FUNDUS_MODEL + '+' + OCT_MODEL + '+' + NAME + '.pth'

print("Train baseline ", model_name, 'RESUME:', RESUME)

data_dir = '/home/hejiawen/datasets/AMD_processed/'
list_dir = '/home/hejiawen/datasets/AMD_processed/label/new_two_stream/'
data_dir = '/home/hutianyi/datasets/AMD_processed/'
list_dir = '/home/hutianyi/datasets/AMD_processed/label/new_two_stream/'


def train(model, train_loader, optimizer, scheduler, criterion, writer, epoch):
Expand Down
Loading

0 comments on commit e06d566

Please sign in to comment.