Skip to content

Commit

Permalink
add bmt_fused_ce inplace
Browse files Browse the repository at this point in the history
Signed-off-by: ftgreat <[email protected]>
  • Loading branch information
ftgreat committed May 6, 2023
1 parent 12d7186 commit 713b6c3
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion flagai/model/llama_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(
ignore_index=-100,
bmt_comm_overlap=False,
bmt_fused_ce=False,
bmt_fused_ce_inplace=True,
# pad_token_id=-1,
# bos_token_id=0,
# eos_token_id=1,
Expand Down Expand Up @@ -106,6 +107,7 @@ def __init__(
self.ignore_index = ignore_index
self.bmt_comm_overlap = bmt_comm_overlap
self.bmt_fused_ce = bmt_fused_ce
self.bmt_fused_ce_inplace = bmt_fused_ce_inplace

# super().__init__(
# pad_token_id=pad_token_id,
Expand Down Expand Up @@ -168,7 +170,9 @@ def __init__(self, config, **kwargs):

if os.getenv("ENV_TYPE") == "bmtrain" and self.config.bmt_fused_ce:
import bmtrain as bmt
self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=self.config.ignore_index)
self.loss_func = bmt.loss.FusedCrossEntropy(
ignore_index=self.config.ignore_index,
inplace=self.config.bmt_fused_ce_inplace)
else:
self.loss_func = nn.CrossEntropyLoss(ignore_index=self.config.ignore_index)

Expand Down

0 comments on commit 713b6c3

Please sign in to comment.