Skip to content

Commit

Permalink
Minimal Build for On-Device Training (microsoft#16326)
Browse files Browse the repository at this point in the history
🛠️ __Changes in this pull request:__

This pull request introduces two significant changes to the project:

- Changing on device training checkpoint format: The current
implementation stores the on device training checkpoint as a sequence of
tensors in multiple files inside a checkpoint folder, which can be
inefficient in terms of storage and performance. In this PR, I have
modified the checkpoint format to utilize the flatbuffer table to save
the checkpoint to a single file, providing a more compact and efficient
representation. The changes around this are twofold:
- Add the checkpoint flatbuffer schema that will generate the necessary
checkpoint source files.
- Update the checkpoint saving and loading functionality to use the new
format.

- Adding support for onnxruntime minimal build: To support scenarios
where binary size is a constraint, I made changes to ensure that the
training build can work well with the minimal build.

🔍 __Open Issues:__
- In order to extract the optimizer type, the existing implementation
re-loaded the onnx optimizer model and parsed it. This is no longer
possible, since the model format can either be onnx or ort. One idea is
to do the same for ort format optimizer model. This needs some
investigation.
- Changes to the offline tooling to generate ort format training
artifacts.
- End-to-end training example showcasing the use of the minimal training
build.
- Add support for export model for inferencing in a minimal build.
  • Loading branch information
baijumeswani authored Jun 22, 2023
1 parent 97f4484 commit 10ba1e2
Show file tree
Hide file tree
Showing 46 changed files with 2,232 additions and 755 deletions.
2 changes: 1 addition & 1 deletion .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ exclude_patterns = [
'java/**', # FIXME: Enable clang-format for java
'js/**',
'onnxruntime/contrib_ops/cuda/bert/tensorrt_fused_multihead_attention/**', # Contains data chunks
'onnxruntime/core/flatbuffers/schema/ort.fbs.h', # Generated code
'onnxruntime/core/flatbuffers/schema/*.fbs.h', # Generated code
'onnxruntime/core/graph/contrib_ops/quantization_defs.cc',
'onnxruntime/core/mlas/**', # Contains assembly code
'winml/**', # FIXME: Enable clang-format for winml
Expand Down
6 changes: 2 additions & 4 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ else()
)
endif()

if (onnxruntime_ENABLE_TRAINING_APIS)
# we need optimizers for both full build as well as training api only build.
# Using onnxruntime_ENABLE_TRAINING_APIS since it is always ON in a full training build.
if (onnxruntime_ENABLE_TRAINING)
list(APPEND onnxruntime_optimizer_src_patterns
"${ORTTRAINING_SOURCE_DIR}/core/optimizer/*.h"
"${ORTTRAINING_SOURCE_DIR}/core/optimizer/*.cc"
Expand All @@ -101,7 +99,7 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src

onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})
if (onnxruntime_ENABLE_TRAINING_APIS)
if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
endif()
add_dependencies(onnxruntime_optimizer ${onnxruntime_EXTERNAL_DEPENDENCIES})
Expand Down
56 changes: 32 additions & 24 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,18 @@ file(GLOB onnxruntime_test_training_src
"${ORTTRAINING_SOURCE_DIR}/test/distributed/*.cc"
)

