Skip to content

Commit

Permalink
sampler modified
Browse files Browse the repository at this point in the history
  • Loading branch information
arimousa committed Jan 3, 2023
1 parent 9006221 commit 3a482e8
Show file tree
Hide file tree
Showing 12 changed files with 542 additions and 286 deletions.
79 changes: 44 additions & 35 deletions anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,56 +6,65 @@
from utilities import *
from backbone import *
from dataset import *
from visualize import *
from feature_extractor import *




def heat_map(outputs, targets, config, v ,mean_train_dataset, representation_backbone):
def heat_map(outputs, targets, feature_extractor, constants_dict, config):
sigma = 4
kernel_size = 2*int(4 * sigma + 0.5) +1
anomaly_score = torch.zeros([outputs[0].shape[0], 3, int(config.data.image_size), int(config.data.image_size)], device = config.model.device)#distance_map.shape[0]

distance_map_image = representation_score(outputs[-1], config, mean_train_dataset, representation_backbone)

for output in outputs:

feature_extractor = Feature_extractor(config = config, backbone = "wide_resnet50_2", out_indices=[1])
feature_extractor.to(config.model.device)
outputs_features = feature_extractor(output.to(config.model.device))
targets_features = feature_extractor(targets.to(config.model.device))

distance_map = 1 - F.cosine_similarity(outputs_features.to(config.model.device), targets_features.to(config.model.device),dim=1).to(config.model.device)
distance_map = torch.unsqueeze(distance_map, dim=1)
distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="bilinear")
anomaly_score += distance_map
i_d = color_distance(outputs, targets, config)
f_d = feature_distance(outputs, targets,feature_extractor, constants_dict, config)

print('distance_map_image : ',torch.mean(distance_map_image))
print('distance_map : ',torch.mean(anomaly_score))

anomaly_score = ((100-v)/100)*distance_map_image + ((v)/100)* anomaly_score

print('image_distance : ',torch.mean(i_d))
print('feature_distance : ',torch.mean(f_d))

anomaly_score = (0.8) * f_d + (0.2) * i_d

anomaly_score = gaussian_blur2d(
anomaly_score , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
)

anomaly_score = torchvision.transforms.functional.rgb_to_grayscale(anomaly_score)
anomaly_score = torch.sum(anomaly_score, dim=1).unsqueeze(1)
print( 'anomaly_score : ',torch.mean(anomaly_score))

return anomaly_score


def representation_score(outputs, config, mean_train_dataset, representation_backbone):
if representation_backbone == 'cait_m48_448':
feature_extractor = Feature_extractor(config = config, backbone = "cait_m48_448", out_indices=[1])
outputs = F.interpolate(outputs , size = 448, mode="bilinear")
else :
feature_extractor = Feature_extractor(config = config, backbone = "wide_resnet50_2", out_indices=[1])
feature_extractor.to(config.model.device)
outputs_features = feature_extractor(outputs.to(config.model.device))
mean_train_dataset = mean_train_dataset.to(config.model.device)
distance_map_image = 1 - F.cosine_similarity(outputs_features.to(config.model.device), mean_train_dataset,dim=1).to(config.model.device)
distance_map_image = torch.unsqueeze(distance_map_image, dim=1)
distance_map_image = F.interpolate(distance_map_image , size = int(config.data.image_size), mode="bilinear")
return distance_map_image

def color_distance(image1, image2, config):
image1 = (image1 - image1.min())/ (image1.max() - image1.min())
image2 = (image2 - image2.min())/ (image2.max() - image2.min())

# distance_map = 1 - F.cosine_similarity(image1.to(config.model.device), image2.to(config.model.device), dim=1).to(config.model.device)
# distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = image1.to(config.model.device) - image2.to(config.model.device)
distance_map = torch.abs(distance_map)

# distance_map **= 2
distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)
return distance_map



def feature_distance(output, target,feature_extractor, constants_dict, config):

