Skip to content

Commit

Permalink
Introduce /api/embed endpoint supporting batch embedding (ollama#5127)
Browse files Browse the repository at this point in the history
* Initial Batch Embedding

* Revert "Initial Batch Embedding"

This reverts commit c22d548.

* Initial Draft

* mock up notes

* api/embed draft

* add server function

* check normalization

* clean up

* normalization

* playing around with truncate stuff

* Truncation

* Truncation

* move normalization to go

* Integration Test Template

* Truncation Integration Tests

* Clean up

* use float32

* move normalize

* move normalize test

* refactoring

* integration float32

* input handling and handler testing

* Refactoring of legacy and new

* clear comments

* merge conflicts

* touches

* embedding type 64

* merge conflicts

* fix hanging on single string

* refactoring

* test values

* set context length

* clean up

* testing clean up

* testing clean up

* remove function closure

* Revert "remove function closure"

This reverts commit 55d48c6.

* remove function closure

* remove redundant error check

* clean up

* more clean up

* clean up
  • Loading branch information
royjhan authored Jul 15, 2024
1 parent e9f7f36 commit b9f5e16
Show file tree
Hide file tree
Showing 8 changed files with 453 additions and 31 deletions.
11 changes: 10 additions & 1 deletion api/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,16 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil
}

// Embeddings generates embeddings from a model.
// Embed generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}

// Embeddings generates an embedding from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
Expand Down
24 changes: 24 additions & 0 deletions api/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"`
}

// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`

// Input is the input to embed.
Input any `json:"input"`

// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`

Truncate *bool `json:"truncate,omitempty"`

// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}

// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings,omitempty"`
}

// EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct {
// Model is the model name.
Expand Down
152 changes: 152 additions & 0 deletions integration/embed_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
//go:build integration

package integration

import (
"context"
"testing"
"time"

"github.com/ollama/ollama/api"
)

func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}

res, err := embedTestHelper(ctx, t, req)

if err != nil {
t.Fatalf("error: %v", err)
}

if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}

if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}

if res.Embeddings[0][0] != 0.010071031 {
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0])
}
}

func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}

res, err := embedTestHelper(ctx, t, req)

if err != nil {
t.Fatalf("error: %v", err)
}

if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}

if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}

if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 {
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}

func TestAllMiniLmEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()

truncTrue, truncFalse := true, false

type testReq struct {
Name string
Request api.EmbedRequest
}

reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}

res := make(map[string]*api.EmbedResponse)

for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}

if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}

if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}

// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})

if err == nil {
t.Fatal("expected error, got nil")
}
}

func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}

response, err := client.Embed(ctx, &req)

if err != nil {
return nil, err
}

return response, nil
}
39 changes: 23 additions & 16 deletions llm/ext_server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3188,26 +3188,33 @@ int main(int argc, char **argv) {
prompt = "";
}

json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
if (prompt.size() == 1) {
prompt = prompt[0];
}

// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);

// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);

// get the result
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}

// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
responses = result.result_json.value("results", std::vector<json>{result.result_json});
json embeddings = json::array();
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
});

// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
Expand Down
16 changes: 8 additions & 8 deletions llm/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embed(ctx context.Context, input []string) ([][]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
Expand Down Expand Up @@ -867,15 +867,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}

type EmbeddingRequest struct {
Content string `json:"content"`
type EmbedRequest struct {
Content []string `json:"content"`
}

type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
}

func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
Expand All @@ -890,7 +890,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}

data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbedRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
Expand All @@ -917,7 +917,7 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("%s", body)
}

var embedding EmbeddingResponse
var embedding EmbedResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
Expand Down
Loading

0 comments on commit b9f5e16

Please sign in to comment.