Skip to content

Commit

Permalink
Add SGDOptimizer in the on-device training offline tooling (onnxblock) (
Browse files Browse the repository at this point in the history
microsoft#17085)

### Description
Adding SGDOptimizer to on device training onnxblock
  • Loading branch information
AdamLouly authored Aug 18, 2023
1 parent ee09a5f commit c0b6c6c
Show file tree
Hide file tree
Showing 11 changed files with 266 additions and 123 deletions.
Binary file modified onnxruntime/test/testdata/training_api/adamw.onnx
Binary file not shown.
2 changes: 1 addition & 1 deletion orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1537,7 +1537,7 @@ void RegisterTrainingOpSchemas() {
"This signal indicates if weight updates are skipped, applicable to gradient infinity check"
" in mixed precision training. ",
"T_BOOL", OpSchema::Optional)
.Output(0, "updated_flag", "Whether gradient is applied or not.", "T2")
.Output(0, "updated_flag", "Whether gradient is applied or not.", "T_BOOL")
.Output(1, "updated_weights", "Sequence of weights after optimize.", "S_WEIGHT", OpSchema::Optional)
.Output(2, "updated_momentums_1", "Sequence of momentum_1 after optimize.", "S_MOMENT", OpSchema::Optional)
.Output(3, "updated_momentums_2", "Sequence of momentum_2 after optimize.", "S_MOMENT", OpSchema::Optional)
Expand Down
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class OptimType(Enum):
"""

AdamW = 1
SGD = 2


def generate_artifacts(
Expand Down Expand Up @@ -192,7 +193,8 @@ def _export_to_ort_format(model_path, output_dir, extra_options):
logging.info("Optimizer enum provided: %s", optimizer.name)

optim_model = None
optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW}
optim_blocks = {OptimType.AdamW: onnxblock.optim.AdamW, OptimType.SGD: onnxblock.optim.SGD}

optim_block = optim_blocks[optimizer]()
with onnxblock.empty_base():
_ = optim_block(model_params)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

from onnxruntime.training.onnxblock.optim.optim import AdamW, ClipGradNorm
from onnxruntime.training.onnxblock.optim.optim import SGD, AdamW, ClipGradNorm

__all__ = ["AdamW", "ClipGradNorm"]
__all__ = ["AdamW", "ClipGradNorm", "SGD"]
Loading

0 comments on commit c0b6c6c

Please sign in to comment.