飞智支持了BERT,GLM,GPT2 以及T5模型的megatron-lm模型并行. 其主要的操作都是从Megatron-LM
中拿出来的,放在了mpu
模块中。
代码位置:flagai/model/layers/embeddings_mpu.py
核心思想:将两个linear
层的forward
过程,按照特定顺序进行拆分(先列后行):
intermediate_parallel = self.dense_h_to_4h(hidden_states)
intermediate_parallel = gelu(intermediate_parallel)
output = self.dense_4h_to_h(intermediate_parallel)
其中,self.dense_h_to_4h
和self.dense_4h_to_h
分别为:
# Project to 4h.
self.dense_h_to_4h = ColumnParallelLinear(hidden_size,
4 * hidden_size,
gather_output=False, #这里可以是True
init_method=init_method)
# Project back to h.
self.dense_4h_to_h = RowParallelLinear(
4 * hidden_size,
hidden_size,
input_is_parallel=True,# 受到self.dense_h_to_4h的gather_output设置影响
init_method=output_layer_init_method)
关键参数
world_size = get_model_parallel_world_size()
self.hidden_size_per_partition = divide(hidden_size, world_size)
self.hidden_size_per_attention_head = divide(hidden_size,
num_attention_heads)
self.num_attention_heads_per_partition = divide(
num_attention_heads, world_size)
self-attention
中的两个Linear layer分别转化为column/rowParallel
版本
代码位置:flagai/model/layers/attentions_mpu.py
如下:
self.query_key_value = ColumnParallelLinear(hidden_size,
3 * hidden_size,
stride=3,
gather_output=False,
init_method=init_method)
if relative_encoding:
self.relative = ColumnParallelLinear(hidden_size,
hidden_size,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)
同上,
self.query = ColumnParallelLinear(hidden_size,
hidden_size,
gather_output=False,
init_method=init_method)
self.key_value = ColumnParallelLinear(hidden_size,
2 * hidden_size,
stride=2,
gather_output=False,
init_method=init_method)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = RowParallelLinear(hidden_size,
hidden_size,
input_is_parallel=True,
init_method=output_layer_init_method)