forked from ollama/ollama
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce
/api/embed
endpoint supporting batch embedding (ollama#5127)
* Initial Batch Embedding * Revert "Initial Batch Embedding" This reverts commit c22d548. * Initial Draft * mock up notes * api/embed draft * add server function * check normalization * clean up * normalization * playing around with truncate stuff * Truncation * Truncation * move normalization to go * Integration Test Template * Truncation Integration Tests * Clean up * use float32 * move normalize * move normalize test * refactoring * integration float32 * input handling and handler testing * Refactoring of legacy and new * clear comments * merge conflicts * touches * embedding type 64 * merge conflicts * fix hanging on single string * refactoring * test values * set context length * clean up * testing clean up * testing clean up * remove function closure * Revert "remove function closure" This reverts commit 55d48c6. * remove function closure * remove redundant error check * clean up * more clean up * clean up
- Loading branch information
Showing
8 changed files
with
453 additions
and
31 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
//go:build integration | ||
|
||
package integration | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
"time" | ||
|
||
"github.com/ollama/ollama/api" | ||
) | ||
|
||
func TestAllMiniLMEmbed(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) | ||
defer cancel() | ||
|
||
req := api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: "why is the sky blue?", | ||
} | ||
|
||
res, err := embedTestHelper(ctx, t, req) | ||
|
||
if err != nil { | ||
t.Fatalf("error: %v", err) | ||
} | ||
|
||
if len(res.Embeddings) != 1 { | ||
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings)) | ||
} | ||
|
||
if len(res.Embeddings[0]) != 384 { | ||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) | ||
} | ||
|
||
if res.Embeddings[0][0] != 0.010071031 { | ||
t.Fatalf("expected 0.010071031, got %f", res.Embeddings[0][0]) | ||
} | ||
} | ||
|
||
func TestAllMiniLMBatchEmbed(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) | ||
defer cancel() | ||
|
||
req := api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: []string{"why is the sky blue?", "why is the grass green?"}, | ||
} | ||
|
||
res, err := embedTestHelper(ctx, t, req) | ||
|
||
if err != nil { | ||
t.Fatalf("error: %v", err) | ||
} | ||
|
||
if len(res.Embeddings) != 2 { | ||
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings)) | ||
} | ||
|
||
if len(res.Embeddings[0]) != 384 { | ||
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0])) | ||
} | ||
|
||
if res.Embeddings[0][0] != 0.010071031 || res.Embeddings[1][0] != -0.009802706 { | ||
t.Fatalf("expected 0.010071031 and -0.009802706, got %f and %f", res.Embeddings[0][0], res.Embeddings[1][0]) | ||
} | ||
} | ||
|
||
func TestAllMiniLmEmbedTruncate(t *testing.T) { | ||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) | ||
defer cancel() | ||
|
||
truncTrue, truncFalse := true, false | ||
|
||
type testReq struct { | ||
Name string | ||
Request api.EmbedRequest | ||
} | ||
|
||
reqs := []testReq{ | ||
{ | ||
Name: "Target Truncation", | ||
Request: api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: "why", | ||
}, | ||
}, | ||
{ | ||
Name: "Default Truncate", | ||
Request: api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: "why is the sky blue?", | ||
Options: map[string]any{"num_ctx": 1}, | ||
}, | ||
}, | ||
{ | ||
Name: "Explicit Truncate", | ||
Request: api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: "why is the sky blue?", | ||
Truncate: &truncTrue, | ||
Options: map[string]any{"num_ctx": 1}, | ||
}, | ||
}, | ||
} | ||
|
||
res := make(map[string]*api.EmbedResponse) | ||
|
||
for _, req := range reqs { | ||
response, err := embedTestHelper(ctx, t, req.Request) | ||
if err != nil { | ||
t.Fatalf("error: %v", err) | ||
} | ||
res[req.Name] = response | ||
} | ||
|
||
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] { | ||
t.Fatal("expected default request to truncate correctly") | ||
} | ||
|
||
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] { | ||
t.Fatal("expected default request and truncate true request to be the same") | ||
} | ||
|
||
// check that truncate set to false returns an error if context length is exceeded | ||
_, err := embedTestHelper(ctx, t, api.EmbedRequest{ | ||
Model: "all-minilm", | ||
Input: "why is the sky blue?", | ||
Truncate: &truncFalse, | ||
Options: map[string]any{"num_ctx": 1}, | ||
}) | ||
|
||
if err == nil { | ||
t.Fatal("expected error, got nil") | ||
} | ||
} | ||
|
||
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) { | ||
client, _, cleanup := InitServerConnection(ctx, t) | ||
defer cleanup() | ||
if err := PullIfMissing(ctx, client, req.Model); err != nil { | ||
t.Fatalf("failed to pull model %s: %v", req.Model, err) | ||
} | ||
|
||
response, err := client.Embed(ctx, &req) | ||
|
||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return response, nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.