Skip to content

Commit

Permalink
feat: supports embeddings for T5 and ChatGLM family generation (#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
aarnphm authored Jul 27, 2023
1 parent e075bd2 commit 15640a8
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 13 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ client = openllm.client.HTTPClient("http://localhost:3000")
client.embed("I like to eat apples")
```

> **Note**: Currently, only LlaMA models and variants support embeddings.
> **Note**: Currently, the following model framily supports embeddings: Llama, T5 (Flan-T5, FastChat, etc.), ChatGLM
## ⚙️ Integrations

Expand Down
1 change: 1 addition & 0 deletions changelog.d/153.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added embeddings support for T5 and ChatGLM
2 changes: 1 addition & 1 deletion src/openllm/_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def metadata_v1(_: str) -> openllm.MetadataOutput: return openllm.MetadataOutput
@svc.api(input=bentoml.io.JSON.from_sample(sample=["Hey Jude, welcome to the jumgle!", "What is the meaning of life?"]), output=bentoml.io.JSON.from_sample(sample={"embeddings": [0.007917795330286026, -0.014421648345887661, 0.00481307040899992, 0.007331526838243008, -0.0066398633643984795, 0.00945580005645752, 0.0087016262114048, -0.010709521360695362, 0.012635177001357079, 0.010541186667978764, -0.00730888033285737, -0.001783102168701589, 0.02339819073677063, -0.010825827717781067, -0.015888236463069916, 0.01876218430697918, 0.0076906150206923485, 0.0009032754460349679, -0.010024012066423893, 0.01090280432254076, -0.008668390102684498, 0.02070549875497818, 0.0014594447566196322, -0.018775740638375282, -0.014814382418990135, 0.01796768605709076], "num_tokens": 20}), route="/v1/embeddings")
async def embeddings_v1(phrases: list[str]) -> openllm.EmbeddingsOutput:
responses = await runner.embeddings.async_run(phrases)
return openllm.EmbeddingsOutput(embeddings=t.cast(t.List[t.List[float]], responses["embeddings"].tolist())[0], num_tokens=responses["num_tokens"])
return openllm.EmbeddingsOutput(embeddings=responses["embeddings"], num_tokens=responses["num_tokens"])
if runner.supports_hf_agent and openllm.utils.is_transformers_supports_agent():
@attr.define
class HfAgentInput:
Expand Down
2 changes: 1 addition & 1 deletion src/openllm/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class PeftAdapterOutput(t.TypedDict):


class LLMEmbeddings(t.TypedDict):
embeddings: torch.Tensor
embeddings: t.List[t.List[float]]
num_tokens: int


Expand Down
19 changes: 14 additions & 5 deletions src/openllm/models/chatglm/modeling_chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
from __future__ import annotations
import typing as t
import openllm
from ..._llm import LLMEmbeddings
if t.TYPE_CHECKING:
import torch
import transformers
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
transformers = openllm.utils.LazyLoader("transformers", globals(), "transformers")
import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
class ChatGLM(openllm.LLM["transformers.PreTrainedModel", "transformers.PreTrainedTokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, num_beams: int | None = None, top_p: float | None = None, temperature: float | None = None, chat_history: list[str] | None = None, use_default_prompt_template: bool = False, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
Expand All @@ -44,3 +42,14 @@ def generate(self, prompt: str, **attrs: t.Any) -> tuple[str, list[tuple[str, st
# Only use half precision if the model is not yet quantized
if self.config.use_half_precision: self.model.half()
return self.model.chat(self.tokenizer, prompt, generation_config=self.config.model_construct_env(**attrs).to_generation_config())
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, output_hidden_states=True)
data = F.normalize(torch.mean(outputs.hidden_states[-1].transpose(0, 1), dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
18 changes: 14 additions & 4 deletions src/openllm/models/flan_t5/modeling_flan_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
import typing as t
import openllm
from .configuration_flan_t5 import DEFAULT_PROMPT_TEMPLATE
from ..._llm import LLMEmbeddings
from ..._prompt import default_formatter
if t.TYPE_CHECKING:
import torch
import transformers # noqa: F401
else:
torch = openllm.utils.LazyLoader("torch", globals(), "torch")
import torch, transformers, torch.nn.functional as F
else: torch, transformers, F = openllm.utils.LazyLoader("torch", globals(), "torch"), openllm.utils.LazyLoader("transformers", globals(), "transformers"), openllm.utils.LazyLoader("F", globals(), "torch.nn.functional")
class FlanT5(openllm.LLM["transformers.T5ForConditionalGeneration", "transformers.T5TokenizerFast"]):
__openllm_internal__ = True
def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, temperature: float | None = None, top_k: int | None = None, top_p: float | None = None, repetition_penalty: float | None = None, use_default_prompt_template: bool = True, **attrs: t.Any) -> tuple[str, dict[str, t.Any], dict[str, t.Any]]:
Expand All @@ -35,3 +34,14 @@ def sanitize_parameters(self, prompt: str, max_new_tokens: int | None = None, te
def postprocess_generate(self, prompt: str, generation_result: t.Sequence[str], **_: t.Any) -> str: return generation_result[0]
def generate(self, prompt: str, **attrs: t.Any) -> list[str]:
with torch.inference_mode(): return self.tokenizer.batch_decode(self.model.generate(**self.tokenizer(prompt, return_tensors="pt").to(self.device), do_sample=True, generation_config=self.config.model_construct_env(**attrs).to_generation_config()), skip_special_tokens=True)
def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
embeddings: list[list[float]] = []
num_tokens = 0
for prompt in prompts:
input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
with torch.inference_mode():
outputs = self.model(input_ids, decoder_input_ids=input_ids)
data = F.normalize(torch.mean(outputs.encoder_last_hidden_state[0], dim=0), p=2, dim=0)
embeddings.append(data.tolist())
num_tokens += len(input_ids[0])
return LLMEmbeddings(embeddings=embeddings, num_tokens=num_tokens)
2 changes: 1 addition & 1 deletion src/openllm/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def embeddings(self, prompts: list[str]) -> LLMEmbeddings:
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings, seq_length = torch.sum(masked_embeddings, dim=1), torch.sum(mask, dim=1)
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1), num_tokens=torch.sum(attention_mask).item())
return LLMEmbeddings(embeddings=F.normalize(sum_embeddings / seq_length, p=2, dim=1).tolist(), num_tokens=torch.sum(attention_mask).item())

0 comments on commit 15640a8

Please sign in to comment.