Skip to content

Commit

Permalink
transformers top model and training configs
Browse files Browse the repository at this point in the history
  • Loading branch information
anananan116 committed Mar 17, 2024
1 parent e0f1fa8 commit d504604
Show file tree
Hide file tree
Showing 14 changed files with 2,673 additions and 174 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,6 @@ autoencoders/VQ_VAE/layers_2.py
*.pth
!results/best_model.pth
*.npy
*.0
*.0
result_transformers/
results_top/
10 changes: 5 additions & 5 deletions configs_pixelsnail/default_bottom.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
batch_size: 16
epochs: 70
epochs: 30
lr: 2e-4
hier: bottom
channel: 256
n_res_block: 5
n_res_channel: 256
n_out_res_block: 0
n_cond_res_block: 3
n_cond_res_block: 4
dropout: 0.15
exp_id: 1
early_stop_patience: 10
validation_ratio: 0.1
exp_id: 2
early_stop_patience: 15
validation_ratio: 0.02
36 changes: 36 additions & 0 deletions configs_transformers/bottom.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
resolution: 64
hier: bottom
validation_ratio: 0.004
num_eval_per_epoch: 10
trainer:
per_device_train_batch_size: 4
per_device_eval_batch_size: 4
num_train_epochs: 20
learning_rate: 0.0008
weight_decay: 0.001
fp16: True
output_dir: ./results_bottom
evaluation_strategy: steps
lr_scheduler_type: cosine
warmup_ratio: 0.1
logging_dir: ./logs_bottom
logging_strategy: steps
save_strategy: epoch
logging_steps: 10
gradient_accumulation_steps: 6

model:
vocab_size: 1026
num_layers: 6
num_decoder_layers: 6
d_ff: 1024
d_model: 512
d_kv: 64
num_heads: 8
relative_attention_num_buckets: 32
relative_attention_max_distance: 128
dropout_rate: 0.1
pad_token_id: 0
eos_token_id: 1
decoder_start_token_id: 0
n_positions: 4098
33 changes: 33 additions & 0 deletions configs_transformers/top.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
resolution: 32
hier: top
validation_ratio: 0.005
num_eval_per_epoch: 10
trainer:
per_device_train_batch_size: 8
per_device_eval_batch_size: 8
num_train_epochs: 30
learning_rate: 0.0005
weight_decay: 0.001
fp16: True
output_dir: ./results_top
evaluation_strategy: steps
lr_scheduler_type: cosine
warmup_ratio: 0.1
logging_dir: ./logs_top
logging_strategy: steps
save_strategy: epoch
logging_steps: 10
gradient_accumulation_steps: 6

model:
vocab_size: 514
n_positions: 1026
n_embd: 1024
n_layer: 16
n_head: 16
resid_pdrop: 0.15
embd_pdrop: 0.15
attn_pdrop: 0.15
bos_token_id: 512
eos_token_id: 513
side: 32
23 changes: 19 additions & 4 deletions data_utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm
import torch
import os
import numpy as np
class ImageDataset(Dataset):
def __init__(self, image_files, save_porcessed= None, transform=None):
self.transform = transform
Expand Down Expand Up @@ -33,12 +34,26 @@ def __getitem__(self, idx):
return self.images[idx]

class latentDataset(Dataset):
def __init__(self, top_codes, bottom_codes):
self.top_codes = torch.tensor(top_codes, dtype=torch.long)
self.bottom_codes = torch.tensor(bottom_codes, dtype=torch.long)

def __init__(self, top_codes, bottom_codes, for_transformer=False, hier=None):
if for_transformer:
if hier == 'top':
self.top_codes = np.hstack([np.ones((top_codes.shape[0], 1), dtype=np.int64) * 512,top_codes.reshape(top_codes.shape[0],-1), np.ones((top_codes.shape[0], 1), dtype=np.int64) * 513])
print(self.top_codes.shape), print(self.top_codes[0])
else:
self.top_codes = np.hstack([top_codes.reshape(top_codes.shape[0],-1)+2, np.ones((top_codes.shape[0], 1))])
self.bottom_codes = np.hstack([bottom_codes.reshape(bottom_codes.shape[0],-1)+514, np.ones((bottom_codes.shape[0], 1))])
self.top_codes = torch.tensor(self.top_codes, dtype=torch.long)
self.bottom_codes = torch.tensor(self.bottom_codes, dtype=torch.long)
else:
self.top_codes = torch.tensor(top_codes, dtype=torch.long)
self.bottom_codes = torch.tensor(bottom_codes, dtype=torch.long)
self.hier = hier
def __len__(self):
return len(self.top_codes)

