Skip to content

Commit

Permalink
interpolation
Browse files Browse the repository at this point in the history
  • Loading branch information
arimousa committed Dec 8, 2022
1 parent 2fefb17 commit 198c7f9
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 41 deletions.
7 changes: 5 additions & 2 deletions anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@



def heat_map(outputs, targets, config):
def heat_map(outputs, targets, mean_train_dataset, config, v):
sigma = 4
kernel_size = 2*int(4 * sigma + 0.5) +1
anomaly_map = torch.zeros([outputs[0].shape[0], 3, int(config.data.image_size), int(config.data.image_size)], device = config.model.device)#distance_map.shape[0]
Expand All @@ -30,9 +30,12 @@ def heat_map(outputs, targets, config):
distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="bilinear")
# mean_train_dataset = torch.Tensor(mean_train_dataset).to(config.model.device)
distance_map_image = 1 - F.cosine_similarity(output.to(config.model.device), mean_train_dataset,dim=1).to(config.model.device)
distance_map_image = torch.unsqueeze(distance_map_image, dim=1)


anomaly_map += distance_map
anomaly_map += ((v/100)*distance_map + ((100-v)/100)*distance_map_image)

# anomaly_map += (output-target).square()*2 - 1

Expand Down
6 changes: 3 additions & 3 deletions config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ model:
schedule : linear # options: [linear, quad, const, jsd, sigmoid]
learning_rate: 0.001
weight_decay: 0.00001
epochs: 500
trajectory_steps: 601
test_trajectoy_steps: 600 # 230 cannot reconstruct missed components
epochs: 400
trajectory_steps: 301
test_trajectoy_steps: 300 # 230 cannot reconstruct missed components
beta_start : 0.0001 # 0.0001
beta_end : 0.006 # 0.006 for 300
ema : True
Expand Down
4 changes: 2 additions & 2 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from noise import *


def get_loss(model, constant_dict, x_0, t, config):
def get_loss(model, constant_dict, x_0, t, v, config):

x_noisy, noise = forward_diffusion_sample(x_0, t , constant_dict, config)

Expand Down Expand Up @@ -36,6 +36,6 @@ def get_loss(model, constant_dict, x_0, t, config):
cosloss += torch.mean(1-cos_loss(F_x_noisy[item].view(F_x_noisy[item].shape[0],-1),
F_x_prime_noisy[item].view(F_x_prime_noisy[item].shape[0],-1)))

return 0.5*(cosloss) + 0.5*(loss)
return (v/100)*(cosloss) + ((100-v)/100)*(loss)


21 changes: 15 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def build_model(config):

def train(args, category):
config = OmegaConf.load(args.config)
start = time.time()

model = build_model(config)
print("Num params: ", sum(p.numel() for p in model.parameters()))
model = model.to(config.model.device)
Expand All @@ -60,9 +60,14 @@ def train(args, category):
ema_helper = None
# model = torch.nn.DataParallel(model)
constants_dict = constant(config)
trainer(model, constants_dict, ema_helper, config, category)
end = time.time()
print('training time on ',config.model.epochs,' epochs is ', str(timedelta(seconds=end - start)),'\n')
for v in [0,10,20,30,40,50,60,70,80,90,100]:
start = time.time()
print('v_train : ',v,'\n')
with open('readme.txt', 'a') as f:
f.write(f'v_train : v \n')
trainer(model, constants_dict, v, ema_helper, config, category)
end = time.time()
print('training time on ',config.model.epochs,' epochs is ', str(timedelta(seconds=end - start)),'\n')
with open('readme.txt', 'a') as f:
f.write('\n training time is {}\n'.format(str(timedelta(seconds=end - start))))

Expand All @@ -71,7 +76,7 @@ def evaluate(args, category):
start = time.time()
config = OmegaConf.load(args.config)
model = build_model(config)
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir),category)) # config.model.checkpoint_name
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir),os.path.join(category,str(250)))) # config.model.checkpoint_name
model.load_state_dict(checkpoint)
model.to(config.model.device)
model.eval()
Expand All @@ -83,7 +88,11 @@ def evaluate(args, category):
else:
ema_helper = None
constants_dict = constant(config)
validate(model, constants_dict, config, category)
for v in [0,10,20,30,40,50,60,70,80,90,100]:
print('v_test : ',v,'\n')
with open('readme.txt', 'a') as f:
f.write(f'v_test : {v} \n')
validate(model, constants_dict, config, category, v)
end = time.time()
print('Test time is ', str(timedelta(seconds=end - start)))

Expand Down
2 changes: 1 addition & 1 deletion metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def metric(labels_list, predictions_max, predictions_mean, anomaly_map_list, GT_

with open('readme.txt', 'a') as f:
f.write(
f"AUROC_max: {auroc_max} | AUROC_mean: {auroc_mean} | auroc_pixel{auroc_pixel} | F1SCORE_max: {f1_scor_max} | F1SCORE_mean: {f1_scor_mean} | f1_score_pixel: {f1_score_pixel}")
f"AUROC_max: {auroc_max} | AUROC_mean: {auroc_mean} | auroc_pixel{auroc_pixel} | F1SCORE_max: {f1_scor_max} | F1SCORE_mean: {f1_scor_mean} | f1_score_pixel: {f1_score_pixel}\n")
5 changes: 2 additions & 3 deletions sample.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from noise import *

from utilities import *

