Skip to content

Commit

Permalink
wandb
Browse files Browse the repository at this point in the history
  • Loading branch information
Yusuke Uchida committed Aug 18, 2020
1 parent 9f7551c commit 67d7eb7
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 5 deletions.
5 changes: 4 additions & 1 deletion src/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@ model:
train:
optimizer_name: adam
lr: 0.001
epochs: 30
epochs: 30

wandb:
project: null
16 changes: 12 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,14 @@

@hydra.main(config_path="src/config.yaml")
def main(cfg):
if cfg.wandb.project:
import wandb
from wandb.keras import WandbCallback
wandb.init(project="age-gender-estimation")
callbacks = [WandbCallback()]
else:
callbacks = []

csv_path = Path(to_absolute_path(__file__)).parent.joinpath("meta", f"{cfg.data.db}.csv")
df = pd.read_csv(str(csv_path))
train, val = train_test_split(df, random_state=42, test_size=0.1)
Expand All @@ -30,17 +38,17 @@ def main(cfg):

checkpoint_dir = Path(to_absolute_path(__file__)).parent.joinpath("checkpoint")
checkpoint_dir.mkdir(exist_ok=True)
callbacks = [
callbacks.extend([
LearningRateScheduler(schedule=scheduler),
ModelCheckpoint(str(checkpoint_dir) + "/weights.{epoch:02d}-{val_loss:.2f}.hdf5",
monitor="val_loss",
verbose=1,
save_best_only=True,
mode="auto")
]
])

hist = model.fit(train_gen, epochs=cfg.train.epochs, callbacks=callbacks, validation_data=val_gen,
workers=multiprocessing.cpu_count())
model.fit(train_gen, epochs=cfg.train.epochs, callbacks=callbacks, validation_data=val_gen,
workers=multiprocessing.cpu_count())


if __name__ == '__main__':
Expand Down

0 comments on commit 67d7eb7

Please sign in to comment.