Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] Add support for seq split in Domino #7111

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

duanhx1037
Copy link

No description provided.

@GuanhuaWang
Copy link
Contributor

Hi @duanhx1037

Really appreciated your quick action on this effort! But right now it is far from ready. Please allow me to change the title to [Draft] add support for seq split in Domino. Thank you!

@GuanhuaWang GuanhuaWang changed the title Add support for seq split in Domino [Draft] Add support for seq split in Domino Mar 4, 2025
@GuanhuaWang GuanhuaWang marked this pull request as draft March 4, 2025 22:24
Comment on lines +387 to +433
layernorm_output = torch.concat([layernorm_output0, layernorm_output1], dim=0)
mixed_x_layer, _ = self.query_key_value(layernorm_output)

# [s, b, np * 3 * hn] --> [s, b, np, 3 * hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)

# [s, b, np, 3 * hn] -> [b, np, s, 3*hn]
mixed_x_layer = mixed_x_layer.permute(1, 2, 0, 3).contiguous()
# [s, b, np, 3 * hn] --> [s, b, np, hn], [s, b, np, hn], [s, b, np, hn]
(query_layer, key_layer, value_layer) = torch.split(mixed_x_layer, [
self.hidden_size_per_attention_head, self.hidden_size_per_attention_head,
self.hidden_size_per_attention_head
], dim=3)

# [s, b, np, np * hn] -> [s, b, np, hn]
query_layer = query_layer.view(query_layer.size(0), query_layer.size(1), -1,
self.hidden_size_per_attention_head)

if rotary_pos_emb is not None:
if isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = rotary_pos_emb
else:
rotary_pos_emb = ((rotary_pos_emb, ) * 2)
q_pos_emb, k_pos_emb = rotary_pos_emb
query_layer = self.apply_rotary_pos_emb(query_layer, q_pos_emb)
key_layer = self.apply_rotary_pos_emb(key_layer, k_pos_emb)

batchsize, num_heads, seq_len, hidden_per_head = query_layer.shape[0], query_layer.shape[1], query_layer.shape[2], query_layer.shape[3]

# seq 0: core attention
context_layer0 = self.self_attention_sp(query_layer[:, :, :seq_len//2, :], key_layer, value_layer, attention_mask[:, :, :seq_len//2, :])
# Output. [s, b, h]
attention_output0, attention_bias0 = self.dense(context_layer0)

handle0 = dist.all_reduce(attention_output0, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)

# seq 1: core attention
context_layer1 = self.self_attention_sp(query_layer[:, :, seq_len//2:, :], key_layer, value_layer, attention_mask[:, :, seq_len//2:, :])
# Output. [s, b, h]
attention_output1, attention_bias1 = self.dense(context_layer1)

handle1 = dist.all_reduce(attention_output1, group=self.mpu.get_tensor_model_parallel_group(), async_op=True)
handle0.wait()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are exactly the same as sharedAttention forward, I don't see why we need these duplication here.

Also please follow our current code hierarchy, not pull up lower layer module implementation code to upper layer module.

e.g., if any real change need to make on sharedAttention, create a similar module say XYZAttention, then here in DominoTransformerLayer forward function, we can simply call XYZAttention module without duplicating XYZAttention module's every line of code of its forward.

Comment on lines +280 to +301
elif self.input_split_dim == "seq":
query_projection_size = config.kv_channels * config.num_attention_heads
kv_projection_size = config.kv_channels * config.num_attention_heads

# Per attention head and per partition values.
world_size = mpu.get_tensor_model_parallel_world_size()
self.hidden_size_per_attention_head = query_projection_size // config.num_attention_heads
self.num_attention_heads_per_partition = config.num_attention_heads // world_size
self.query_key_value = ColumnParallelLinear(config.hidden_size,
query_projection_size + 2 * kv_projection_size,
config=config,
init_method=config.init_method,
bias=config.add_bias_linear,
gather_output=False)
self.self_attention_sp = CoreAttention(config, self.layer_number, mpu, self_attn_mask_type)
self.dense = RowParallelLinearNoComm(query_projection_size,
config.hidden_size,
config=config,
init_method=config.output_layer_init_method,
bias=config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is duplication of shardedAttention module's detail implementation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants