Skip to content

Commit

Permalink
fix llama
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed Jun 4, 2023
1 parent befa604 commit d48fa78
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 10 deletions.
6 changes: 3 additions & 3 deletions flagai/model/blocks/llama_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ def __init__(self, layer_id, config ):

self.layer_id = layer_id
if config.flash_atten_llama_style:
from flash_attn.ops.rms_norm import RMSNorm
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
import flash_attn
self.attention_norm = flash_attn.ops.rms_norm.RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = flash_attn.ops.rms_norm.RMSNorm(config.dim, eps=config.norm_eps)
else:
self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
Expand Down
10 changes: 5 additions & 5 deletions flagai/model/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,6 @@ def forward(

xq, xk = apply_rotary_pos_emb(xq, xk, freqs_cis=freqs_cis)

xq = xq.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
keys = keys.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
values = values.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
qkv = torch.concat([xq, keys, values], dim=2)

if use_cache:
self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
Expand All @@ -189,6 +184,11 @@ def forward(
keys = xk
values = xv

xq = xq.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
keys = keys.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
values = values.view(bsz, seqlen, 1, self.n_local_heads, self.head_dim)
qkv = torch.concat([xq, keys, values], dim=2)

if self.config.flash_atten or (self.config.flash_atten_llama_style and not self.training):
qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...')

Expand Down
4 changes: 2 additions & 2 deletions flagai/model/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def __init__(self, config, **kwargs):
self.layers.append(LLAMABlock(layer_id, config))

if config.flash_atten_llama_style:
from flash_attn.ops.rms_norm import RMSNorm
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
import flash_attn
self.norm = flash_attn.ops.rms_norm.RMSNorm(config.dim, eps=config.norm_eps)
else:
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
if os.getenv("ENV_TYPE") == "deepspeed+mpu":
Expand Down

0 comments on commit d48fa78

Please sign in to comment.