Skip to content

Commit

Permalink
[AMP] fix static promote (PaddlePaddle#53439)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored May 8, 2023
1 parent 3fd2e76 commit 2bf6128
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 52 deletions.
68 changes: 34 additions & 34 deletions python/paddle/static/amp/fp16_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@
)

# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_list = {
_extra_black_list = {
'lookup_table',
'lookup_table_v2',
'scatter',
'scatter_grad',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
}


Expand Down Expand Up @@ -118,8 +122,7 @@ def _get_sys_unsupported_list(dtype):
def _get_unsupported_list(dtype):
# The set of ops that don't support fp16 calculation
_, _sys_unsupported_list = _get_sys_unsupported_list(dtype)
unsupported_list = _extra_unsupported_list | _sys_unsupported_list
return unsupported_list
return _sys_unsupported_list


# The three sets listed below are changed dynamiclly. They don't contain all
Expand All @@ -145,6 +148,32 @@ def _get_white_list(dtype):
return white_list_for_dtype


# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}


def _get_black_list():
_black_list = copy.copy(black_list)
_black_list = _black_list | _extra_black_list
return _black_list


class AutoMixedPrecisionLists:
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
Expand All @@ -170,7 +199,7 @@ def __init__(
self._custom_white_list = custom_white_list
self._custom_black_list = custom_black_list
self.white_list = copy.copy(_get_white_list(self.amp_dtype))
self.black_list = copy.copy(black_list)
self.black_list = copy.copy(_get_black_list())
self.gray_list = copy.copy(gray_list)
self.unsupported_list = copy.copy(_get_unsupported_list(self.amp_dtype))
self.black_varnames = copy.copy(custom_black_varnames)
Expand All @@ -196,8 +225,6 @@ def _update_list(self):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
if op_name in self.white_list:
Expand All @@ -217,33 +244,6 @@ def _update_list(self):
)


# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
black_list = {
'exp',
'square',
'log',
'mean',
'sum',
'cos_sim',
'softmax',
'softmax_with_cross_entropy',
'sigmoid_cross_entropy_with_logits',
'c_softmax_with_cross_entropy',
'cross_entropy',
'cross_entropy2',
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table',
'lookup_table_v2',
'linear_interp_v2',
'nearest_interp_v2',
'bilinear_interp_v2',
'bicubic_interp_v2',
'trilinear_interp_v2',
# default fp32 can avoid return inf when the sum value large than 65504
'reduce_sum',
}

# This set contains two types of ops. All ops supported fp16 calculation. One
# of two types is considered numerically-safe, but may be made unsafe by an
# upstream blacklist op. Another type do not have numerically-significant
Expand Down
20 changes: 16 additions & 4 deletions python/paddle/static/amp/fp16_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,15 +425,22 @@ def set_var_dst_dtype(


def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if level == "O1":
return
keep_fp32_var_names = set()
if level == "O1":
return keep_fp32_var_names
all_parameters = []
for block in program.blocks:
all_parameters.extend(block.all_parameters())
ops = block.ops
for op in ops:
if op_need_keep_fp32(op, amp_lists, use_fp16_guard):
# Currently, lookup_table is in black_list and unsupport_list, it's weight will be
# set to fp32 in setp 1 of cast_model_tp_fp16. But the weight may be used as matmul's
# input in transformer, so the weight is also in to_fp16_var_names.
# TODO(zhangting2020): consider fix auto_parallel_fp16 and remove lookup_table
# from black_list and unsupport_list.
if op in ['lookup_table', 'lookup_table_v2']:
continue
if _need_keep_fp32(op, amp_lists.unsupported_list, use_fp16_guard):
for in_name in op.input_names:
keep_fp32_var_names = keep_fp32_var_names.union(
op.input(in_name)
Expand All @@ -451,6 +458,7 @@ def set_param_dtype(program, dtype, amp_lists, use_fp16_guard, level):
if param.name not in keep_fp32_var_names:
_logger.debug(f"-- set param {param.name} to {dtype} --.")
param.desc.set_dtype(dtype)
return keep_fp32_var_names


def op_need_keep_fp32(op, amp_lists, use_fp16_guard):
Expand Down Expand Up @@ -607,15 +615,17 @@ def cast_model_to_fp16(
keep_fp32_ops = set()
keep_fp16_ops = set()
to_fp16_var_names = set()
keep_fp32_var_names = set()

# step 1: set params dtype.
set_param_dtype(
fp32_var_names = set_param_dtype(
program,
dtype=dest_type,
amp_lists=amp_lists,
use_fp16_guard=use_fp16_guard,
level=level,
)
keep_fp32_var_names = keep_fp32_var_names.union(fp32_var_names)

def need_process(op):
need_process = True
Expand Down Expand Up @@ -719,6 +729,8 @@ def need_process(op):
idx += num_cast_ops + 1
_logger.debug("---- after cast model to fp16 ----")
_logger.debug(program)

to_fp16_var_names.difference_update(keep_fp32_var_names)
return to_fp16_var_names


Expand Down
7 changes: 4 additions & 3 deletions test/amp/test_amp_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ def test_static(self):
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.unsupported_list
)
if paddle.amp.is_float16_supported():
self.check_if_op_not_in_list(
self.custom_white_list, amp_list.black_list
)

def test_eager(self):
if not paddle.amp.is_float16_supported():
Expand Down
1 change: 0 additions & 1 deletion test/amp/test_amp_promote.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def check_promote_results(

max_iters = 2
x_fp32 = np.random.random(size=[1, 1, 6, 6]).astype("float32")
print(main_program)
losses_o1 = self.run_program(
main_program,
startup_program,
Expand Down
8 changes: 5 additions & 3 deletions test/amp/test_model_cast_to_bf16.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,18 +265,20 @@ def test_amp_bf16_o2(self):

amp.debugging.collect_operator_stats(main_program)
op_stats_list = amp.debugging._get_op_stats_list(main_program)
expected_fp32_calls = {"lookup_table_v2": 1}
expected_bf16_calls = {
"matmul_v2": 1,
"elementwise_add": 1,
"dropout": 1,
"lookup_table_v2": 0,
"squared_l2_norm": 2,
"adamw": 2,
"squared_l2_norm": 3,
"adamw": 3,
}
self._check_optimizer(
main_program,
expected_bf16_calls["matmul_v2"]
+ expected_bf16_calls["elementwise_add"],
+ expected_bf16_calls["elementwise_add"]
+ expected_fp32_calls["lookup_table_v2"],
)
self._check_op_calls(op_stats_list[0], expected_bf16_calls)

Expand Down
35 changes: 28 additions & 7 deletions test/contrib/test_image_classification_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,10 @@ def test_amp_lists(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

amp_lists = paddle.static.amp.AutoMixedPrecisionLists()
Expand All @@ -331,7 +334,10 @@ def test_amp_lists_1(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 1. w={'exp}, b=None
Expand All @@ -348,7 +354,10 @@ def test_amp_lists_2(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 2. w={'tanh'}, b=None
Expand All @@ -365,7 +374,10 @@ def test_amp_lists_3(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 3. w={'lstm'}, b=None
Expand All @@ -381,7 +393,10 @@ def test_amp_lists_4(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 4. w=None, b={'conv2d'}
Expand All @@ -400,7 +415,10 @@ def test_amp_lists_5(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 5. w=None, b={'tanh'}
Expand All @@ -419,7 +437,10 @@ def test_amp_lists_6(self):
copy.copy(paddle.static.amp.fp16_lists.white_list)
| paddle.static.amp.fp16_lists._only_supported_fp16_list
)
black_list = copy.copy(paddle.static.amp.fp16_lists.black_list)
black_list = copy.copy(
paddle.static.amp.fp16_lists.black_list
| paddle.static.amp.fp16_lists._extra_black_list
)
gray_list = copy.copy(paddle.static.amp.fp16_lists.gray_list)

# 6. w=None, b={'lstm'}
Expand Down

0 comments on commit 2bf6128

Please sign in to comment.