Skip to content

Commit

Permalink
rgb to ycml
Browse files Browse the repository at this point in the history
  • Loading branch information
arimousa committed Jan 10, 2023
1 parent 3a482e8 commit 9c6899f
Show file tree
Hide file tree
Showing 10 changed files with 265 additions and 93 deletions.
93 changes: 74 additions & 19 deletions anomaly_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,23 @@
from dataset import *
from visualize import *
from feature_extractor import *
from PIL import Image



def heat_map(outputs, targets, feature_extractor, constants_dict, config):
sigma = 4
kernel_size = 2*int(4 * sigma + 0.5) +1

i_d = color_distance(outputs, targets, config)
f_d = feature_distance(outputs, targets,feature_extractor, constants_dict, config)


print('image_distance : ',torch.mean(i_d))
print('feature_distance : ',torch.mean(f_d))
anomaly_score = 0
for output, target in zip(outputs, targets):
i_d = color_distance(output, target, config)
f_d = feature_distance(output, target,feature_extractor, constants_dict, config)
print('image_distance : ',torch.mean(i_d))
print('feature_distance : ',torch.mean(f_d))

visualalize_distance(output, target, i_d, f_d)

anomaly_score = (0.8) * f_d + (0.2) * i_d
anomaly_score += i_d #(f_d + .4 * i_d)

anomaly_score = gaussian_blur2d(
anomaly_score , kernel_size=(kernel_size,kernel_size), sigma=(sigma,sigma)
Expand All @@ -32,21 +34,71 @@ def heat_map(outputs, targets, feature_extractor, constants_dict, config):

return anomaly_score

def rgb_to_cmyk(r, g, b):
RGB_SCALE = 1
CMYK_SCALE = 100
# if (r, g, b) == (0, 0, 0):
# # black
# return 0, 0, 0, CMYK_SCALE

# rgb [0,255] -> cmy [0,1]
c = 1 - r / RGB_SCALE
m = 1 - g / RGB_SCALE
y = 1 - b / RGB_SCALE

# extract out k [0, 1]
min_cmy = torch.zeros(c.shape, device=c.device)

c = c.view(-1)
m = m.view(-1)
y = y.view(-1)
min_cmy = min_cmy.view(-1)
for i in range(len(c)):
min_cmy[i] = min(c[i], m[i], y[i])
c = c.view((256,256))
m = m.view((256,256))
y = y.view((256,256))
min_cmy = min_cmy.view((256,256))

c = (c - min_cmy) / (1 - min_cmy)
m = (m - min_cmy) / (1 - min_cmy)
y = (y - min_cmy) / (1 - min_cmy)
k = min_cmy

# rescale to the range [0,CMYK_SCALE]
return c * CMYK_SCALE, m * CMYK_SCALE, y * CMYK_SCALE, k * CMYK_SCALE


def color_distance(image1, image2, config):
image1 = (image1 - image1.min())/ (image1.max() - image1.min())
image2 = (image2 - image2.min())/ (image2.max() - image2.min())
image1 = ((image1 - image1.min())/ (image1.max() - image1.min()))
image2 = ((image2 - image2.min())/ (image2.max() - image2.min()))

for i, (img1, img2) in enumerate(zip(image1, image2)):
c1,m1,y1,k1 = rgb_to_cmyk(img1[0,:,:], img1[1,:,:], img1[2,:,:])
c2,m2,y2,k2 = rgb_to_cmyk(img2[0], img2[1], img2[2])
img1_cmyk = torch.stack((c1,m1,y1), dim=0)
print('img1_cmyk : ',img1_cmyk.shape)
img2_cmyk = torch.stack((c2,m2,y2), dim=0)
img1_cmyk = img1_cmyk.to(config.model.device)
img2_cmyk = img2_cmyk.to(config.model.device)
distance_map = torch.abs(img1_cmyk - img2_cmyk).to(config.model.device)
distance_map = torch.mean(distance_map, dim=0).unsqueeze(0)
if i == 0:
batch = distance_map
else:
batch = torch.cat((batch , distance_map) , dim=0)
batch = batch.unsqueeze(1)
print('batch :', batch.shape)
return batch

# distance_map = 1 - F.cosine_similarity(image1.to(config.model.device), image2.to(config.model.device), dim=1).to(config.model.device)
# distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = image1.to(config.model.device) - image2.to(config.model.device)
distance_map = torch.abs(distance_map)

# distance_map **= 2
distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)
return distance_map
# distance_map = image1.to(config.model.device) - image2.to(config.model.device)
# distance_map = torch.abs(distance_map)
# #visualalize_rgb(image1, image2 ,distance_map)

# distance_map = torch.mean(distance_map, dim=1).unsqueeze(1)
# return distance_map



Expand All @@ -55,8 +107,8 @@ def feature_distance(output, target,feature_extractor, constants_dict, config):
outputs_features = extract_features(feature_extractor=feature_extractor, x=output.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(output.to(config.model.device))
targets_features = extract_features(feature_extractor=feature_extractor, x=target.to(config.model.device), out_indices=[2,3], config=config) #feature_extractor(target.to(config.model.device))

# outputs_features = (outputs_features - outputs_features.min())/ (outputs_features.max() - outputs_features.min())
# targets_features = (targets_features - targets_features.min())/ (targets_features.max() - targets_features.min())
outputs_features = (outputs_features - outputs_features.min())/ (outputs_features.max() - outputs_features.min())
targets_features = (targets_features - targets_features.min())/ (targets_features.max() - targets_features.min())

# distance_map = (outputs_features.to(config.model.device) - targets_features.to(config.model.device))
# distance_map = torch.abs(distance_map)
Expand All @@ -66,5 +118,8 @@ def feature_distance(output, target,feature_extractor, constants_dict, config):
distance_map = torch.unsqueeze(distance_map, dim=1)

distance_map = F.interpolate(distance_map , size = int(config.data.image_size), mode="nearest")

# distance_map = torch.mean(torch.pow((outputs_features - targets_features), 2), dim=1).unsqueeze(1)

return distance_map

17 changes: 8 additions & 9 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
data :
name: mvtec #mtd # mvtec #mvtec if the dataset is MVTec AD, otherwise, it is the name of your desire dataset
data_dir: datasets/MVTec #MTD #MVTec
category: capsule #['hazelnut', 'bottle', 'cable', 'carpet', 'leather', 'capsule', 'grid', 'pill','transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
category: cable #['hazelnut', 'bottle', 'cable', 'carpet', 'leather', 'capsule', 'grid', 'pill','transistor', 'metal_nut', 'screw','toothbrush', 'zipper', 'tile', 'wood']
image_size: 256
batch_size: 32
mask : True
Expand All @@ -12,18 +12,18 @@ model:
checkpoint_dir: checkpoints/MVTec #MTD
checkpoint_name: weights
exp_name: default
backbone: resnet18 # wide_resnet50_2
backbone: resnet18 #resnet34 #resnet18 # wide_resnet50_2
pre_trained: True
noise : Gaussian # options : [Gaussian, Perlin]
schedule : linear # options: [linear, quad, const, jsd, sigmoid]
learning_rate: 0.001
weight_decay: 0.00001
epochs: 600
learning_rate: 1e-4 #0.0002
weight_decay: 0 #0.00001
epochs: 400
trajectory_steps: 1000
test_trajectoy_steps: 200
test_trajectoy_steps: 200 #200
generate_time_steps: 800
skip : 10
eta : 0.8
skip : 5
sigma : 0.5
beta_start : 0.0001 # 0.0001
beta_end : 0.02 # 0.006 for 300
ema : True
Expand All @@ -37,7 +37,6 @@ model:
metrics:
image_level_F1Score: True
image_level_AUROC: True
pixel_level_F1Score: True
pixel_level_AUROC: True
threshold:
method: adaptive #options: [adaptive, manual]
Expand Down
22 changes: 11 additions & 11 deletions feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,15 +70,19 @@ def fake_real_dataset(config, constants_dict):


def tune_feature_extractor(constants_dict, config):
R_F_dataset = fake_real_dataset(config, constants_dict)
R_F_dataloader = torch.utils.data.DataLoader(R_F_dataset, batch_size=config.data.batch_size, shuffle=True)
# R_F_dataset = fake_real_dataset(config, constants_dict)
# R_F_dataloader = torch.utils.data.DataLoader(R_F_dataset, batch_size=config.data.batch_size, shuffle=True)
feature_extractor = timm.create_model(
config.model.backbone,
pretrained=True,
num_classes=2,
)
print(feature_extractor.get_classifier())
num_in_features = feature_extractor.get_classifier().in_features
# print(feature_extractor.get_classifier())




# num_in_features = feature_extractor.get_classifier().in_features
# feature_extractor.fc = nn.Sequential(
# nn.BatchNorm1d(num_in_features),
# nn.Linear(num_in_features, 512, bias = True),
Expand Down Expand Up @@ -130,10 +134,7 @@ def extract_features(feature_extractor, x, out_indices, config):
with torch.no_grad():
feature_extractor.eval()
reverse_transforms = transforms.Compose([
transforms.Lambda(lambda t: (t + 1) / (2)),
# transforms.Lambda(lambda t: t * 255.),
# transforms.Lambda(lambda t: t.cpu().numpy().astype(np.uint8)),
# transforms.ToPILImage(),
transforms.Lambda(lambda t: (t + 1) / (2))
])
x = reverse_transforms(x)

Expand All @@ -143,9 +144,8 @@ def extract_features(feature_extractor, x, out_indices, config):
activations = []
for name, module in feature_extractor.named_children():
x = module(x)
# print('name : ', name)
if name in ['layer1', 'layer3']:
activations.append(x)
if name in ['layer1', 'layer2' ,'layer3']:
activations.append(x )
embeddings = activations[0]
for feature in activations[1:]:
layer_embedding = feature
Expand Down
11 changes: 11 additions & 0 deletions loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,17 @@


def get_loss(model, constant_dict, x_0, t, config):

# x_0 = x_0.to(config.model.device)
# b = constant_dict['betas'].to(config.model.device)
# x, e = forward_diffusion_sample(x_0, t, constant_dict, config)
# # a = (1-b).cumprod(dim=0).index_select(0, t).view(-1, 1, 1, 1).to(config.model.device)
# # x = x_0 * a.sqrt() + e * (1.0 - a).sqrt()
# output = model(x, t.float())

# #return (e - output).square().sum(dim=(1, 2, 3)).mean(dim=0)


x_0 = x_0.to(config.model.device)
b = constant_dict['betas'].to(config.model.device)
e = torch.randn_like(x_0, device = x_0.device)
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def evaluate(args):
config = OmegaConf.load(args.config)
model = build_model(config)
if config.data.category:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'600')) # config.model.checkpoint_name 300+50
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), config.data.category,'900')) # config.model.checkpoint_name 300+50
else:
checkpoint = torch.load(os.path.join(os.path.join(os.getcwd(), config.model.checkpoint_dir), '150'))
model.load_state_dict(checkpoint)
Expand Down
14 changes: 7 additions & 7 deletions metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,22 +32,22 @@ def metric(labels_list, predictions, anomaly_map_list, GT_list, config):
thresholdOpt = thresholds[thresholdOpt_index]

f1 = F1Score()
predictions0_1 = (predictions > thresholdOpt).int()
for i,(l,p) in enumerate(zip(labels_list, predictions0_1)):
print('sample : ', i, ' prediction is: ',p,' label is: ',l,'\n' ) if l != p else None

f1_scor = f1(predictions, labels_list)
f1_score_pixel = f1(resutls_embeddings, GT_embeddings)
f1_scor = f1(predictions0_1, labels_list)

if config.metrics.image_level_AUROC:
print(f'AUROC: {auroc_score}')
if config.metrics.image_level_F1Score:
print(f'F1SCORE: {f1_scor}')
if config.metrics.pixel_level_F1Score:
print(f'f1_score_pixel: {f1_score_pixel}')
if config.metrics.pixel_level_AUROC:
print(f"auroc_pixel{auroc_pixel} ")
if config.metrics.image_level_F1Score:
print(f'F1SCORE: {f1_scor}')

with open('readme.txt', 'a') as f:
f.write(
f"AUROC: {auroc_score} | auroc_pixel{auroc_pixel} | F1SCORE: {f1_scor} | f1_score_pixel: {f1_score_pixel}\n")
f"AUROC: {auroc_score} | auroc_pixel{auroc_pixel} | F1SCORE: {f1_scor} \n")
roc = roc.reset()
auroc = auroc.reset()
f1 = f1.reset()
Expand Down
53 changes: 26 additions & 27 deletions sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,36 @@
import tqdm
from tqdm import tqdm

