Skip to content

Commit

Permalink
update readme; fix faiss bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyu Gao committed May 12, 2021
1 parent 725f70e commit 8ab1e5b
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
33 changes: 22 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,10 @@ We also provide an easy-to-build [demo website](./demo) to show how SimCSE can b
Our released models are listed as following. You can import these models by using the `simcse` package or using [HuggingFace's Transformers](https://github.com/huggingface/transformers).
| Model | Avg. STS |
|:-------------------------------|:--------:|
| [princeton-nlp/unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 74.54 |
| [princeton-nlp/unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 76.05 |
| [princeton-nlp/unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.50 |
| [princeton-nlp/unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 77.47 |
| [princeton-nlp/unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 76.25 |
| [princeton-nlp/unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 78.41 |
| [princeton-nlp/unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.57 |
| [princeton-nlp/unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 78.90 |
| [princeton-nlp/sup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased) | 81.57 |
| [princeton-nlp/sup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-large-uncased) | 82.21 |
| [princeton-nlp/sup-simcse-roberta-base](https://huggingface.co/princeton-nlp/sup-simcse-roberta-base) | 82.52 |
Expand Down Expand Up @@ -190,8 +190,8 @@ Arguments for the evaluation script are as follows,

* `--model_name_or_path`: The name or path of a `transformers`-based pre-trained checkpoint. You can directly use the models in the above table, e.g., `princeton-nlp/sup-simcse-bert-base-uncased`.
* `--pooler`: Pooling method. Now we support
* `cls` (default): Use the representation of `[CLS]` token. A linear+activation layer is applied after the representation (it's in the standard BERT implementation). If you use SimCSE, you should use this option.
* `cls_before_pooler`: Use the representation of `[CLS]` token without the extra linear+activation.
* `cls` (default): Use the representation of `[CLS]` token. A linear+activation layer is applied after the representation (it's in the standard BERT implementation). If you use **supervised SimCSE**, you should use this option.
* `cls_before_pooler`: Use the representation of `[CLS]` token without the extra linear+activation. If you use **unsupervised SimCSE**, you should take this option.
* `avg`: Average embeddings of the last layer. If you use checkpoints of SBERT/SRoBERTa ([paper](https://arxiv.org/abs/1908.10084)), you should use this option.
* `avg_top2`: Average embeddings of the last two layers.
* `avg_first_last`: Average embeddings of the first and last layers. If you use vanilla BERT or RoBERTa, this works the best.
Expand All @@ -208,11 +208,11 @@ Arguments for the evaluation script are as follows,

### Training

#### Data
**Data**

For unsupervised SimCSE, we sample 1 million sentences from English Wikipedia; for supervised SimCSE, we use the SNLI and MNLI datasets. You can run `data/download_wiki.sh` and `data/download_nli.sh` to download the two datasets.

#### Training scripts
**Training scripts**

We provide example training scripts for both unsupervised and supervised SimCSE. In `run_unsup_example.sh`, we provide a single-GPU (or CPU) example for the unsupervised version, and in `run_sup_example.sh` we give a **multiple-GPU** example for the supervised version. Both scripts call `train.py` for training. We explain the arguments in following:
* `--train_file`: Training file path. We support "txt" files (one line for one sentence) and "csv" files (2-column: pair data with no hard negative; 3-column: pair data with one corresponding hard negative instance). You can use our provided Wikipedia or NLI data, or you can use your own data with the same format.
Expand All @@ -226,11 +226,22 @@ We provide example training scripts for both unsupervised and supervised SimCSE.

All the other arguments are standard Huggingface's `transformers` training arguments. Some of the often-used arguments are: `--output_dir`, `--learning_rate`, `--per_device_train_batch_size`. In our example scripts, we also set to evaluate the model on the STS-B development set (need to download the dataset following the [evaluation](#evaluation) section) and save the best checkpoint.

**REPRODUCTION**: For results in the paper, we use Nvidia 3090 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.
For results in the paper, we use Nvidia 3090 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

#### Convert models
**Hyperparameters**

**IMPORTANT**: Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box).
We use the following hyperparamters for training SimCSE:

| | Unsup. BERT | Unsup. RoBERTa | Sup. |
|:--------------|:-----------:|:--------------:|:---------:|
| Batch size | 64 | 512 | 512 |
| Learning rate (base) | 3e-5 | 1e-5 | 5e-5 |
| Learning rate (large) | 1e-5 | 3e-5 | 1e-5 |


**Convert models**

Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box).



Expand Down
4 changes: 2 additions & 2 deletions run_unsup_example.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ python train.py \
--train_file data/wiki1m_for_simcse.txt \
--output_dir result/my-unsup-simcse-bert-base-uncased \
--num_train_epochs 1 \
--per_device_train_batch_size 512 \
--learning_rate 5e-5 \
--per_device_train_batch_size 64 \
--learning_rate 3e-5 \
--max_seq_length 32 \
--evaluation_strategy steps \
--metric_for_best_model stsb_spearman \
Expand Down
23 changes: 20 additions & 3 deletions simcse/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ class SimCSE(object):
def __init__(self, model_name_or_path: str,
device: str = None,
num_cells: int = 100,
num_cells_in_search: int = 10):
num_cells_in_search: int = 10,
pooler = None):

self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path)
Expand All @@ -33,6 +34,14 @@ def __init__(self, model_name_or_path: str,
self.is_faiss_index = False
self.num_cells = num_cells
self.num_cells_in_search = num_cells_in_search

if pooler is not None:
self.pooler = pooler
elif "unsup" in model_name_or_path:
logger.info("Use `cls_before_pooler` for unsupervised models. If you want to use other pooling policy, specify `pooler` argument.")
self.pooler = "cls_before_pooler"
else:
self.pooler = "cls"

def encode(self, sentence: Union[str, List[str]],
device: str = None,
Expand Down Expand Up @@ -62,7 +71,13 @@ def encode(self, sentence: Union[str, List[str]],
return_tensors="pt"
)
inputs = {k: v.to(target_device) for k, v in inputs.items()}
embeddings = self.model(**inputs, return_dict=True).pooler_output
outputs = self.model(**inputs, return_dict=True)
if self.pooler == "cls":
embeddings = outputs.pooler_output
elif self.pooler == "cls_before_pooler":
embeddings = outputs.last_hidden_state[:, 0]
else:
raise NotImplementedError
if normalize_to_unit:
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
embedding_list.append(embeddings.cpu())
Expand Down Expand Up @@ -112,9 +127,10 @@ def build_index(self, sentences_or_file_path: Union[str, List[str]],
if use_faiss is None or use_faiss:
try:
import faiss
assert hasattr(faiss, "IndexFlatIP")
use_faiss = True
except:
logger.warning("Fail to import faiss. Please install faiss or set faiss=False. Now the program continues with brute force search.")
logger.warning("Fail to import faiss. If you want to use faiss, install faiss through PyPI. Now the program continues with brute force search.")
use_faiss = False

# if the input sentence is a string, we assume it's the path of file that stores various sentences
Expand Down Expand Up @@ -159,6 +175,7 @@ def build_index(self, sentences_or_file_path: Union[str, List[str]],
index = embeddings
self.is_faiss_index = False
self.index["index"] = index
logger.info("Finished")

def search(self, queries: Union[str, List[str]],
device: str = None,
Expand Down

0 comments on commit 8ab1e5b

Please sign in to comment.