Skip to content

Commit

Permalink
Copy to buffer for both trainable as well as non trainable parameters (
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani authored Aug 10, 2023
1 parent 555f346 commit f17efb5
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 21 deletions.
8 changes: 4 additions & 4 deletions orttraining/orttraining/python/orttraining_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,12 +1006,12 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn
ORT_THROW_IF_ERROR(model->LazyResetGrad());
})
.def("copy_parameters_to_buffer",
[](onnxruntime::training::api::Module* model, OrtValue& output) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output));
[](onnxruntime::training::api::Module* model, OrtValue& output, bool trainable_only) -> void {
ORT_THROW_IF_ERROR(model->CopyParametersToBuffer(output, trainable_only));
})
.def("copy_buffer_to_parameters",
[](onnxruntime::training::api::Module* model, OrtValue& input) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input));
[](onnxruntime::training::api::Module* model, OrtValue& input, bool trainable_only) -> void {
ORT_THROW_IF_ERROR(model->CopyBufferToParameters(input, trainable_only));
})
.def("get_parameters_size",
[](onnxruntime::training::api::Module* model, bool trainable_only) -> size_t {
Expand Down
8 changes: 4 additions & 4 deletions orttraining/orttraining/python/training/api/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,11 @@ def get_contiguous_parameters(self, trainable_only: bool = False) -> OrtValue:
self._device_type,
self._device.device_id(),
)._ortvalue
self._model.copy_parameters_to_buffer(parameters)
self._model.copy_parameters_to_buffer(parameters, trainable_only)

return parameters

def get_parameters_size(self, trainable_only: bool = False) -> int:
def get_parameters_size(self, trainable_only: bool = True) -> int:
"""Returns the size of the parameters.
Args:
Expand All @@ -175,13 +175,13 @@ def get_parameters_size(self, trainable_only: bool = False) -> int:
"""
return self._model.get_parameters_size(trainable_only)

def copy_buffer_to_parameters(self, buffer: OrtValue) -> None:
def copy_buffer_to_parameters(self, buffer: OrtValue, trainable_only: bool = True) -> None:
"""Copies the OrtValue buffer to the training session parameters.
Args:
buffer: The OrtValue buffer to copy to the training session parameters.
"""
self._model.copy_buffer_to_parameters(buffer)
self._model.copy_buffer_to_parameters(buffer, trainable_only)

def export_model_for_inferencing(
self, inference_model_uri: str | os.PathLike, graph_output_names: list[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,20 @@ def build(self, output_name):
return self.loss(output_name)


def _create_training_artifacts(artifact_directory: str | os.PathLike):
def _create_training_artifacts(
artifact_directory: str | os.PathLike,
requires_grad: list[str] | None = None,
frozen_params: list[str] | None = None,
):
device = "cpu"
batch_size, input_size, hidden_size, output_size = 64, 784, 500, 10
pt_model, onnx_model = _get_models(device, batch_size, input_size, hidden_size, output_size)

requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]
if requires_grad is None:
requires_grad = [name for name, param in pt_model.named_parameters() if param.requires_grad]

if frozen_params is None:
frozen_params = [name for name, param in pt_model.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
onnx_model,
Expand Down Expand Up @@ -69,7 +76,6 @@ def test_train_step():
# Create Checkpoint State.
state = CheckpointState.load_checkpoint(checkpoint_file_path)
# Create a Module.
print(training_model_file_path)
model = Module(training_model_file_path, state)
model.train()
ort_loss = model(inputs, labels)
Expand Down Expand Up @@ -234,7 +240,8 @@ def test_training_module_checkpoint():
assert np.array_equal(old_flatten_params.numpy(), new_params.numpy())


def test_copy_buffer_to_parameters():
@pytest.mark.parametrize("trainable_only", [True, False])
def test_copy_buffer_to_parameters(trainable_only):
# Generating random data for testing.
inputs = torch.randn(64, 784).numpy()
labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy()
Expand All @@ -246,31 +253,33 @@ def test_copy_buffer_to_parameters():
_,
optimizer_model_file_path,
_,
) = _create_training_artifacts(temp_dir)
) = _create_training_artifacts(
temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"]
)
state = CheckpointState.load_checkpoint(checkpoint_file_path)

# Create a Module and Optimizer.
model = Module(training_model_file_path, state)
optimizer = Optimizer(optimizer_model_file_path, model)

# Keep a copy of the parameters.
old_output_params = model.get_contiguous_parameters()
old_output_params = model.get_contiguous_parameters(trainable_only=trainable_only)

# Run a Training Step.
model.train()
model(inputs, labels)
optimizer.step()

# Get the new parameters.
output_params = model.get_contiguous_parameters()
output_params = model.get_contiguous_parameters(trainable_only=trainable_only)
# Make sure old params are different from new params.
assert not np.array_equal(old_output_params.numpy(), output_params.numpy())

# Copy the old parameters to the model.
model.copy_buffer_to_parameters(old_output_params)
model.copy_buffer_to_parameters(old_output_params, trainable_only=trainable_only)

# Get the saved parameters.
saved_params = model.get_contiguous_parameters()
saved_params = model.get_contiguous_parameters(trainable_only=trainable_only)

# Make sure the saved parameters are the same as the old parameters.
assert np.array_equal(old_output_params.numpy(), saved_params.numpy())
Expand Down Expand Up @@ -369,8 +378,9 @@ def test_get_input_output_names():
# Create a Module.
model = Module(training_model_file_path, state, eval_model_file_path)

assert model.input_names() == ["input-0", "labels"]
assert model.output_names() == ["onnx::loss::128"]
training_model = onnx.load(training_model_file_path)
assert model.input_names() == [input.name for input in training_model.graph.input][:2]
assert model.output_names() == [output.name for output in training_model.graph.output][:1]


def test_ort_custom_ops():
Expand Down Expand Up @@ -506,7 +516,6 @@ def test_train_step_with_ort_values():
# Create Checkpoint State.
state = CheckpointState.load_checkpoint(checkpoint_file_path)
# Create a Module.
print(training_model_file_path)
model = Module(training_model_file_path, state)
model.train()
ort_loss = model(inputs, labels)
Expand Down

0 comments on commit f17efb5

Please sign in to comment.