# #https://github.com/ermongroup/ddim
# def generalized_steps(x, seq, model, b, config, **kwargs):
# with torch.no_grad():
# n = x.size(0)
# seq_next = [-1] + list(seq[:-1])
# x0_preds = []
# xs = [x]
# for i, j in zip(reversed(seq), reversed(seq_next)):
# t = (torch.ones(n) * i).to(x.device)
# next_t = (torch.ones(n) * j).to(x.device)
# at = compute_alpha(b, t.long(),config)
# at_next = compute_alpha(b, next_t.long(),config)
# xt = xs[-1].to('cuda')
# et = model(xt, t)
# x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
# x0_preds.append(x0_t.to('cpu'))
# c1 = (
# kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
# )
# c2 = ((1 - at_next) - c1 ** 2).sqrt()
# xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
# xs.append(xt_next.to('cpu'))

# return xs, x0_preds
#https://github.com/ermongroup/ddim
def generalized_steps(x, seq, model, b, config, **kwargs):
with torch.no_grad():
n = x.size(0)
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]
for i, j in zip(reversed(seq), reversed(seq_next)):
t = (torch.ones(n) * i).to(x.device)
next_t = (torch.ones(n) * j).to(x.device)
at = compute_alpha(b, t.long(),config)
at_next = compute_alpha(b, next_t.long(),config)
xt = xs[-1].to('cuda')
et = model(xt, t)
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt()
x0_preds.append(x0_t.to('cpu'))
c1 = (
kwargs.get("eta", 0) * ((1 - at / at_next) * (1 - at_next) / (1 - at)).sqrt()
)
c2 = ((1 - at_next) - c1 ** 2).sqrt()
xt_next = at_next.sqrt() * x0_t + c1 * torch.randn_like(x) + c2 * et
xs.append(xt_next.to('cpu'))

return xs, x0_preds

def efficient_generalized_steps(config, x, seq, model, b, H_funcs, y_0, cls_fn=None, classes=None):
with torch.no_grad():
#setup vectors used in the algorithm
sigma_0 = 0.5 #0.5
etaB = 1
sigma_0 = config.model.sigma #0.5
etaB = 1
etaA = 1
etaC = 1
singulars = H_funcs.singulars()
Expand Down Expand Up @@ -66,7 +66,6 @@ def efficient_generalized_steps(config, x, seq, model, b, H_funcs, y_0, cls_fn=N
seq_next = [-1] + list(seq[:-1])
x0_preds = []
xs = [x]

#iterate over the timesteps
for i, j in zip(reversed(seq), reversed(seq_next)):
t = (torch.ones(n) * i).to(x.device)
Expand Down
Loading

0 comments on commit 9c6899f

Please sign in to comment.