outputs_features = extract_features(feature_extractor=feature_extractor, x=output.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(output.to(config.model.device))
targets_features = extract_features(feature_extractor=feature_extractor, x=target.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(target.to(config.model.device))

# outputs_features = (outputs_features - outputs_features.min())/ (outputs_features.max() - outputs_features.min())
# targets_features = (targets_features - targets_features.min())/ (targets_features.max() - targets_features.min())

# distance_map = (outputs_features.to(config.model.device) - targets_features.to(config.model.device))
# distance_map = torch.abs(distance_map)
# distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)

distance_map = 1 - F.cosine_similarity(outputs_features.to(config.model.device), targets_features.to(config.model.device), dim=1).to(config.model.device)
distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="nearest")
return distance_map

22 changes: 11 additions & 11 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,31 @@
data :
name: mvtec #mvtec if the dataset is MVTec AD, otherwise, it is the name of your desire dataset
data_dir: datasets/MVTec
category: None
name: mvtec #mtd # mvtec #mvtec if the dataset is MVTec AD, otherwise, it is the name of your desire dataset
data_dir: datasets/MVTec #MTD #MVTec
category: capsule #['hazelnut', 'bottle', 'cable', 'carpet', 'leather', 'capsule', 'grid', 'pill','transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
image_size: 256
batch_size: 32
mask : True
imput_channel : 3


model:
checkpoint_dir: checkpoints/500250
checkpoint_dir: checkpoints/MVTec #MTD
checkpoint_name: weights
exp_name: default
backbone: wide_resnet50_2 # wide_resnet50_2
representation_backbone: wide_resnet50_2 # wide_resnet50_2 cait_m48_448
backbone: resnet18 # wide_resnet50_2
pre_trained: True
noise : Gaussian # options : [Gaussian, Perlin]
schedule : linear # options: [linear, quad, const, jsd, sigmoid]
learning_rate: 0.001
weight_decay: 0.00001
epochs: 1000
trajectory_steps: 401
skip : 30
epochs: 600
trajectory_steps: 1000
test_trajectoy_steps: 200
generate_time_steps: 800
skip : 10
eta : 0.8
test_trajectoy_steps: 400 # 230 cannot reconstruct missed components
beta_start : 0.0001 # 0.0001
beta_end : 0.007 # 0.006 for 300
beta_end : 0.02 # 0.006 for 300
ema : True
ema_rate : 0.999
device: 'cuda' #<"cpu", "gpu", "tpu", "ipu">
Expand Down
156 changes: 156 additions & 0 deletions feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import torch
import torch.nn as nn
import tqdm
from tqdm import tqdm
from forward_process import *
from dataset import *
from dataset import *
import timm
from torch import Tensor, nn
from typing import Callable, List, Tuple, Union
from model import *
from omegaconf import OmegaConf
from sample import *
from visualize import *



def build_model(config):
#model = SimpleUnet()
model = UNetModel(256, 64, dropout=0, n_heads=4 ,in_channels=config.data.imput_channel)
return model

def fake_real_dataset(config, constants_dict):
train_dataset = Dataset(
root= config.data.data_dir,
category=config.data.category,
config = config,
is_train=True,
)
trainloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.data.batch_size,
shuffle=True,
num_workers=config.model.num_workers,
drop_last=True,
)
R_F_dataset=[]
print("Start generating fake real dataset")
for step, batch in tqdm(enumerate(trainloader), total=len(trainloader)):
image = batch[0]
image = image.to(config.model.device)
model = build_model(config)
if config.data.category:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'600')) # config.model.checkpoint_name 300+50
else:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), '600'))
model.load_state_dict(checkpoint)
model.to(config.model.device)
model.eval()
generate_time_steps = torch.Tensor([config.model.generate_time_steps]).type(torch.int64)
noise = get_noise(image,config)
# noise = forward_diffusion_sample(image, generate_time_steps, constants_dict, config)[0]
seq = range(0, config.model.generate_time_steps, config.model.skip)
H_funcs = Denoising(config.data.imput_channel, config.data.image_size, config.model.device)
reconstructed,_ = efficient_generalized_steps(config, noise, seq, model, constants_dict['betas'], H_funcs, image, cls_fn=None, classes=None) #generalized_steps(noise, seq, model, constants_dict['betas'], config, eta=config.model.eta)
generated_image = reconstructed[-1]
generated_image = generated_image.to(config.model.device)

for fake, real in zip(generated_image, image):
fake_label = torch.Tensor([1,0]).type(torch.float32).to(config.model.device)
real_label = torch.Tensor([0,1]).type(torch.float32).to(config.model.device)
R_F_dataset.append((fake.type(torch.float32), fake_label))
R_F_dataset.append((real.type(torch.float32), real_label))
break
return R_F_dataset






