forked from microsoft/onnxruntime
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minimal Build for On-Device Training (microsoft#16326)
🛠️ __Changes in this pull request:__ This pull request introduces two significant changes to the project: - Changing on device training checkpoint format: The current implementation stores the on device training checkpoint as a sequence of tensors in multiple files inside a checkpoint folder, which can be inefficient in terms of storage and performance. In this PR, I have modified the checkpoint format to utilize the flatbuffer table to save the checkpoint to a single file, providing a more compact and efficient representation. The changes around this are twofold: - Add the checkpoint flatbuffer schema that will generate the necessary checkpoint source files. - Update the checkpoint saving and loading functionality to use the new format. - Adding support for onnxruntime minimal build: To support scenarios where binary size is a constraint, I made changes to ensure that the training build can work well with the minimal build. 🔍 __Open Issues:__ - In order to extract the optimizer type, the existing implementation re-loaded the onnx optimizer model and parsed it. This is no longer possible, since the model format can either be onnx or ort. One idea is to do the same for ort format optimizer model. This needs some investigation. - Changes to the offline tooling to generate ort format training artifacts. - End-to-end training example showcasing the use of the minimal training build. - Add support for export model for inferencing in a minimal build.
- Loading branch information
1 parent
97f4484
commit 10ba1e2
Showing
46 changed files
with
2,232 additions
and
755 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
|
||
namespace onnxruntime { | ||
|
||
// The versions below highlight the checkpoint format version history. | ||
// Everytime the checkpoint format changes, the version should be incremented | ||
// and the changes should be documented here and in the | ||
// onnxruntime/core/flatbuffers/schema/README.md file. | ||
// Version 1: Introduces the On-Device Training Checkpoint format | ||
// The format includes support for the ModuleState (stores the module parameters), OptimizerGroups | ||
// (stores the optimizer states), and PropertyBag | ||
// (stores custom user properties with support for int64, float and strings). | ||
constexpr const int kCheckpointVersion = 1; | ||
|
||
/** | ||
* @brief Check if the given checkpoint version is supported in this build | ||
* @param checkpoint_version The checkpoint version to check | ||
* @return true if the checkpoint version is supported, false otherwise | ||
*/ | ||
inline constexpr bool IsCheckpointVersionSupported(const int checkpoint_version) { | ||
return kCheckpointVersion == checkpoint_version; | ||
} | ||
|
||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
87 changes: 87 additions & 0 deletions
87
onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: fbs | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class Checkpoint(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAsCheckpoint(cls, buf, offset): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = Checkpoint() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False): | ||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed) | ||
|
||
# Checkpoint | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# Checkpoint | ||
def Version(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos) | ||
return 0 | ||
|
||
# Checkpoint | ||
def ModuleState(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) | ||
if o != 0: | ||
x = self._tab.Indirect(o + self._tab.Pos) | ||
from ort_flatbuffers_py.fbs.ModuleState import ModuleState | ||
obj = ModuleState() | ||
obj.Init(self._tab.Bytes, x) | ||
return obj | ||
return None | ||
|
||
# Checkpoint | ||
def OptimizerGroups(self, j): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) | ||
if o != 0: | ||
x = self._tab.Vector(o) | ||
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 | ||
x = self._tab.Indirect(x) | ||
from ort_flatbuffers_py.fbs.OptimizerGroup import OptimizerGroup | ||
obj = OptimizerGroup() | ||
obj.Init(self._tab.Bytes, x) | ||
return obj | ||
return None | ||
|
||
# Checkpoint | ||
def OptimizerGroupsLength(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) | ||
if o != 0: | ||
return self._tab.VectorLen(o) | ||
return 0 | ||
|
||
# Checkpoint | ||
def OptimizerGroupsIsNone(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8)) | ||
return o == 0 | ||
|
||
# Checkpoint | ||
def PropertyBag(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10)) | ||
if o != 0: | ||
x = self._tab.Indirect(o + self._tab.Pos) | ||
from ort_flatbuffers_py.fbs.PropertyBag import PropertyBag | ||
obj = PropertyBag() | ||
obj.Init(self._tab.Bytes, x) | ||
return obj | ||
return None | ||
|
||
def CheckpointStart(builder): builder.StartObject(4) | ||
def CheckpointAddVersion(builder, version): builder.PrependInt32Slot(0, version, 0) | ||
def CheckpointAddModuleState(builder, moduleState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0) | ||
def CheckpointAddOptimizerGroups(builder, optimizerGroups): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0) | ||
def CheckpointStartOptimizerGroupsVector(builder, numElems): return builder.StartVector(4, numElems, 4) | ||
def CheckpointAddPropertyBag(builder, propertyBag): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0) | ||
def CheckpointEnd(builder): return builder.EndObject() |
44 changes: 44 additions & 0 deletions
44
onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/FloatProperty.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: fbs | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class FloatProperty(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAsFloatProperty(cls, buf, offset): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = FloatProperty() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def FloatPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False): | ||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed) | ||
|
||
# FloatProperty | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# FloatProperty | ||
def Name(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.String(o + self._tab.Pos) | ||
return None | ||
|
||
# FloatProperty | ||
def Value(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) | ||
if o != 0: | ||
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos) | ||
return 0.0 | ||
|
||
def FloatPropertyStart(builder): builder.StartObject(2) | ||
def FloatPropertyAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) | ||
def FloatPropertyAddValue(builder, value): builder.PrependFloat32Slot(1, value, 0.0) | ||
def FloatPropertyEnd(builder): return builder.EndObject() |
44 changes: 44 additions & 0 deletions
44
onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/IntProperty.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
# automatically generated by the FlatBuffers compiler, do not modify | ||
|
||
# namespace: fbs | ||
|
||
import flatbuffers | ||
from flatbuffers.compat import import_numpy | ||
np = import_numpy() | ||
|
||
class IntProperty(object): | ||
__slots__ = ['_tab'] | ||
|
||
@classmethod | ||
def GetRootAsIntProperty(cls, buf, offset): | ||
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset) | ||
x = IntProperty() | ||
x.Init(buf, n + offset) | ||
return x | ||
|
||
@classmethod | ||
def IntPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False): | ||
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed) | ||
|
||
# IntProperty | ||
def Init(self, buf, pos): | ||
self._tab = flatbuffers.table.Table(buf, pos) | ||
|
||
# IntProperty | ||
def Name(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) | ||
if o != 0: | ||
return self._tab.String(o + self._tab.Pos) | ||
return None | ||
|
||
# IntProperty | ||
def Value(self): | ||
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6)) | ||
if o != 0: | ||
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos) | ||
return 0 | ||
|
||
def IntPropertyStart(builder): builder.StartObject(2) | ||
def IntPropertyAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) | ||
def IntPropertyAddValue(builder, value): builder.PrependInt64Slot(1, value, 0) | ||
def IntPropertyEnd(builder): return builder.EndObject() |
Oops, something went wrong.