Skip to content

Commit

Permalink
sampler modifies
Browse files Browse the repository at this point in the history
  • Loading branch information
arimousa committed Jan 17, 2023
1 parent 9c6899f commit adae31a
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 136 deletions.
163 changes: 91 additions & 72 deletions anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch.nn.functional as F
from kornia.filters import gaussian_blur2d
import torchvision

from torchvision.transforms import transforms
from utilities import *
from backbone import *
from dataset import *
Expand All @@ -14,83 +14,57 @@

def heat_map(outputs, targets, feature_extractor, constants_dict, config):
sigma = 4
kernel_size = 2*int(4 * sigma + 0.5) +1
kernel_size = 2 * int(4 * sigma + 0.5) +1
anomaly_score = 0
for output, target in zip(outputs, targets):
i_d = color_distance(output, target, config)
f_d = feature_distance(output, target,feature_extractor, constants_dict, config)
print('image_distance : ',torch.mean(i_d))
print('feature_distance : ',torch.mean(f_d))

visualalize_distance(output, target, i_d, f_d)
i_d = color_distance(output, target, config)
f_d = feature_distance(output, target, feature_extractor, constants_dict, config)
# print('image_distance mean : ',torch.mean(i_d))
# print('feature_distance mean : ',torch.mean(f_d))
# print('image_distance max : ',torch.max(i_d))
# print('feature_distance max : ',torch.max(f_d))

visualalize_distance(output, target, i_d, f_d)

anomaly_score += i_d #(f_d + .4 * i_d)
anomaly_score += f_d + .2* i_d #0.7 * f_d + 0.3 * i_d # .8*

anomaly_score = gaussian_blur2d(
anomaly_score , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
)
anomaly_score = torch.sum(anomaly_score, dim=1).unsqueeze(1)
print( 'anomaly_score : ',torch.mean(anomaly_score))
# print( 'anomaly_score : ',torch.mean(anomaly_score))

return anomaly_score

def rgb_to_cmyk(r, g, b):
def rgb_to_cmyk(images):
RGB_SCALE = 1
CMYK_SCALE = 100
# if (r, g, b) == (0, 0, 0):
# # black
# return 0, 0, 0, CMYK_SCALE

# rgb [0,255] -> cmy [0,1]
c = 1 - r / RGB_SCALE
m = 1 - g / RGB_SCALE
y = 1 - b / RGB_SCALE

# extract out k [0, 1]
min_cmy = torch.zeros(c.shape, device=c.device)

c = c.view(-1)
m = m.view(-1)
y = y.view(-1)
min_cmy = min_cmy.view(-1)
for i in range(len(c)):
min_cmy[i] = min(c[i], m[i], y[i])
c = c.view((256,256))
m = m.view((256,256))
y = y.view((256,256))
min_cmy = min_cmy.view((256,256))

c = (c - min_cmy) / (1 - min_cmy)
m = (m - min_cmy) / (1 - min_cmy)
y = (y - min_cmy) / (1 - min_cmy)
k = min_cmy

# rescale to the range [0,CMYK_SCALE]
return c * CMYK_SCALE, m * CMYK_SCALE, y * CMYK_SCALE, k * CMYK_SCALE
cmy = 1 - images / RGB_SCALE

min_cmy = torch.zeros(images.shape, device=images.device)
min_cmy = torch.amin(cmy, dim=1).unsqueeze(1)-.001
cmy = (cmy - min_cmy) / (1 - min_cmy)
k = min_cmy
cmyk = torch.cat((cmy,k), dim=1)
return cmyk * CMYK_SCALE


def color_distance(image1, image2, config):
image1 = ((image1 - image1.min())/ (image1.max() - image1.min()))
image2 = ((image2 - image2.min())/ (image2.max() - image2.min()))