def tune_feature_extractor(constants_dict, config):
R_F_dataset = fake_real_dataset(config, constants_dict)
R_F_dataloader = torch.utils.data.DataLoader(R_F_dataset, batch_size=config.data.batch_size, shuffle=True)
feature_extractor = timm.create_model(
config.model.backbone,
pretrained=True,
num_classes=2,
)
print(feature_extractor.get_classifier())
num_in_features = feature_extractor.get_classifier().in_features
# feature_extractor.fc = nn.Sequential(
# nn.BatchNorm1d(num_in_features),
# nn.Linear(num_in_features, 512, bias = True),
# nn.ReLU(),
# nn.BatchNorm1d(512),
# nn.Dropout(0.4),
# nn.Linear(512, 2, bias = False),
# )
feature_extractor.to(config.model.device)
feature_extractor.train()
optimizer = torch.optim.Adam(feature_extractor.parameters(), lr=config.model.learning_rate)
criterion = nn.CrossEntropyLoss() #nn.BCELoss()
print("Start training feature extractor")
if False:
for epoch in tqdm(range(100)):
for step, batch in enumerate(R_F_dataloader):
image = batch[0]
label = batch[1]
# plt.figure(figsize=(11,11))
# plt.axis('off')
# plt.subplot(1, 1, 1)
# plt.imshow(show_tensor_image(image))
# plt.title(label[0])
# plt.savefig('results/F_or_R{}_{}.png'.format(epoch,step))
# plt.close()
output = feature_extractor(image)
if epoch ==49:
for l, o in zip(label, output):
print('output : ' , o , 'label : ' , l,'\n')
loss = criterion(output, label)
loss.requires_grad = True
optimizer.zero_grad()

loss.backward()
optimizer.step()
print("Epoch: ", epoch, "Loss: ", loss.item())
if config.data.category:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'feature'))
else:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), 'feature'))

return feature_extractor





def extract_features(feature_extractor, x, out_indices, config):
with torch.no_grad():
feature_extractor.eval()
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / (2)),
# transforms.Lambda(lambda t: t * 255.),
# transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
# transforms.ToPILImage(),
])
x = reverse_transforms(x)

for param in feature_extractor.parameters():
param.requires_grad = False
feature_extractor.features_only = True
activations = []
for name, module in feature_extractor.named_children():
x = module(x)
# print('name : ', name)
if name in ['layer1', 'layer3']:
activations.append(x)
embeddings = activations[0]
for feature in activations[1:]:
layer_embedding = feature
layer_embedding = F.interpolate(layer_embedding, size=int(embeddings.shape[-2]), mode='bilinear', align_corners=False)
embeddings = torch.cat((embeddings,layer_embedding),1)
return embeddings


2 changes: 1 addition & 1 deletion forward_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def forward_diffusion_sample(x_0, t, constant_dict, config):
"""
sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod= constant_dict['sqrt_alphas_cumprod'], constant_dict['sqrt_one_minus_alphas_cumprod']

noise = get_noise(x_0, t, config)
noise = get_noise(x_0, config)
device = config.model.device

sqrt_alphas_cumprod_t = get_index_from_list(sqrt_alphas_cumprod, t, x_0.shape, config)
Expand Down
41 changes: 9 additions & 32 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,12 @@
from noise import *


def get_loss(model, constant_dict, x_0, t, v, config):

x_noisy, noise = forward_diffusion_sample(x_0, t , constant_dict, config)

noise_pred = model(x_noisy, t)
# loss = F.l1_loss(noise, noise_pred)
loss = F.mse_loss(noise, noise_pred)
return loss



# cos_loss = torch.nn.CosineSimilarity()
# cosloss = 0
# x_0 = x_0.to(config.model.device)



# posterior_variance_t = get_index_from_list(constant_dict['posterior_variance'], t, noise_pred.shape, config)
# x_prime_noisy = x_noisy - torch.sqrt(posterior_variance_t) * noise_pred
# x_noisy_for = x_noisy - torch.sqrt(posterior_variance_t) * noise

# feature_extractor = Feature_extractor(backbone = "wide_resnet50_2",config = config, out_indices=[1])
# feature_extractor.to(config.model.device)
# F_x_noisy = feature_extractor(x_noisy_for.to(config.model.device))
# F_x_prime_noisy = feature_extractor(x_prime_noisy.to(config.model.device))
# for item in range(len(F_x_noisy)):
# cosloss += torch.mean(1-cos_loss(F_x_noisy[item].view(F_x_noisy[item].shape[0],-1),
# F_x_prime_noisy[item].view(F_x_prime_noisy[item].shape[0],-1)))
# # print('cosloss : ',cosloss)
# # print('loss : ',loss)
# return loss # (v/100)*(cosloss) + ((100-v)/100)*(loss)

def get_loss(model, constant_dict, x_0, t, config):
x_0 = x_0.to(config.model.device)
b = constant_dict['betas'].to(config.model.device)
e = torch.randn_like(x_0, device = x_0.device)
a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1).to(config.model.device)
x = x_0 * a.sqrt() + e * (1.0 - a).sqrt()
output = model(x, t.float())

return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
Loading

0 comments on commit 3a482e8

Please sign in to comment.