if (onnxruntime_ENABLE_TRAINING_APIS)
file(GLOB onnxruntime_test_training_api_src
"${ORTTRAINING_SOURCE_DIR}/test/training_api/common/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/common/*.h"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/core/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/core/*.h"
)
# TODO (baijumeswani): Remove the minimal build check here.
# The training api tests should be runnable even on a minimal build.
# This requires converting all the *.onnx files to ort format.
if (NOT onnxruntime_MINIMAL_BUILD)
if (onnxruntime_ENABLE_TRAINING_APIS)
file(GLOB onnxruntime_test_training_api_src
"${ORTTRAINING_SOURCE_DIR}/test/training_api/common/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/common/*.h"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/core/*.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_api/core/*.h"
)
endif()
endif()

if(WIN32)
Expand Down Expand Up @@ -370,26 +375,29 @@ if (onnxruntime_USE_CANN)
list(APPEND onnxruntime_test_providers_src ${onnxruntime_test_providers_cann_src})
endif()

if (onnxruntime_ENABLE_TRAINING_APIS)
file(GLOB_RECURSE orttraining_test_trainingops_cpu_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/compare_provider_test_utils.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/function_op_test_utils.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cpu/*"
)

if (NOT onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM orttraining_test_trainingops_cpu_src
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cpu/tensorboard/summary_op_test.cc"
# Disable training ops test for minimal build as a lot of these depend on loading an onnx model.
if (NOT onnxruntime_MINIMAL_BUILD)
if (onnxruntime_ENABLE_TRAINING_CORE)
file(GLOB_RECURSE orttraining_test_trainingops_cpu_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/compare_provider_test_utils.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/function_op_test_utils.cc"
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cpu/*"
)
endif()

list(APPEND onnxruntime_test_providers_src ${orttraining_test_trainingops_cpu_src})
if (NOT onnxruntime_ENABLE_TRAINING)
list(REMOVE_ITEM orttraining_test_trainingops_cpu_src
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cpu/tensorboard/summary_op_test.cc"
)
endif()

if (onnxruntime_USE_CUDA OR onnxruntime_USE_ROCM)
file(GLOB_RECURSE orttraining_test_trainingops_cuda_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cuda/*"
)
list(APPEND onnxruntime_test_providers_src ${orttraining_test_trainingops_cuda_src})
list(APPEND onnxruntime_test_providers_src ${orttraining_test_trainingops_cpu_src})

if (onnxruntime_USE_CUDA OR onnxruntime_USE_ROCM)
file(GLOB_RECURSE orttraining_test_trainingops_cuda_src CONFIGURE_DEPENDS
"${ORTTRAINING_SOURCE_DIR}/test/training_ops/cuda/*"
)
list(APPEND onnxruntime_test_providers_src ${orttraining_test_trainingops_cuda_src})
endif()
endif()
endif()

Expand Down
5 changes: 3 additions & 2 deletions java/src/test/java/ai/onnxruntime/TrainingTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -168,16 +168,17 @@ public void testSaveCheckpoint() throws IOException, OrtException {
String trainingPath = TestHelpers.getResourcePath("/training_model.onnx").toString();

Path tmpPath = Files.createTempDirectory("ort-java-training-test");
Path tmpCheckpointPath = tmpPath.resolve("checkpoint.ckpt");
try {
try (OrtTrainingSession trainingSession =
env.createTrainingSession(checkpointPath, trainingPath, null, null)) {

// Save checkpoint
trainingSession.saveCheckpoint(tmpPath, false);
trainingSession.saveCheckpoint(tmpCheckpointPath, false);
}

try (OrtTrainingSession trainingSession =
env.createTrainingSession(tmpPath.toString(), trainingPath, null, null)) {
env.createTrainingSession(tmpCheckpointPath.toString(), trainingPath, null, null)) {
// Load saved checkpoint into new session and run train step
runTrainStep(trainingSession);
}
Expand Down
27 changes: 27 additions & 0 deletions onnxruntime/core/flatbuffers/checkpoint_version.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {

// The versions below highlight the checkpoint format version history.
// Everytime the checkpoint format changes, the version should be incremented
// and the changes should be documented here and in the
// onnxruntime/core/flatbuffers/schema/README.md file.
// Version 1: Introduces the On-Device Training Checkpoint format
// The format includes support for the ModuleState (stores the module parameters), OptimizerGroups
// (stores the optimizer states), and PropertyBag
// (stores custom user properties with support for int64, float and strings).
constexpr const int kCheckpointVersion = 1;

/**
* @brief Check if the given checkpoint version is supported in this build
* @param checkpoint_version The checkpoint version to check
* @return true if the checkpoint version is supported, false otherwise
*/
inline constexpr bool IsCheckpointVersionSupported(const int checkpoint_version) {
return kCheckpointVersion == checkpoint_version;
}

} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/core/flatbuffers/flatbuffers_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ onnxruntime::common::Status SaveValueInfoOrtFormat(

void LoadStringFromOrtFormat(std::string& dst, const flatbuffers::String* fbs_string);

// This macro is to be used on a protobuf message (protobug_msg), which will not create an empty string field (str_field)
// This macro is to be used on a protobuf message (protobuf_msg), which will not create an empty string field (str_field)
// if fbs_string is null
#define LOAD_STR_FROM_ORT_FORMAT(protobug_msg, str_field, fbs_string) \
#define LOAD_STR_FROM_ORT_FORMAT(protobuf_msg, str_field, fbs_string) \
{ \
if (fbs_string) \
protobug_msg.set_##str_field(fbs_string->c_str()); \
protobuf_msg.set_##str_field(fbs_string->c_str()); \
}

onnxruntime::common::Status LoadValueInfoOrtFormat(
Expand Down
87 changes: 87 additions & 0 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/Checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: fbs

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class Checkpoint(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsCheckpoint(cls, buf, offset):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Checkpoint()
x.Init(buf, n + offset)
return x

@classmethod
def CheckpointBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)

# Checkpoint
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# Checkpoint
def Version(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0

# Checkpoint
def ModuleState(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
from ort_flatbuffers_py.fbs.ModuleState import ModuleState
obj = ModuleState()
obj.Init(self._tab.Bytes, x)
return obj
return None

# Checkpoint
def OptimizerGroups(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
from ort_flatbuffers_py.fbs.OptimizerGroup import OptimizerGroup
obj = OptimizerGroup()
obj.Init(self._tab.Bytes, x)
return obj
return None

# Checkpoint
def OptimizerGroupsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.VectorLen(o)
return 0

# Checkpoint
def OptimizerGroupsIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
return o == 0

# Checkpoint
def PropertyBag(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
from ort_flatbuffers_py.fbs.PropertyBag import PropertyBag
obj = PropertyBag()
obj.Init(self._tab.Bytes, x)
return obj
return None

def CheckpointStart(builder): builder.StartObject(4)
def CheckpointAddVersion(builder, version): builder.PrependInt32Slot(0, version, 0)
def CheckpointAddModuleState(builder, moduleState): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(moduleState), 0)
def CheckpointAddOptimizerGroups(builder, optimizerGroups): builder.PrependUOffsetTRelativeSlot(2, flatbuffers.number_types.UOffsetTFlags.py_type(optimizerGroups), 0)
def CheckpointStartOptimizerGroupsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def CheckpointAddPropertyBag(builder, propertyBag): builder.PrependUOffsetTRelativeSlot(3, flatbuffers.number_types.UOffsetTFlags.py_type(propertyBag), 0)
def CheckpointEnd(builder): return builder.EndObject()
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: fbs

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class FloatProperty(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsFloatProperty(cls, buf, offset):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = FloatProperty()
x.Init(buf, n + offset)
return x

@classmethod
def FloatPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)

# FloatProperty
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# FloatProperty
def Name(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

# FloatProperty
def Value(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Float32Flags, o + self._tab.Pos)
return 0.0

def FloatPropertyStart(builder): builder.StartObject(2)
def FloatPropertyAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def FloatPropertyAddValue(builder, value): builder.PrependFloat32Slot(1, value, 0.0)
def FloatPropertyEnd(builder): return builder.EndObject()
44 changes: 44 additions & 0 deletions onnxruntime/core/flatbuffers/ort_flatbuffers_py/fbs/IntProperty.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# automatically generated by the FlatBuffers compiler, do not modify

# namespace: fbs

import flatbuffers
from flatbuffers.compat import import_numpy
np = import_numpy()

class IntProperty(object):
__slots__ = ['_tab']

@classmethod
def GetRootAsIntProperty(cls, buf, offset):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = IntProperty()
x.Init(buf, n + offset)
return x

@classmethod
def IntPropertyBufferHasIdentifier(cls, buf, offset, size_prefixed=False):
return flatbuffers.util.BufferHasIdentifier(buf, offset, b"\x4F\x44\x54\x43", size_prefixed=size_prefixed)

# IntProperty
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)

# IntProperty
def Name(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None

# IntProperty
def Value(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int64Flags, o + self._tab.Pos)
return 0

def IntPropertyStart(builder): builder.StartObject(2)
def IntPropertyAddName(builder, name): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def IntPropertyAddValue(builder, value): builder.PrependInt64Slot(1, value, 0)
def IntPropertyEnd(builder): return builder.EndObject()
Loading

0 comments on commit 10ba1e2

Please sign in to comment.