for i, (img1, img2) in enumerate(zip(image1, image2)):
c1,m1,y1,k1 = rgb_to_cmyk(img1[0,:,:], img1[1,:,:], img1[2,:,:])
c2,m2,y2,k2 = rgb_to_cmyk(img2[0], img2[1], img2[2])
img1_cmyk = torch.stack((c1,m1,y1), dim=0)
print('img1_cmyk : ',img1_cmyk.shape)
img2_cmyk = torch.stack((c2,m2,y2), dim=0)
img1_cmyk = img1_cmyk.to(config.model.device)
img2_cmyk = img2_cmyk.to(config.model.device)
distance_map = torch.abs(img1_cmyk - img2_cmyk).to(config.model.device)
distance_map = torch.mean(distance_map, dim=0).unsqueeze(0)
if i == 0:
batch = distance_map
else:
batch = torch.cat((batch , distance_map) , dim=0)
batch = batch.unsqueeze(1)
print('batch :', batch.shape)
return batch

cmyk_image_1 = rgb_to_cmyk(image1)
cmyk_image_2 = rgb_to_cmyk(image2)

cmyk_image_1 = ((cmyk_image_1 - cmyk_image_1.min())/ (cmyk_image_1.max() - cmyk_image_1.min()))
cmyk_image_2 = ((cmyk_image_2 - cmyk_image_2.min())/ (cmyk_image_2.max() - cmyk_image_2.min()))

distance_map = cmyk_image_1.to(config.model.device) - cmyk_image_2.to(config.model.device)

distance_map = torch.abs(distance_map)
distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)
return distance_map


# distance_map = image1.to(config.model.device) - image2.to(config.model.device)
Expand All @@ -104,22 +78,67 @@ def color_distance(image1, image2, config):

def feature_distance(output, target,feature_extractor, constants_dict, config):

outputs_features = extract_features(feature_extractor=feature_extractor, x=output.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(output.to(config.model.device))
targets_features = extract_features(feature_extractor=feature_extractor, x=target.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(target.to(config.model.device))
# output = ((output - output.min())/ (output.max() - output.min()))
# target = ((target - target.min())/ (target.max() - target.min()))

# reversed = transforms.Compose([
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
# ])

# output = reversed(output)
# target = reversed(target)

# print('output : ', output.max(), output.min())

outputs_features = extract_features(feature_extractor=feature_extractor, x=output.to(config.model.device), out_indices=[2,3], config=config)
targets_features = extract_features(feature_extractor=feature_extractor, x=target.to(config.model.device), out_indices=[2,3], config=config)

outputs_features = (outputs_features - outputs_features.min())/ (outputs_features.max() - outputs_features.min())
targets_features = (targets_features - targets_features.min())/ (targets_features.max() - targets_features.min())

# distance_map = (outputs_features.to(config.model.device) - targets_features.to(config.model.device))
# distance_map = torch.abs(distance_map)
# distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)
outputs_features = (outputs_features - outputs_features.min())/ (outputs_features.max() - outputs_features.min())
targets_features = (targets_features - targets_features.min())/ (targets_features.max() - targets_features.min())

distance_map = 1 - F.cosine_similarity(outputs_features.to(config.model.device), targets_features.to(config.model.device), dim=1).to(config.model.device)
distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="nearest")
distance_map = 1 - F.cosine_similarity(outputs_features.to(config.model.device), targets_features.to(config.model.device), dim=1).to(config.model.device)
distance_map = torch.unsqueeze(distance_map, dim=1)

# distance_map = torch.mean(torch.pow((outputs_features - targets_features), 2), dim=1).unsqueeze(1)
distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="bilinear")

return distance_map

return distance_map


# patches1_features = []
# patches2_features = []
# patch_size = (64, 64)
# stride = (32, 32)
# print('output : ', output.shape)
# # patchify the two images
# patches1 = output.unfold(2, patch_size[0], patch_size[0]).unfold(3, patch_size[1], patch_size[1])
# patches2 = target.unfold(2, patch_size[0], patch_size[0]).unfold(3, patch_size[1], patch_size[1])
# print('patches1 : ', len(patches1))
# print('patches[0] : ', patches1[0].shape)

# patches1 = torch.stack(patches1, dim=0)
# patches2 = torch.stack(patches2, dim=0)
# print('patches1 stack : ', patches1.shape)
# for patch1, patch2 in zip(patches1, patches2):
# patch1_feature = extract_features(feature_extractor=feature_extractor, x=patch1.to(config.model.device), out_indices=[2,3], config=config)
# patch2_feature = extract_features(feature_extractor=feature_extractor, x=patch2.to(config.model.device), out_indices=[2,3], config=config)
# patches1_features.append(patch1_feature)
# patches2_features.append(patch2_feature)


