Skip to content

Commit

Permalink
A way to load and evaluate a model from a WandB run.
Browse files Browse the repository at this point in the history
  • Loading branch information
GreenWizard2015 committed Apr 10, 2024
1 parent bf4ed1e commit bfc2416
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from NN import model_from_config, model_to_architecture
from Utils import dataset_from_config
from Utils.WandBUtils import CWBRun

def validateLayersNames(model):
not_unique_layers = []
Expand All @@ -21,7 +22,13 @@ def validateLayersNames(model):

def main(args):
folder = os.path.dirname(__file__)
config = load_config(args.config, folder=folder)
if args.wandb_id:
run = CWBRun(args.wandb_id)
config = run.config
args.model = run.bestModel.pathTo()
args.no_train = True
else:
config = load_config(args.config, folder=folder)

assert "experiment" in config, "Config must contain 'experiment' key"
# store args as part of config
Expand Down Expand Up @@ -138,7 +145,8 @@ def main(args):
parser.add_argument('--wandb', type=str, help='Wandb project name (optional)')
parser.add_argument('--wandb-entity', type=str, help='Wandb entity name (optional)')
parser.add_argument('--wandb-name', type=str, help='Wandb run name (optional)')

parser.add_argument('--wandb-id', type=str, help='Wandb run id, to load and test model (optional)')

args = parser.parse_args()
if args.gpu_memory_mb: setGPUMemoryLimit(args.gpu_memory_mb)
main(args)
Expand Down

0 comments on commit bfc2416

Please sign in to comment.