diff --git a/flagai/model/llama_model.py b/flagai/model/llama_model.py index 9e9792d6..8e196478 100644 --- a/flagai/model/llama_model.py +++ b/flagai/model/llama_model.py @@ -78,6 +78,7 @@ def __init__( ignore_index=-100, bmt_comm_overlap=False, bmt_fused_ce=False, + bmt_fused_ce_inplace=True, # pad_token_id=-1, # bos_token_id=0, # eos_token_id=1, @@ -106,6 +107,7 @@ def __init__( self.ignore_index = ignore_index self.bmt_comm_overlap = bmt_comm_overlap self.bmt_fused_ce = bmt_fused_ce + self.bmt_fused_ce_inplace = bmt_fused_ce_inplace # super().__init__( # pad_token_id=pad_token_id, @@ -168,7 +170,9 @@ def __init__(self, config, **kwargs): if os.getenv("ENV_TYPE") == "bmtrain" and self.config.bmt_fused_ce: import bmtrain as bmt - self.loss_func = bmt.loss.FusedCrossEntropy(ignore_index=self.config.ignore_index) + self.loss_func = bmt.loss.FusedCrossEntropy( + ignore_index=self.config.ignore_index, + inplace=self.config.bmt_fused_ce_inplace) else: self.loss_func = nn.CrossEntropyLoss(ignore_index=self.config.ignore_index)