# print('image1 : ', image1.shape)

# image1 = image1.view(32,3,256,256)
# image2 = image2.view(32,3,256,256)
# distance_map = 1 - F.cosine_similarity(image1.to(config.model.device), image2.to(config.model.device), dim=1).to(config.model.device)
# distance_map = torch.unsqueeze(distance_map, dim=1)
# distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="bilinear")

# def patchify(img, patch_size):
# patches = []
# for i in range(0, img.shape[1], patch_size):
# for j in range(0, img.shape[2], patch_size):
# patch = img[:, i:i+patch_size, j:j+patch_size]
# patches.append(patch)
# return patches
20 changes: 14 additions & 6 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data :
name: mvtec #mtd # mvtec #mvtec if the dataset is MVTec AD, otherwise, it is the name of your desire dataset
data_dir: datasets/MVTec #MTD #MVTec
category: cable #['hazelnut', 'bottle', 'cable', 'carpet', 'leather', 'capsule', 'grid', 'pill','transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
category: metal_nut #['hazelnut', 'bottle', 'cable', 'carpet', 'leather', 'capsule', 'grid', 'pill','transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
image_size: 256
batch_size: 32
mask : True
Expand All @@ -12,18 +12,22 @@ model:
checkpoint_dir: checkpoints/MVTec #MTD
checkpoint_name: weights
exp_name: default
backbone: resnet18 #resnet34 #resnet18 # wide_resnet50_2
backbone: wide_resnet101_2 #wide_resnet101_2 #resnet18 # wide_resnet50_2 #resnet34
pre_trained: True
noise : Gaussian # options : [Gaussian, Perlin]
schedule : linear # options: [linear, quad, const, jsd, sigmoid]
fine_tune : False
learning_rate: 1e-4 #0.0002
weight_decay: 0 #0.00001
epochs: 400
weight_decay: 0
epochs: 1000
trajectory_steps: 1000
test_trajectoy_steps: 200 #200
test_trajectoy_steps: 400 #200
test_trajectoy_steps2: 300
generate_time_steps: 800
skip : 5
skip : 40 #10
skip2 : 30
sigma : 0.5
eta : 0.99
beta_start : 0.0001 # 0.0001
beta_end : 0.02 # 0.006 for 300
ema : True
Expand All @@ -42,3 +46,7 @@ metrics:
method: adaptive #options: [adaptive, manual]
manual_image: null
manual_pixel: null




79 changes: 34 additions & 45 deletions feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,16 @@ def fake_real_dataset(config, constants_dict):
image = image.to(config.model.device)
model = build_model(config)
if config.data.category:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'600')) # config.model.checkpoint_name 300+50
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'200')) # config.model.checkpoint_name 300+50
else:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), '600'))
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), '200'))
model.load_state_dict(checkpoint)
model.to(config.model.device)
model.eval()
generate_time_steps = torch.Tensor([config.model.generate_time_steps]).type(torch.int64)
noise = get_noise(image,config)
# noise = forward_diffusion_sample(image, generate_time_steps, constants_dict, config)[0]
seq = range(0, config.model.generate_time_steps, config.model.skip)
seq = range(0, config.model.generate_time_steps, 50 * config.model.skip)
H_funcs = Denoising(config.data.imput_channel, config.data.image_size, config.model.device)
reconstructed,_ = efficient_generalized_steps(config, noise, seq, model, constants_dict['betas'], H_funcs, image, cls_fn=None, classes=None) #generalized_steps(noise, seq, model, constants_dict['betas'], config, eta=config.model.eta)
generated_image = reconstructed[-1]
Expand All @@ -61,7 +61,7 @@ def fake_real_dataset(config, constants_dict):
real_label = torch.Tensor([0,1]).type(torch.float32).to(config.model.device)
R_F_dataset.append((fake.type(torch.float32), fake_label))
R_F_dataset.append((real.type(torch.float32), real_label))
break
# break
return R_F_dataset


