Skip to content

Commit

Permalink
Merge branch 'develop' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyu Gao committed May 11, 2021
2 parents 354784b + becfbf0 commit b0aaacd
Show file tree
Hide file tree
Showing 24 changed files with 1,701 additions and 28 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Byte-compiled / optimized / DLL files
__pycache__/
simcse/__pycache__/
*.py[cod]
*$py.class

Expand Down Expand Up @@ -128,3 +129,4 @@ dmypy.json
# Pyre type checker
.pyre/
.DS_Store
.vscode
105 changes: 81 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ Wait a minute! The authors are working day and night 💪, to make the code and
We anticipate the code will be out * **in one week** *. -->

<!-- * 4/26: SimCSE is now on [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) (Thanks [@AK391](https://github.com/AK391)!). Try it out! -->
* 5/10: We released our [sentence embedding tool](#getting-started) and [demo code](./demo).
* 4/23: We released our [training code](#training).
* 4/20: We released our [model checkpoints](#use-our-models-out-of-the-box) and [evaluation code](#evaluation).
* 4/18: We released [our paper](https://arxiv.org/pdf/2104.08821.pdf). Check it out!


## Quick links
## Quick Links

- [Overview](#overview)
- [Pre-trained sentence embeddings](#use-our-models-out-of-the-box)
- [Requirements](#requirements)
- [Evaluation](#evaluation)
- [Training](#training)
- [Getting Started](#getting-started)
- [Model List](#model-list)
- [Use SimCSE with Huggingface](#use-our-models-out-of-the-box)
- [Train SimCSE](#train-simcse)
- [Requirements](#requirements)
- [Evaluation](#evaluation)
- [Training](#training)
- [Bugs or Questions?](#Bugs-or-questions)
- [Citation](#citation)
- [SimCSE Elsewhere](#simcse-elsewhere)
Expand All @@ -33,22 +37,71 @@ We propose a simple contrastive learning framework that works with both unlabele

![](figure/model.png)

## Use our models out of the box
Our pre-trained models are now publicly available with [HuggingFace's Transformers](https://github.com/huggingface/transformers). Models and their performance are presented as follows:
## Getting Started

We provide an easy-to-use sentence embedding tool based on our SimCSE model. To use the tool, first install the `simcse` package from pypi
```bash
pip install simcse
```

Or directly install it from our code
```bash
python setup.py install
```

Note that if you want to enable GPU encoding, you should install the correct version of PyTorch that supports CUDA. See [PyTorch official website](https://pytorch.org) for instructions.

After installing the package, you can load our model by just two lines of code
```python
from simcse import SimCSE
model = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased")
```
See [model list](#model-list) for a full list of available models.

Then you can use our model for **encoding sentences into embeddings**
```python
embeddings = model.encode("A woman is reading.")
```

**Compute the cosine similarities** between two groups of sentences
```python
sentences_a = ['A woman is reading.', 'A man is playing a guitar.']
sentences_b = ['He plays guitar.', 'A woman is making a photo.']
similarities = model.similarity(sentences_a, sentences_b)
```

Or build index for a group of sentences and **search** among them
```python
sentences = ['A woman is reading.', 'A man is playing a guitar.']
model.build_index(sentences)
results = model.search("He plays guitar.")
```

We also support [faiss](https://github.com/facebookresearch/faiss), an efficient similarity search library. Just install the package following [instructions](https://github.com/facebookresearch/faiss/blob/master/INSTALL.md) here and `simcse` will automatically use `faiss` for efficient search.

**WARNING**: We have found that `faiss` did not well support Nvidia AMPERE GPUs (3090 and A100). In that case, you should change to other GPUs or install the CPU version of `faiss` package.

We also provide an easy-to-build [demo website](./demo) to show how SimCSE can be used in sentence retrieval.

## Model List

Our released models are listed as following. You can import these models by using the `simcse` package or using [HuggingFace's Transformers](https://github.com/huggingface/transformers).
| Model | Avg. STS |
|:-------------------------------:|:--------:|
| [unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 74.54 |
| [unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 76.05 |
| [unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.50 |
| [unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 77.47 |
| [sup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased) | 81.57 |
| [sup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-large-uncased) | 82.21 |
| [sup-simcse-roberta-base](https://huggingface.co/princeton-nlp/sup-simcse-roberta-base) | 82.52 |
| [sup-simcse-roberta-large](https://huggingface.co/princeton-nlp/sup-simcse-roberta-large) | 83.76 |
|:-------------------------------|:--------:|
| [princeton-nlp/unsup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-base-uncased) | 74.54 |
| [princeton-nlp/unsup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/unsup-simcse-bert-large-uncased) | 76.05 |
| [princeton-nlp/unsup-simcse-roberta-base](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-base) | 76.50 |
| [princeton-nlp/unsup-simcse-roberta-large](https://huggingface.co/princeton-nlp/unsup-simcse-roberta-large) | 77.47 |
| [princeton-nlp/sup-simcse-bert-base-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-base-uncased) | 81.57 |
| [princeton-nlp/sup-simcse-bert-large-uncased](https://huggingface.co/princeton-nlp/sup-simcse-bert-large-uncased) | 82.21 |
| [princeton-nlp/sup-simcse-roberta-base](https://huggingface.co/princeton-nlp/sup-simcse-roberta-base) | 82.52 |
| [princeton-nlp/sup-simcse-roberta-large](https://huggingface.co/princeton-nlp/sup-simcse-roberta-large) | 83.76 |

**Naming rules**: `unsup` and `sup` represent "unsupervised" (trained on Wikipedia corpus) and "supervised" (trained on NLI datasets) respectively.

You can easily import our model in an out-of-the-box way with HuggingFace's API:
## Use SimCSE with Huggingface

Besides using our provided sentence embedding tool, you can also easily import our models with HuggingFace's `transformers`:
```python
import torch
from scipy.spatial.distance import cosine
Expand Down Expand Up @@ -81,9 +134,11 @@ print("Cosine similarity between \"%s\" and \"%s\" is: %.3f" % (texts[0], texts[

If you encounter any problem when directly loading the models by HuggingFace's API, you can also download the models manually from the above table and use `model = AutoModel.from_pretrained({PATH TO THE DOWNLOAD MODEL})`.

If you only want to use our models in an out-of-the-box way, just installing the latest version of `torch`, `transformers` and `scipy` is enough. If you want to use our training or evaluation code, see the requirement section below.
## Train SimCSE

## Requirements
In the following section, we describe how to train a SimCSE model by using our code.

### Requirements

First, install PyTorch by following the instructions from [the official website](https://pytorch.org). To faithfully reproduce our results, please use the correct `1.7.1` version corresponding to your platforms/CUDA versions. PyTorch version higher than `1.7.1` should also work. For example, if you use Linux and **CUDA11** ([how to check CUDA version](https://varhowto.com/check-cuda-version/)), install PyTorch by the following command,

Expand All @@ -104,7 +159,7 @@ Then run the following script to install the remaining dependencies,
pip install -r requirements.txt
```

## Evaluation
### Evaluation
Our evaluation code for sentence embeddings is based on a modified version of [SentEval](https://github.com/facebookresearch/SentEval). It evaluates sentence embeddings on semantic textual similarity (STS) tasks and downstream transfer tasks. For STS tasks, our evaluation takes the "all" setting, and report Spearman's correlation. See [our paper](https://arxiv.org/pdf/2104.08821.pdf) (Appendix B) for evaluation details.

Before evaluation, please download the evaluation datasets by running
Expand Down Expand Up @@ -151,13 +206,13 @@ Arguments for the evaluation script are as follows,
* `na`: Manually set tasks by `--tasks`.
* `--tasks`: Specify which dataset(s) to evaluate on. Will be overridden if `--task_set` is not `na`. See the code for a full list of tasks.

## Training
### Training

### Data
#### Data

For unsupervised SimCSE, we sample 1 million sentences from English Wikipedia; for supervised SimCSE, we use the SNLI and MNLI datasets. You can run `data/download_wiki.sh` and `data/download_nli.sh` to download the two datasets.

### Training scripts
#### Training scripts

We provide example training scripts for both unsupervised and supervised SimCSE. In `run_unsup_example.sh`, we provide a single-GPU (or CPU) example for the unsupervised version, and in `run_sup_example.sh` we give a **multiple-GPU** example for the supervised version. Both scripts call `train.py` for training. We explain the arguments in following:
* `--train_file`: Training file path. We support "txt" files (one line for one sentence) and "csv" files (2-column: pair data with no hard negative; 3-column: pair data with one corresponding hard negative instance). You can use our provided Wikipedia or NLI data, or you can use your own data with the same format.
Expand All @@ -173,10 +228,12 @@ All the other arguments are standard Huggingface's `transformers` training argum

**REPRODUCTION**: For results in the paper, we use Nvidia 3090 GPUs with CUDA 11. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

### Convert models
#### Convert models

**IMPORTANT**: Our saved checkpoints are slightly different from Huggingface's pre-trained checkpoints. Run `python simcse_to_huggingface.py --path {PATH_TO_CHECKPOINT_FOLDER}` to convert it. After that, you can evaluate it by our [evaluation](#evaluation) code or directly use it [out of the box](#use-our-models-out-of-the-box).



## Bugs or questions?

If you have any questions related to the code or the paper, feel free to email Tianyu (`[email protected]`) and Xingcheng (`[email protected]`). If you encounter any problems when using the code, or want to report a bug, you can open an issue. Please try to specify the problem with details so we can help you better and quicker!
Expand Down
18 changes: 18 additions & 0 deletions demo/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
## Demo of SimCSE
Several demos are available for people to play with our pre-trained SimCSE.

### Flask Demo
<div align="center">
<img src="../figure/demo.gif" width="750">
</div>

We provide a simple Web demo based on [flask](https://github.com/pallets/flask) to show how SimCSE can be directly used for information retrieval. To run this flask demo locally, make sure the SimCSE inference interfaces are setup:
```bash
git clone https://github.com/princeton-nlp/SimCSE
cd SimCSE
python setup.py develop
```
Then you can use `run_demo_example.sh` to launch the demo. As a default setting, we build the index for 1000 sentences sampled from STS-B dataset. Feel free to build the index of your own corpora. You can also install [faiss](https://github.com/facebookresearch/faiss) to speed up the retrieval process.

### Gradio Demo
[AK391](https://github.com/AK391) has provided a [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) of SimCSE to show how the pre-trained models can predict the semantic similarity between two sentences.
84 changes: 84 additions & 0 deletions demo/flaskdemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import json
import argparse
import torch
import os
import random
import numpy as np
import requests
import logging
import math
import copy
import string

from tqdm import tqdm
from time import time
from flask import Flask, request, jsonify
from flask_cors import CORS
from tornado.wsgi import WSGIContainer
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop

from simcse import SimCSE

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)

def run_simcse_demo(port, args):
app = Flask(__name__, static_folder='./static')
app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
CORS(app)

sentence_path = os.path.join(args.sentences_dir, args.example_sentences)
query_path = os.path.join(args.sentences_dir, args.example_query)
embedder = SimCSE(args.model_name_or_path)
embedder.build_index(sentence_path)
@app.route('/')
def index():
return app.send_static_file('index.html')

@app.route('/api', methods=['GET'])
def api():
query = request.args['query']
top_k = int(request.args['topk'])
threshold = float(request.args['threshold'])
start = time()
results = embedder.search(query, top_k=top_k, threshold=threshold)
ret = []
out = {}
for sentence, score in results:
ret.append({"sentence": sentence, "score": score})
span = time() - start
out['ret'] = ret
out['time'] = "{:.4f}".format(span)
return jsonify(out)

@app.route('/files/<path:path>')
def static_files(path):
return app.send_static_file('files/' + path)

@app.route('/get_examples', methods=['GET'])
def get_examples():
with open(query_path, 'r') as fp:
examples = [line.strip() for line in fp.readlines()]
return jsonify(examples)

addr = args.ip + ":" + args.port
logger.info(f'Starting Index server at {addr}')
http_server = HTTPServer(WSGIContainer(app))
http_server.listen(port)
IOLoop.instance().start()

if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--model_name_or_path', default=None, type=str)
parser.add_argument('--device', default='cpu', type=str)
parser.add_argument('--sentences_dir', default=None, type=str)
parser.add_argument('--example_query', default=None, type=str)
parser.add_argument('--example_sentences', default=None, type=str)
parser.add_argument('--port', default='8888', type=str)
parser.add_argument('--ip', default='http://127.0.0.1')
parser.add_argument('--load_light', default=False, action='store_true')
args = parser.parse_args()

run_simcse_demo(args.port, args)
File renamed without changes.
9 changes: 9 additions & 0 deletions demo/run_demo_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

# This example shows how to run the flask demo of SimCSE

python flaskdemo.py \
--model_name_or_path princeton-nlp/sup-simcse-bert-base-uncased \
--sentences_dir ./static/ \
--example_query example_query.txt \
--example_sentences example_sentence.txt
3 changes: 3 additions & 0 deletions demo/static/example_query.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
a man is playing music
a woman is making a photo
a woman is taking some food
Loading

0 comments on commit b0aaacd

Please sign in to comment.