Skip to content

Commit

Permalink
clean the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Tianyu Gao committed May 11, 2021
1 parent 7399dcb commit a670591
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 79 deletions.
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"python.linting.flake8Enabled": true,
"python.linting.enabled": false
}
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Wait a minute! The authors are working day and night 💪, to make the code and
We anticipate the code will be out * **in one week** *. -->

<!-- * 4/26: SimCSE is now on [Gradio Web Demo](https://gradio.app/g/AK391/SimCSE) (Thanks [@AK391](https://github.com/AK391)!). Try it out! -->
* 5/10: We released our [inference interface and demo code](#inference).
* 5/10: We released our [sentence embedding tool and demo code](#inference).
* 4/23: We released our [training code](#training).
* 4/20: We released our [model checkpoints](#use-our-models-out-of-the-box) and [evaluation code](#evaluation).
* 4/18: We released [our paper](https://arxiv.org/pdf/2104.08821.pdf). Check it out!
Expand All @@ -35,6 +35,7 @@ We propose a simple contrastive learning framework that works with both unlabele

![](figure/model.png)


## Use our models out of the box
Our pre-trained models are now publicly available with [HuggingFace's Transformers](https://github.com/huggingface/transformers). Models and their performance are presented as follows:
| Model | Avg. STS |
Expand Down
2 changes: 1 addition & 1 deletion simcse/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .inference import SentenceEmbedder
from .tool import SimCSE
162 changes: 85 additions & 77 deletions simcse/inference.py → simcse/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
level=logging.INFO)
logger = logging.getLogger(__name__)

"""
A simple class used for several downstream zero-shot sentence embedding tasks:
"""
class SentenceEmbedder(object):
class SimCSE(object):
"""
A class for embedding sentences, calculating similarities, and retriving sentences by SimCSE.
"""
def __init__(self, model_name_or_path: str,
device: str = None,
num_cells: int = 100,
num_cells_in_search: int = 10):

self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModel.from_pretrained(model_name_or_path)
if device is None:
Expand All @@ -36,8 +37,10 @@ def __init__(self, model_name_or_path: str,
def encode(self, sentence: Union[str, List[str]],
device: str = None,
return_numpy: bool = False,
normalize_to_unit: bool = False,
keep_dim: bool = False) -> Union[ndarray, Tensor]:
normalize_to_unit: bool = True,
keepdim: bool = False,
batch_size: int = 64,
max_length: int = 128) -> Union[ndarray, Tensor]:

target_device = self.device if device is None else device
self.model = self.model.to(target_device)
Expand All @@ -46,19 +49,26 @@ def encode(self, sentence: Union[str, List[str]],
if isinstance(sentence, str):
sentence = [sentence]
single_sentence = True

inputs = self.tokenizer(sentence, padding=True, truncation=True, return_tensors="pt")
for feature in inputs:
inputs[feature] = inputs[feature].to(target_device)

embedding_list = []
with torch.no_grad():
embeddings = self.model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.cpu()
total_batch = len(sentence) // batch_size + (1 if len(sentence) % batch_size > 0 else 0)
for batch_id in tqdm(range(total_batch)):
inputs = self.tokenizer(
sentence[batch_id*batch_size:(batch_id+1)*batch_size],
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt"
)
inputs = {k: v.to(target_device) for k, v in inputs.items()}
embeddings = self.model(**inputs, return_dict=True).pooler_output
if normalize_to_unit:
embeddings = embeddings / embeddings.norm(dim=1, keepdim=True)
embedding_list.append(embeddings.cpu())
embeddings = torch.cat(embedding_list, 0)

if normalize_to_unit:
embeddings = normalize(embeddings, axis=1)
if not return_numpy:
embeddings = Tensor(embeddings)

if single_sentence and not keep_dim:
if single_sentence and not keepdim:
embeddings = embeddings[0]

if return_numpy and not isinstance(embeddings, ndarray):
Expand All @@ -79,11 +89,11 @@ def similarity(self, queries: Union[str, List[str]],
# check whether N == 1 or M == 1
single_query, single_key = len(query_vecs.shape) == 1, len(key_vecs.shape) == 1
if single_query:
query_vecs = query_vecs.reshape(1,-1)
query_vecs = query_vecs.reshape(1, -1)
if single_key:
key_vecs = key_vecs.reshape(1,-1)
key_vecs = key_vecs.reshape(1, -1)

# returns a N*M similarity array
# returns an N*M similarity array
similarities = cosine_similarity(query_vecs, key_vecs)

if single_query:
Expand All @@ -95,51 +105,54 @@ def similarity(self, queries: Union[str, List[str]],

def build_index(self, sentences_or_file_path: Union[str, List[str]],
use_faiss: bool = None,
faiss_fast: bool = False,
faiss_gpu: bool = False,
device: str = None,
batch_size: int = 64):

if use_faiss is None or use_faiss == True:
if use_faiss is None or use_faiss:
try:
import faiss
use_faiss = True
except:
logger.warning("Fail to import faiss, try to use exact search")
logger.warning("Fail to import faiss. Please install faiss or set faiss=False. Now the program continues with brute force search.")
use_faiss = False

# if the input sentence is a string, we assume it's the path of file that stores various sentences
if isinstance(sentences_or_file_path, str):
sentences = []
with open(sentences_or_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in lines:
sentences.append(line.strip())
with open(sentences_or_file_path, "r") as f:
logging.info("Loading sentences from %s ..." % (sentences_or_file_path))
for line in tqdm(f):
sentences.append(line.rstrip())
sentences_or_file_path = sentences

embeddings = []
logger.info("Building Index...")
logger.info("Generating Embeddings for Target Sentences...")
for i in tqdm(range(0, len(sentences_or_file_path), batch_size)):
batch_embeddings = self.encode(sentences_or_file_path[i:i + batch_size], device=device, return_numpy=True, normalize_to_unit=use_faiss)
embeddings.append(batch_embeddings)
embeddings = np.vstack(embeddings)

id2sentence = {}
for i, sentence in enumerate(sentences_or_file_path):
id2sentence[i] = sentence
self.index = {"id2sentence": id2sentence}
logger.info("Encoding embeddings for sentences...")
embeddings = self.encode(sentences_or_file_path, device=device, batch_size=batch_size, normalize_to_unit=True, return_numpy=True)

logger.info("Building index...")
self.index = {"sentences": sentences_or_file_path}

if use_faiss:
quantizer = faiss.IndexFlatL2(embeddings.shape[1])
index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path)))
logger.info("Train and Insert Embeddings to Faiss Index...")
index.train(embeddings.astype(np.float32))
quantizer = faiss.IndexFlatIP(embeddings.shape[1])
if faiss_fast:
index = faiss.IndexIVFFlat(quantizer, embeddings.shape[1], min(self.num_cells, len(sentences_or_file_path)))
else:
index = quantizer
if faiss_gpu:
res = faiss.StandardGpuResources()
res.setTempMemory(20 * 1024 * 1024 * 1024)
index = faiss.index_cpu_to_gpu(res, 0, index)

if faiss_fast:
index.train(embeddings.astype(np.float32))
index.add(embeddings.astype(np.float32))
index.nprobe = min(self.num_cells_in_search, len(sentences_or_file_path))
self.index["index"] = index
self.is_faiss_index = True
else:
self.index["index"] = embeddings
index = embeddings
self.is_faiss_index = False
self.index["index"] = index

def search(self, queries: Union[str, List[str]],
device: str = None,
Expand All @@ -160,16 +173,15 @@ def search(self, queries: Union[str, List[str]],
if s >= threshold:
id_and_score.append((i, s))
id_and_score = sorted(id_and_score, key=lambda x: x[1], reverse=True)[:top_k]
results = [(self.index["id2sentence"][idx], score) for idx, score in id_and_score]
results = [(self.index["sentences"][idx], score) for idx, score in id_and_score]
return results
else:
query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keep_dim=True, return_numpy=True)
query_vecs = self.encode(queries, device=device, normalize_to_unit=True, keepdim=True, return_numpy=True)

distance, idx = self.index["index"].search(query_vecs.astype(np.float32), top_k)

def pack_single_result(dist, idx):
score = [1.0 - d / 2.0 for d in dist]
results = [(self.index["id2sentence"][i], s) for i, s in zip(idx, score) if s >= threshold]
results = [(self.index["sentences"][i], s) for i, s in zip(idx, dist) if s >= threshold]
return results

if isinstance(queries, list):
Expand All @@ -183,48 +195,44 @@ def pack_single_result(dist, idx):

if __name__=="__main__":
example_sentences = [
'an animal is biting a persons finger .',
'a woman is reading .',
'a man is lifting weights in a garage .',
'a man plays the violin .',
'a man is eating food .',
'a man plays the piano .',
'a panda is climbing .',
'a man plays a guitar .',
'a woman is slicing a meat .',
'a woman is taking a picture .'
'An animal is biting a persons finger.',
'A woman is reading.',
'A man is lifting weights in a garage.',
'A man plays the violin.',
'A man is eating food.',
'A man plays the piano.',
'A panda is climbing.',
'A man plays a guitar.',
'A woman is slicing a meat.',
'A woman is taking a picture.'
]
example_queries = [
'a man is playing music',
'a woman is making a photo'
'A man is playing music.',
'A woman is making a photo.'
]

model_name = "princeton-nlp/sup-simcse-bert-base-uncased"
embedder = SentenceEmbedder(model_name)
simcse = SimCSE(model_name)

print("\n=========Calculate cosine similarities between queries and sentences============\n")
similarities = embedder.similarity(example_queries, example_sentences)
similarities = simcse.similarity(example_queries, example_sentences)
print(similarities)

print("\n=========Naive exact search============\n")
embedder.build_index(example_sentences, use_faiss=False)
results = embedder.search(example_queries)
print("\n=========Naive brute force search============\n")
simcse.build_index(example_sentences, use_faiss=False)
results = simcse.search(example_queries)
for i, result in enumerate(results):
print("retrieval results for query: {}".format(example_queries[i]))
print("Retrieval results for query: {}".format(example_queries[i]))
for sentence, score in result:
print("{} (cosine similarity: {:.4f})".format(sentence, score))
print(" {} (cosine similarity: {:.4f})".format(sentence, score))
print("")

print("\n=========Search with Faiss backend============\n")
embedder.build_index(example_sentences, use_faiss=True)
results = embedder.search(example_queries)
simcse.build_index(example_sentences, use_faiss=True)
results = simcse.search(example_queries)
for i, result in enumerate(results):
print("retrieval results for query: {}".format(example_queries[i]))
print("Retrieval results for query: {}".format(example_queries[i]))
for sentence, score in result:
print("{} (cosine similarity: {:.4f})".format(sentence, score))






print(" {} (cosine similarity: {:.4f})".format(sentence, score))
print("")

0 comments on commit a670591

Please sign in to comment.