Expand All @@ -70,59 +70,48 @@ def fake_real_dataset(config, constants_dict):


def tune_feature_extractor(constants_dict, config):
# R_F_dataset = fake_real_dataset(config, constants_dict)
# R_F_dataloader = torch.utils.data.DataLoader(R_F_dataset, batch_size=config.data.batch_size, shuffle=True)

feature_extractor = timm.create_model(
config.model.backbone,
pretrained=True,
num_classes=2,
)
# print(feature_extractor.get_classifier())




# num_in_features = feature_extractor.get_classifier().in_features
# feature_extractor.fc = nn.Sequential(
# nn.BatchNorm1d(num_in_features),
# nn.Linear(num_in_features, 512, bias = True),
# nn.ReLU(),
# nn.BatchNorm1d(512),
# nn.Dropout(0.4),
# nn.Linear(512, 2, bias = False),
# )
feature_extractor.to(config.model.device)
feature_extractor.train()
optimizer = torch.optim.Adam(feature_extractor.parameters(), lr=config.model.learning_rate)
criterion = nn.CrossEntropyLoss() #nn.BCELoss()
print("Start training feature extractor")
if False:
for epoch in tqdm(range(100)):
feature_extractor.to(config.model.device)
if config.model.fine_tune:
R_F_dataset = fake_real_dataset(config, constants_dict)
R_F_dataloader = torch.utils.data.DataLoader(R_F_dataset, batch_size=config.data.batch_size, shuffle=True)
feature_extractor.train()
optimizer = torch.optim.SGD(feature_extractor.parameters(), lr=0.001, momentum=0.9) #config.model.learning_rate
criterion = nn.CrossEntropyLoss()
print("Start training feature extractor")
for epoch in tqdm(range(50)):
for step, batch in enumerate(R_F_dataloader):
image = batch[0]
label = batch[1]
# plt.figure(figsize=(11,11))
# plt.axis('off')
# plt.subplot(1, 1, 1)
# plt.imshow(show_tensor_image(image))
# plt.title(label[0])
# plt.savefig('results/F_or_R{}_{}.png'.format(epoch,step))
# plt.close()
if epoch == 49:
plt.figure(figsize=(11,11))
plt.axis('off')
plt.subplot(1, 1, 1)
plt.imshow(show_tensor_image(image))
plt.title(label[0])
plt.savefig('results/F_or_R{}_{}.png'.format(epoch,step))
plt.close()
output = feature_extractor(image)
if epoch ==49:
for l, o in zip(label, output):
print('output : ' , o , 'label : ' , l,'\n')
# if epoch ==49:
# for l, o in zip(label, output):
# print('output : ' , o , 'label : ' , l,'\n')
loss = criterion(output, label)
loss.requires_grad = True
optimizer.zero_grad()

loss.backward()
optimizer.step()
print("Epoch: ", epoch, "Loss: ", loss.item())
if config.data.category:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'feature'))
else:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), 'feature'))
if epoch % 10 == 0:
print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, 50, loss.item()))
if config.data.category:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'feature'))
else:
torch.save(feature_extractor.state_dict(), os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), 'feature'))

return feature_extractor

Expand All @@ -144,13 +133,13 @@ def extract_features(feature_extractor, x, out_indices, config):
activations = []
for name, module in feature_extractor.named_children():
x = module(x)
if name in ['layer1', 'layer2' ,'layer3']:
activations.append(x )
if name in ['layer1', 'layer2', 'layer3']:
activations.append(x)
embeddings = activations[0]
for feature in activations[1:]:
layer_embedding = feature
layer_embedding = F.interpolate(layer_embedding, size=int(embeddings.shape[-2]), mode='bilinear', align_corners=False)
embeddings = torch.cat((embeddings,layer_embedding),1)
return embeddings


2 changes: 1 addition & 1 deletion loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def get_loss(model, constant_dict, x_0, t, config):
# # x = x_0 * a.sqrt() + e * (1.0 - a).sqrt()
# output = model(x, t.float())

# #return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)
# return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)


x_0 = x_0.to(config.model.device)
Expand Down
Loading

0 comments on commit adae31a

Please sign in to comment.