def __getitem__(self, index):
if self.hier == 'top':
return {'input_ids': self.top_codes[index], 'labels': self.top_codes[index]}
elif self.hier == 'bottom':
return {'input_ids': self.top_codes[index], 'labels': self.bottom_codes[index]}
return self.top_codes[index], self.bottom_codes[index]
8 changes: 5 additions & 3 deletions data_utils/prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,16 @@ def load_images(img_size, validation_ratio, test_ratio, batch_size, dataset_name

return train_dataloader, validation_dataloader, test_dataloader

def load_latent_code(validation_ratio, batch_size):
def load_latent_code(validation_ratio, batch_size, for_transformer=False, hier = None):
top_codes = np.load("t_codes.npy").astype(np.int64)
bottom_codes = np.load("b_codes.npy").astype(np.int64)
dataset = latentDataset(top_codes, bottom_codes)
dataset = latentDataset(top_codes, bottom_codes, for_transformer=for_transformer, hier=hier)
dataset_size = len(dataset)
validation_size = int(validation_ratio * dataset_size)
train_size = dataset_size - validation_size
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=torch.Generator().manual_seed(42))
train_dataset, validation_dataset = random_split(dataset, [train_size, validation_size], generator=torch.Generator().manual_seed(43))
if batch_size == 0:
return train_dataset, validation_dataset
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)
return train_dataloader, val_dataloader
20 changes: 11 additions & 9 deletions extract_code.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from data_utils.prepare_data import load_images
from autoencoders.VQ_VAE.VQ_VAE import VQ_VAE,VQ_VAE2
from trainers.autoencoder_trainer import VQVAE_Trainer
from tqdm import tqdm
from yaml import safe_load
import argparse
import torch
Expand All @@ -12,7 +13,7 @@ def extract(model, dataloader):
t_codes = []
b_codes = []
with torch.no_grad():
for data in dataloader:
for data in tqdm(dataloader):
b_code, t_code = model.encode_to_id(data.to(device))
t_code = t_code.cpu().numpy()
b_code = b_code.cpu().numpy()
Expand All @@ -26,15 +27,15 @@ def extract(model, dataloader):

if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument('--config', type=str, default="./configs/default.yaml", help="Path to the config file")
argparser.add_argument('--config', type=str, default="./configs/default_all256.yaml", help="Path to the config file")
args = argparser.parse_args()
config = safe_load(open(args.config))
train_dataloader, validation_dataloader, test_dataloader = load_images(
config['img_size'],
config['validation_ratio'],
config['test_ratio'],
config['batch_size'],
config['dataset']
config['img_size'],
config['validation_ratio'],
config['test_ratio'],
config['batch_size'],
'both'
)

vqvae_config = config['VQ-VAE']
Expand All @@ -44,7 +45,8 @@ def extract(model, dataloader):
vqvae_config['latent_dimension'],
vqvae_config['kernel_sizes'],
vqvae_config['res_layers'],
vqvae_config['code_book_size']
vqvae_config['code_book_size'],
vqvae_config['lower_bound_factor']
)
else:
vqvae_config['version'] = 1
Expand All @@ -59,7 +61,7 @@ def extract(model, dataloader):
vqvae_config['device'] = device


checkpoint = torch.load('./results/best_model.pth')
checkpoint = torch.load('./results/best_model_exp22.pth')
vqvae.load_state_dict(checkpoint)
vqvae.to(device)
extract(vqvae, train_dataloader)
Loading

0 comments on commit d504604

Please sign in to comment.