Skip to content

Commit

Permalink
add prefix to model path
Browse files Browse the repository at this point in the history
  • Loading branch information
deforum committed Feb 9, 2023
1 parent dbb6417 commit 2a0ee56
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def train():
def training():
global train_from
prepare_training_data(root_folder,database_file,train_from)
train_predictor(root_folder,train_from)
train_predictor(root_folder,database_file,train_from)
predict_score(root_folder,database_file,train_from)
validate_prediction(root_folder,database_file,train_from)
return redirect('/')
Expand Down
3 changes: 2 additions & 1 deletion predict_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@


def predict_score(root_folder, database_file, train_from, clip_model="ViT-L/14"):
prefix = database_file.split(".")[0]
path = pathlib.Path(root_folder)
database_path = path / database_file
database = pd.read_csv(database_path)
model = MLP(768) # CLIP embedding dim is 768 for CLIP ViT L 14
model_name = f"linear_predictor_{clip_model.replace('/', '').lower()}_{train_from}_mse.pth"
model_name = f"{prefix}_linear_predictor_{clip_model.replace('/', '').lower()}_{train_from}_mse.pth"
s = torch.load(path / model_name) # load the model you trained previously or the model available in this repo
model.load_state_dict(s)
model.to("cuda")
Expand Down
9 changes: 5 additions & 4 deletions train_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,13 @@ def configure_optimizers(self):
return optimizer


def train_predictor(root_folder, train_from, clip_model="ViT-L/14", val_percentage=0.05, epochs=50, batch_size=256):
def train_predictor(root_folder, database_file, train_from, clip_model="ViT-L/14", val_percentage=0.05, epochs=50, batch_size=256):

prefix = database_file.split(".")[0]
out_path = pathlib.Path(root_folder)
x_out = f"x_{clip_model.replace('/', '').lower()}_ebeddings.npy"
y_out = f"y_{train_from}.npy"
save_name = f"linear_predictor_{clip_model.replace('/', '').lower()}_{train_from}_mse.pth"
x_out = f"{prefix}_x_{clip_model.replace('/', '').lower()}_ebeddings.npy"
y_out = f"{prefix}_y_{train_from}.npy"
save_name = f"{prefix}_linear_predictor_{clip_model.replace('/', '').lower()}_{train_from}_mse.pth"

x = np.load(out_path / x_out)
y = np.load(out_path / y_out)
Expand Down

0 comments on commit 2a0ee56

Please sign in to comment.