Skip to content

Commit

Permalink
rewrite master weight for amp training (PaddlePaddle#59052)
Browse files Browse the repository at this point in the history
* rewrite master weight for amp training

* some optimizers does not support master weight
  • Loading branch information
zhangting2020 authored Dec 4, 2023
1 parent c32c14f commit ea5f3c5
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 4 deletions.
1 change: 1 addition & 0 deletions python/paddle/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,7 @@ def __init__(
self._auxiliary_vars = {}
self._already_create_accumulater = set()

self._master_weights = {}
# create master gradients' states
self._create_master_grad_states()

Expand Down
9 changes: 8 additions & 1 deletion python/paddle/static/amp/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,12 @@ def _add_cast_ops_to_startup_program(self, startup_program):
self._to_fp16_var_names = None

def amp_init(
self, place, scope=None, test_program=None, use_fp16_test=False
self,
place,
scope=None,
test_program=None,
use_fp16_test=False,
rewrite_master_weight=False,
):
"""
Init the amp training, such as cast fp32 parameters to fp16 type.
Expand Down Expand Up @@ -369,6 +374,8 @@ def amp_init(
scope,
self._to_fp16_var_names,
self._amp_vartype,
rewrite_master_weight,
self._optimizer._master_weights,
)
if test_program is not None:
if self._use_pure_fp16:
Expand Down
26 changes: 23 additions & 3 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,12 +846,23 @@ def _convert_float_to_bfloat16(place, fp32_array):
return bf16_array


def _convert_to_float(place, org_array):
paddle.disable_static()
framework._set_expected_place(place)
org_tensor = paddle.to_tensor(org_array)
fp32_array = paddle.cast(org_tensor, paddle.float32).numpy()
paddle.enable_static()
return fp32_array


def cast_parameters_to_fp16(
place,
program,
scope=None,
to_fp16_var_names=None,
dest_type=core.VarDesc.VarType.FP16,
rewrite_master_weight=False,
master_weights={},
):
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Expand Down Expand Up @@ -882,10 +893,19 @@ def cast_parameters_to_fp16(
param_t = var_scope.find_var(param.name).get_tensor()
data = np.array(param_t)
if dest_type == core.VarDesc.VarType.BF16:
bf16_data = _convert_float_to_bfloat16(place, data)
param_t.set(bf16_data, place)
p_array = _convert_float_to_bfloat16(place, data)
param_t.set(p_array, place)
else:
param_t.set(np.float16(data), place)
p_array = np.float16(data)
param_t.set(p_array, place)
# rewrite master weight
if rewrite_master_weight and param.name in master_weights:
master_p_var = var_scope.find_var(
master_weights[param.name].name
)
master_p_t = master_p_var.get_tensor()
master_p_array = _convert_to_float(place, p_array)
master_p_t.set(master_p_array, place)
else:
_logger.warning(f"Cannot find {param.name}")

Expand Down
142 changes: 142 additions & 0 deletions test/amp/test_amp_master_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from amp_base_models import AmpTestBase

import paddle
from paddle.base import core


class SimpleNet(paddle.nn.Layer):
def __init__(self, input_size, output_size):
super().__init__()
weight_attr = paddle.ParamAttr(
name="weight", initializer=paddle.nn.initializer.Constant(value=0.5)
)
bias_attr = paddle.ParamAttr(
name="bias", initializer=paddle.nn.initializer.Constant(value=1.0)
)
self.linear = paddle.nn.Linear(
input_size, output_size, weight_attr, bias_attr
)

def forward(self, x):
x = self.linear(x)
return x


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestMasterWeight(AmpTestBase):
def run_dygraph(self, dtype, level, use_promote, max_iters, x_data):
losses = []
model = SimpleNet(100, 100)
optimizer = paddle.optimizer.AdamW(
learning_rate=0.01,
parameters=model.parameters(),
)
scaler = paddle.amp.GradScaler()
model, optimizer = paddle.amp.decorate(
models=model,
optimizers=optimizer,
level=level,
dtype=dtype,
)

for i in range(max_iters):
with paddle.amp.auto_cast(
enable=True,
dtype=dtype,
level=level,
use_promote=use_promote,
):
x = paddle.to_tensor(x_data, dtype='float16')
out = model(x)
loss = paddle.mean(out)
losses.append(loss)
scaled = scaler.scale(loss)
scaled.backward()
scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
return losses

def run_static(self, dtype, level, use_promote, max_iters, x_data):
paddle.enable_static()
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
losses = []
with paddle.utils.unique_name.guard():
with paddle.static.program_guard(main_program, startup_program):
model = SimpleNet(100, 100)
optimizer = paddle.optimizer.AdamW(learning_rate=0.01)
optimizer = paddle.static.amp.decorate(
optimizer,
level=level,
dtype=dtype,
use_promote=use_promote,
master_weight=True,
)
x = paddle.static.data(
name='input', shape=[100, 100], dtype='float16'
)
out = model(x)
loss = paddle.mean(out)
optimizer.minimize(loss)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup_program)
optimizer.amp_init(
place,
scope=paddle.static.global_scope(),
rewrite_master_weight=True,
)
for iter_id in range(max_iters):
results = exe.run(
program=main_program,
feed={x.name: x_data},
fetch_list=[loss],
)
print(f"-- [AMP {dtype} {level}] iter={iter_id}, loss={results[0]}")
losses.append(results[0])

paddle.disable_static()
return losses

def test_master_weight(self):
dtype = 'float16'
level = 'O2'
use_promote = True
total_steps = 4
x_data = np.random.random(size=[100, 100]).astype("float16")

loss_dygraph = self.run_dygraph(
dtype, level, use_promote, total_steps, x_data
)
loss_static = self.run_static(
dtype, level, use_promote, total_steps, x_data
)

for i in range(total_steps):
self.assertEqual(loss_dygraph[i], loss_static[i])


if __name__ == '__main__':
unittest.main()

0 comments on commit ea5f3c5

Please sign in to comment.