Skip to content

Commit

Permalink
move to sentence_transformers for embedding script
Browse files Browse the repository at this point in the history
  • Loading branch information
enjalot committed Jul 12, 2024
1 parent 2e786da commit 1793879
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 31 deletions.
65 changes: 35 additions & 30 deletions latentscope/models/providers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,55 @@ def __init__(self, name, params):
self.torch = torch
self.device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

def cls_pooling(self, model_output):
return model_output[0][:, 0]
# def cls_pooling(self, model_output):
# return model_output[0][:, 0]

def average_pooling(self, model_output, attention_mask):
last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
# def average_pooling(self, model_output, attention_mask):
# last_hidden = model_output.last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0)
# return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def mean_pooling(self, model_output, attention_mask):
token_embeddings = model_output[0]
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
return self.torch.sum(token_embeddings * input_mask_expanded, 1) / self.torch.clamp(input_mask_expanded.sum(1), min=1e-9)
# def mean_pooling(self, model_output, attention_mask):
# token_embeddings = model_output[0]
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# return self.torch.sum(token_embeddings * input_mask_expanded, 1) / self.torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def load_model(self):
from transformers import AutoTokenizer, AutoModel
# from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

if "rps" in self.params and self.params["rps"]:
self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True, safe_serialization=True, rotary_scaling_factor=2 )
else:
self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True)
self.model = SentenceTransformer(self.name, trust_remote_code=True)
self.tokenizer = self.model.tokenizer

print("CONFIG", self.model.config)
# if "rps" in self.params and self.params["rps"]:
# self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True, safe_serialization=True, rotary_scaling_factor=2 )
# else:
# self.model = AutoModel.from_pretrained(self.name, trust_remote_code=True)

if self.name == "nomic-ai/nomic-embed-text-v1" or self.name == "nomic-ai/nomic-embed-text-v1.5":
self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", model_max_length=self.params["max_tokens"])
else:
self.tokenizer = AutoTokenizer.from_pretrained(self.name)
# print("CONFIG", self.model.config)

# if self.name == "nomic-ai/nomic-embed-text-v1" or self.name == "nomic-ai/nomic-embed-text-v1.5":
# self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", model_max_length=self.params["max_tokens"])
# else:
# self.tokenizer = AutoTokenizer.from_pretrained(self.name)

self.model.to(self.device)
self.model.eval()

def embed(self, inputs, dimensions=None):
encoded_input = self.tokenizer(inputs, padding=self.params["padding"], truncation=self.params["truncation"], return_tensors='pt')
encoded_input = {key: value.to(self.device) for key, value in encoded_input.items()}
pool = self.params["pooling"]
# encoded_input = self.tokenizer(inputs, padding=self.params["padding"], truncation=self.params["truncation"], return_tensors='pt')
# encoded_input = {key: value.to(self.device) for key, value in encoded_input.items()}
# pool = self.params["pooling"]
# Compute token embeddings
with self.torch.no_grad():
model_output = self.model(**encoded_input)
if pool == "cls":
embeddings = self.cls_pooling(model_output)
elif pool == "average":
embeddings = self.average_pooling(model_output, encoded_input["attention_mask"])
elif pool == "mean":
embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])
# with self.torch.no_grad():
# model_output = self.model(**encoded_input)
# if pool == "cls":
# embeddings = self.cls_pooling(model_output)
# elif pool == "average":
# embeddings = self.average_pooling(model_output, encoded_input["attention_mask"])
# elif pool == "mean":
# embeddings = self.mean_pooling(model_output, encoded_input["attention_mask"])

embeddings = self.model.encode(inputs, convert_to_tensor=True)
# Support Matroyshka embeddings
if dimensions is not None and dimensions > 0:
embeddings = self.torch.nn.functional.layer_norm(embeddings, normalized_shape=(embeddings.shape[1],))
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ safetensors~=0.4.1
scikit-learn~=1.3.2
scipy~=1.11.4
Send2Trash~=1.8.2
sentence-transformers~=3.0.1
six~=1.16.0
sniffio~=1.3.0
soupsieve~=2.5
Expand All @@ -159,7 +160,7 @@ tzdata~=2023.4
umap-learn~=0.5.5
uri-template~=1.3.0
urllib3~=2.1.0
voyageai~=0.1.6
voyageai~=0.2.3
wcwidth~=0.2.13
webcolors~=1.13
webencodings~=0.5.1
Expand Down

0 comments on commit 1793879

Please sign in to comment.