Skip to content

Commit

Permalink
Add support for non cubic patch size in yucca train (#117)
Browse files Browse the repository at this point in the history
* Add support for multiple args to patchsize but manager does not use them when testing. Still not working.

* make run_training script pass patch_size arg to manager
  • Loading branch information
Drilip authored Feb 15, 2024
1 parent 0425feb commit 7952d8a
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
70 changes: 53 additions & 17 deletions yucca/documentation/guides/run_scripts_advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,72 @@ For help and all the available arguments see the output of the `-h` flag below.

```console
> yucca_train -h
usage: yucca_train [-h] [-t TASK] [-d D] [-m M] [-man MAN] [-pl PL] [--disable_logging] [--ds] [--epochs EPOCHS] [--experiment EXPERIMENT] [--loss LOSS] [--lr LR] [--mom MOM] [--new_version]
[--patch_size PATCH_SIZE] [--precision PRECISION] [--profile] [--split_idx SPLIT_IDX] [--split_data_method SPLIT_DATA_METHOD] [--split_data_param SPLIT_DATA_PARAM]
[--train_batches_per_step TRAIN_BATCHES_PER_STEP] [--val_batches_per_step VAL_BATCHES_PER_STEP]
usage: yucca_train [-h] [-t TASK] [-d D] [-m M] [-man MAN] [-pl PL]
[--batch_size BATCH_SIZE] [--disable_logging] [--ds]
[--epochs EPOCHS] [--experiment EXPERIMENT] [--loss LOSS]
[--lr LR] [--max_vram MAX_VRAM] [--mom MOM] [--new_version]
[--num_workers NUM_WORKERS]
[--patch_size PATCH_SIZE [PATCH_SIZE ...]]
[--precision PRECISION] [--profile] [--split_idx SPLIT_IDX]
[--split_data_method SPLIT_DATA_METHOD]
[--split_data_param SPLIT_DATA_PARAM]
[--train_batches_per_step TRAIN_BATCHES_PER_STEP]
[--val_batches_per_step VAL_BATCHES_PER_STEP]

options:
-h, --help show this help message and exit
-t TASK, --task TASK Name of the task used for training. The data should already be preprocessed using yucca_preprocessArgument should be of format: TaskXXX_MYTASK
-d D Dimensionality of the Model. Can be 3D or 2D. Defaults to 3D. Note that this will always be 2D if ensemble is enabled.
-m M Model Architecture. Should be one of MultiResUNet or UNet Note that this is case sensitive. Defaults to the standard UNet.
-man MAN Manager Class to be used. Defaults to the basic YuccaManager
-pl PL Plan ID to be used. This specifies which plan and preprocessed data to use for training on the given task. Defaults to the YuccaPlanne folder
-t TASK, --task TASK Name of the task used for training. The data should
already be preprocessed using yucca_preprocessArgument
should be of format: TaskXXX_MYTASK
-d D Dimensionality of the Model. Can be 3D or 2D. Defaults
to 3D. Note that this will always be 2D if ensemble is
enabled.
-m M Model Architecture. Should be one of MultiResUNet or
UNet Note that this is case sensitive. Defaults to the
standard UNet.
-man MAN Manager Class to be used. Defaults to the basic
YuccaManager
-pl PL Plan ID to be used. This specifies which plan and
preprocessed data to use for training on the given
task. Defaults to the YuccaPlanner folder
--batch_size BATCH_SIZE
Batch size to be used for training. Overrides the
batch size specified in the plan.
--disable_logging disable logging.
--ds Used to enable deep supervision
--epochs EPOCHS Used to specify the number of epochs for training. Default is 1000
--epochs EPOCHS Used to specify the number of epochs for training.
Default is 1000
--experiment EXPERIMENT
A name for the experiment being performed, with no spaces.
--loss LOSS Should only be used to employ alternative Loss Function
--lr LR Should only be used to employ alternative Learning Rate.
A name for the experiment being performed, with no
spaces.
--loss LOSS Should only be used to employ alternative Loss
Function
--lr LR Should only be used to employ alternative Learning
Rate. Format should be scientific notation e.g. 1e-4.
--max_vram MAX_VRAM
--mom MOM Should only be used to employ alternative Momentum.
--new_version Start a new version, instead of continuing from the most recent.
--patch_size PATCH_SIZE
Use your own patch_size. Example: if 32 is provided and the model is 3D we will use patch size (32, 32, 32). Can also be min, max or mean.
--new_version Start a new version, instead of continuing from the
most recent.
--num_workers NUM_WORKERS
Num workers used in the DataLoaders. By default this
will be inferred from the number of available CPUs-1
--patch_size PATCH_SIZE [PATCH_SIZE ...]
Use your own patch_size. Example: if 32 is provided
and the model is 3D we will use patch size (32, 32,
32). This patch size can be set manually by passing 32
32 32 as arguments. The argument can also be min, max
or mean.
--precision PRECISION
--profile Enable profiling.
--split_idx SPLIT_IDX
idx of splits to use for training.
--split_data_method SPLIT_DATA_METHOD
Specify splitting method. Either kfold, simple_train_val_split
Specify splitting method. Either kfold,
simple_train_val_split
--split_data_param SPLIT_DATA_PARAM
Specify the parameter for the selected split method. For KFold use an int, for simple_split use a float between 0.0-1.0.
Specify the parameter for the selected split method.
For KFold use an int, for simple_split use a float
between 0.0-1.0.
--train_batches_per_step TRAIN_BATCHES_PER_STEP
--val_batches_per_step VAL_BATCHES_PER_STEP
```
Expand Down
14 changes: 10 additions & 4 deletions yucca/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def main():
"-pl",
help="Plan ID to be used. "
"This specifies which plan and preprocessed data to use for training "
"on the given task. Defaults to the YuccaPlanne folder",
"on the given task. Defaults to the YuccaPlanner folder",
default="YuccaPlanner",
)

Expand Down Expand Up @@ -79,8 +79,9 @@ def main():
)
parser.add_argument(
"--patch_size",
nargs="+",
type=str,
help="Use your own patch_size. Example: if 32 is provided and the model is 3D we will use patch size (32, 32, 32). Can also be min, max or mean.",
help="Use your own patch_size. Example: if 32 is provided and the model is 3D we will use patch size (32, 32, 32). This patch size can be set manually by passing 32 32 32 as arguments. The argument can also be min, max or mean.",
default=None,
)
parser.add_argument("--precision", type=str, default="bf16-mixed")
Expand Down Expand Up @@ -130,8 +131,12 @@ def main():
val_batches_per_step = args.val_batches_per_step

if patch_size is not None:
if patch_size not in ["mean", "max", "min"]:
patch_size = (int(patch_size),) * 3 if dimensions == "3D" else (int(patch_size),) * 2
if len(patch_size) == 1:
patch_size = patch_size[0]
if patch_size not in ["mean", "max", "min"]:
patch_size = (int(patch_size),) * 3 if dimensions == "3D" else (int(patch_size),) * 2
else:
patch_size = tuple(int(n) for n in patch_size)

kwargs = {}

Expand Down Expand Up @@ -166,6 +171,7 @@ def main():
model_name=model_name,
momentum=momentum,
num_workers=num_workers,
patch_size=patch_size,
planner=planner,
precision=precision,
profile=profile,
Expand Down

0 comments on commit 7952d8a

Please sign in to comment.