@torch.no_grad()
def sample_timestep(config, model, constant_dict, x, t):
Expand All @@ -25,7 +25,6 @@ def sample_timestep(config, model, constant_dict, x, t):
return model_mean
else:
noise = get_noise(x, t, config)

return model_mean + torch.sqrt(posterior_variance_t) * noise
return model_mean + torch.sqrt(posterior_variance_t) * noise


44 changes: 26 additions & 18 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,11 @@
from metrics import metric

from EMA import EMAHelper


from torch.utils.tensorboard import SummaryWriter




@torch.no_grad()
def validate(model, constants_dict, config, category):
def validate(model, constants_dict, config, category, v):

test_dataset = MVTecDataset(
root= config.data.data_dir,
Expand All @@ -34,6 +30,28 @@ def validate(model, constants_dict, config, category):
drop_last=False,
)

train_dataset = MVTecDataset(
root= config.data.data_dir,
category=category,
input_size= config.data.image_size,
is_train=True,
)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=config.model.num_workers,
drop_last=True,
)

mean_train_dataset = torch.zeros([3 ,config.data.image_size, config.data.image_size])
n_saples = 0
for step, batch in enumerate(trainloader):
mean_train_dataset += batch[0].sum(dim=0)
n_saples += batch[0].shape[0]
mean_train_dataset /= n_saples
mean_train_dataset = mean_train_dataset.to(config.model.device)


labels_list = []
predictions_max = []
Expand All @@ -48,30 +66,20 @@ def validate(model, constants_dict, config, category):
data_forward = []
data_reconstructed = []

# for j in range(0,10):
# noisy_image = forward_diffusion_sample(data, test_trajectoy_steps, constants_dict, config)[0]
# for i in range(0,10)[::-1]:
# t = torch.full((1,), i, device=config.model.device, dtype=torch.long)
# noisy_image = test_sample_timestep(config, model, noisy_image.to(config.model.device), t, constants_dict)
# if j == 9:
# if i in [0,5,10]:
# f_image = forward_diffusion_sample(data, t , constants_dict, config)[0]
# data_forward.append(f_image)
# data_reconstructed.append(noisy_image)

noisy_image = forward_diffusion_sample(data, test_trajectoy_steps, constants_dict, config)[0]
for i in range(0,test_trajectoy_steps)[::-1]:
t = torch.full((1,), i, device=config.model.device, dtype=torch.long)
noisy_image = test_sample_timestep(config, model, noisy_image.to(config.model.device), t, constants_dict)
if i in [0,5]: #[0,5,10]
noisy_image = sample_timestep(config, model, constants_dict, noisy_image.to(config.model.device), t)
if i in [0,5,10]: #[0,5,10]
f_image = forward_diffusion_sample(data, t , constants_dict, config)[0]
data_forward.append(f_image)
data_reconstructed.append(noisy_image)




anomaly_map = heat_map(data_reconstructed, data_forward, config)
anomaly_map = heat_map(data_reconstructed, data_forward, mean_train_dataset , config, v)

for pred, label in zip(anomaly_map, labels):
labels_list.append(0 if label == 'good' else 1)
Expand Down
17 changes: 11 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@



def trainer(model, constant_dict, ema_helper, config, category):
def trainer(model, constants_dict, v, ema_helper, config, category):
with open('readme.txt', 'a') as f:
f.write(f"\n {category} : ")
optimizer = build_optimizer(model, config)
Expand All @@ -49,7 +49,7 @@ def trainer(model, constant_dict, ema_helper, config, category):


optimizer.zero_grad()
loss = get_loss(model, constant_dict, batch[0], t, config)
loss = get_loss(model, constants_dict, batch[0], t, v, config)
writer.add_scalar('loss', loss, epoch)

loss.backward()
Expand All @@ -60,15 +60,20 @@ def trainer(model, constant_dict, ema_helper, config, category):
print(f"Epoch {epoch} | Loss: {loss.item()}")
with open('readme.txt', 'a') as f:
f.write(f"\n Epoch {epoch} | Loss: {loss.item()} | ")
# if epoch %50 == 0 and step ==0:
# sample_plot_image(model, trainloader, constant_dict, epoch, category, config)
if epoch %50 == 0 and step ==0:
sample_plot_image(model, trainloader, constant_dict, epoch, category, config)
if epoch %50 == 0 and epoch > 0 and step ==0:
validate(model, constant_dict, config, category)
for v in [0,10,20,30,40,50,60,70,80,90,100]:
print('v_test : ',v,'\n')
with open('readme.txt', 'a') as f:
f.write(f'v_test : {v} \n')
validate(model, constants_dict, config, category, v)
#validate(model, constant_dict, config, category)
if config.model.save_model:
model_save_dir = os.path.join(os.getcwd(), config.model.checkpoint_dir)
if not os.path.exists(model_save_dir):
os.mkdir(model_save_dir)
torch.save(model.state_dict(), os.path.join(config.model.checkpoint_dir, os.path.join(category,str(epoch))), #config.model.checkpoint_name
torch.save(model.state_dict(), os.path.join(config.model.checkpoint_dir, os.path.join(category,str(f'{epoch}+{v}'))), #config.model.checkpoint_name
)


Expand Down
1 change: 1 addition & 0 deletions visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
from forward_process import *
from dataset import *
from sample import *

from noise import *

Expand Down

0 comments on commit 198c7f9

Please sign in to comment.