Skip to content

Commit

Permalink
Merge branch 'merge-0-9-0' of https://github.com/AntreasAntoniou/GATE
Browse files Browse the repository at this point in the history
…into merge-0-9-0
  • Loading branch information
AntreasAntoniou committed Dec 14, 2023
2 parents accfaeb + acf25c6 commit 9248668
Showing 1 changed file with 43 additions and 16 deletions.
59 changes: 43 additions & 16 deletions gate/menu/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,23 @@ def run_experiments(
"med-class": medical_image_classification_config,
"image-seg": image_segmentation_config,
"image-text": image_text_zero_shot_classification_config,
# "acdc": acdc_config,
# "md": md_config,
"acdc": acdc_config,
"md": md_config,
"rr": rr_config,
"rr-mm": rr_mm_config,
"video-class": video_classification_config,
"all": {
**image_classification_config,
**few_shot_learning_config,
**medical_image_classification_config,
**image_segmentation_config,
**image_text_zero_shot_classification_config,
**acdc_config,
**md_config,
**rr_config,
**rr_mm_config,
**video_classification_config,
},
}

if "+" in experiment_type:
Expand All @@ -205,21 +217,36 @@ def run_experiments(
logger.error("Invalid experiment type selected.")
return

elif experiment_type == "all":
for config in experiment_configs.values():
experiment_dict.update(
generate_commands(
prefix=prefix,
seed_list=seed_list,
experiment_config=config,
num_workers=num_workers,
accelerate_launch_path=accelerate_launch_path,
gate_run_path=gate_run_path,
gpu_ids=gpu_ids,
train_iters=train_iters,
evaluate_every_n_steps=evaluate_every_n_steps,
)
if "~" in experiment_type:
base_experiment, *removed_experiments = experiment_type.split("~")
if base_experiment not in experiment_configs:
logger.error(
f"Invalid base experiment type {base_experiment} selected."
)
return
for removed_experiment in removed_experiments:
if removed_experiment not in experiment_configs:
logger.error(
f"Invalid removed experiment type {removed_experiment} selected."
)
return
experiment_configs_adjusted = {
k: v
for k, v in experiment_configs.items()
if k not in removed_experiments
}
experiment_dict = generate_commands(
prefix=prefix,
seed_list=seed_list,
experiment_config=experiment_configs_adjusted[base_experiment],
num_workers=num_workers,
accelerate_launch_path=accelerate_launch_path,
gate_run_path=gate_run_path,
gpu_ids=gpu_ids,
train_iters=train_iters,
evaluate_every_n_steps=evaluate_every_n_steps,
)

else:
if experiment_type in experiment_configs:
experiment_dict = generate_commands(
Expand Down

0 comments on commit 9248668

Please sign in to comment.