Skip to content

Commit

Permalink
Merge pull request princeton-nlp#78 from bcol23/main
Browse files Browse the repository at this point in the history
fix for DDP
  • Loading branch information
Tianyu Gao authored Aug 26, 2021
2 parents cc0a168 + 1c6e0be commit 46035b5
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions simcse/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def cl_init(cls, config):
"""
cls.pooler_type = cls.model_args.pooler_type
cls.pooler = Pooler(cls.model_args.pooler_type)
cls.mlp = MLPLayer(config)
if cls.model_args.pooler_type == "cls":
cls.mlp = MLPLayer(config)
cls.sim = Similarity(temp=cls.model_args.temp)
cls.init_weights()

Expand Down Expand Up @@ -277,7 +278,7 @@ class BertForCL(BertPreTrainedModel):
def __init__(self, config, *model_args, **model_kargs):
super().__init__(config)
self.model_args = model_kargs["model_args"]
self.bert = BertModel(config)
self.bert = BertModel(config, add_pooling_layer=False)

if self.model_args.do_mlm:
self.lm_head = BertLMPredictionHead(config)
Expand Down Expand Up @@ -336,7 +337,7 @@ class RobertaForCL(RobertaPreTrainedModel):
def __init__(self, config, *model_args, **model_kargs):
super().__init__(config)
self.model_args = model_kargs["model_args"]
self.roberta = RobertaModel(config)
self.roberta = RobertaModel(config, add_pooling_layer=False)

if self.model_args.do_mlm:
self.lm_head = RobertaLMHead(config)
Expand Down

0 comments on commit 46035b5

Please sign in to comment.