Skip to content

Commit

Permalink
Revamp go based integration tests
Browse files Browse the repository at this point in the history
This uplevels the integration tests to run the server which can allow
testing an existing server, or a remote server.
  • Loading branch information
dhiltgen committed Mar 23, 2024
1 parent a5ba0fc commit 949b6c0
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 261 deletions.
11 changes: 11 additions & 0 deletions integration/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Integration Tests

This directory contains integration tests to exercise Ollama end-to-end to verify behavior

By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`


The integration tests have 2 modes of operating.

1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote
28 changes: 28 additions & 0 deletions integration/basic_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//go:build integration

package integration

import (
"context"
"net/http"
"testing"
"time"

"github.com/jmorganca/ollama/api"
)

func TestOrcaMiniBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: "orca-mini",
Prompt: "why is the sky blue?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
}
33 changes: 11 additions & 22 deletions server/llm_image_test.go → integration/llm_image_test.go
Original file line number Diff line number Diff line change
@@ -1,49 +1,38 @@
//go:build integration

package server
package integration

import (
"context"
"encoding/base64"
"log"
"os"
"strings"
"net/http"
"testing"
"time"

"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func TestIntegrationMultimodal(t *testing.T) {
SkipIFNoTestData(t)
image, err := base64.StdEncoding.DecodeString(imageEncoding)
require.NoError(t, err)
req := api.GenerateRequest{
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Options: map[string]interface{}{},
Model: "llava:7b",
Prompt: "what does the text in this image say?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}

resp := "the ollamas"
workDir, err := os.MkdirTemp("", "ollama")
require.NoError(t, err)
defer os.RemoveAll(workDir)
require.NoError(t, llm.Init(workDir))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
opts := api.DefaultOptions()
opts.Seed = 42
opts.Temperature = 0.0
model, llmRunner := PrepareModelForPrompts(t, req.Model, opts)
defer llmRunner.Close()
response := OneShotPromptResponse(t, ctx, req, model, llmRunner)
log.Print(response)
assert.Contains(t, strings.ToLower(response), resp)
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{resp})
}

const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
Expand Down
73 changes: 73 additions & 0 deletions integration/llm_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//go:build integration

package integration

import (
"context"
"net/http"
"sync"
"testing"
"time"

"github.com/jmorganca/ollama/api"
)

// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies

// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
//
// TODO - Fix this ^^

var (
stream = false
req = [2]api.GenerateRequest{
{
Model: "orca-mini",
Prompt: "why is the ocean blue?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
}, {
Model: "orca-mini",
Prompt: "what is the origin of the us thanksgiving holiday?",
Stream: &stream,
Options: map[string]interface{}{
"seed": 42,
"temperature": 0.0,
},
},
}
resp = [2]string{
"scattering",
"united states thanksgiving",
}
)

func TestIntegrationSimpleOrcaMini(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
GenerateTestHelper(ctx, t, &http.Client{}, req[0], []string{resp[0]})
}

// TODO
// The server always loads a new runner and closes the old one, which forces serial execution
// At present this test case fails with concurrency problems. Eventually we should try to
// get true concurrency working with n_parallel support in the backend
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
var wg sync.WaitGroup
wg.Add(len(req))
ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
defer cancel()
for i := 0; i < len(req); i++ {
go func(i int) {
defer wg.Done()
GenerateTestHelper(ctx, t, &http.Client{}, req[i], []string{resp[i]})
}(i)
}
wg.Wait()
}

// TODO - create a parallel test with 2 different models once we support concurrency
190 changes: 190 additions & 0 deletions integration/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
//go:build integration

package integration

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"log/slog"
"math/rand"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"

"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
)

func FindPort() string {
port := 0
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
return strconv.Itoa(port)
}

func GetTestEndpoint() (string, string) {
defaultPort := "11434"
ollamaHost := os.Getenv("OLLAMA_HOST")

scheme, hostport, ok := strings.Cut(ollamaHost, "://")
if !ok {
scheme, hostport = "http", ollamaHost
}

// trim trailing slashes
hostport = strings.TrimRight(hostport, "/")

host, port, err := net.SplitHostPort(hostport)
if err != nil {
host, port = "127.0.0.1", defaultPort
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
host = ip.String()
} else if hostport != "" {
host = hostport
}
}

if os.Getenv("OLLAMA_TEST_EXISTING") == "" && port == defaultPort {
port = FindPort()
}

url := fmt.Sprintf("%s:%s", host, port)
slog.Info("server connection", "url", url)
return scheme, url
}

// TODO make fanicier, grab logs, etc.
var serverMutex sync.Mutex
var serverReady bool

func StartServer(ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
CLIName, err := filepath.Abs("../ollama")
if err != nil {
return err
}

if runtime.GOOS == "windows" {
CLIName += ".exe"
}
_, err = os.Stat(CLIName)
if err != nil {
return fmt.Errorf("CLI missing, did you forget to build first? %w", err)
}
serverMutex.Lock()
defer serverMutex.Unlock()
if serverReady {
return nil
}

if tmp := os.Getenv("OLLAMA_HOST"); tmp != ollamaHost {
slog.Info("setting env", "OLLAMA_HOST", ollamaHost)
os.Setenv("OLLAMA_HOST", ollamaHost)
}

slog.Info("starting server", "url", ollamaHost)
done, err := lifecycle.SpawnServer(ctx, "../ollama")
if err != nil {
return fmt.Errorf("failed to start server: %w", err)
}

go func() {
<-ctx.Done()
serverMutex.Lock()
defer serverMutex.Unlock()
exitCode := <-done
if exitCode > 0 {
slog.Warn("server failure", "exit", exitCode)
}
serverReady = false
}()

// TODO wait only long enough for the server to be responsive...
time.Sleep(500 * time.Millisecond)

serverReady = true
return nil
}

func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
requestJSON, err := json.Marshal(genReq)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
}
defer func() {
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
// TODO
fp, err := os.Open(lifecycle.ServerLogFile)
if err != nil {
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
slog.Warn("SERVER LOG FOLLOWS")
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}()
scheme, testEndpoint := GetTestEndpoint()

if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
assert.NoError(t, StartServer(ctx, testEndpoint))
}

// Make the request and get the response
req, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/generate", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}

// Set the content type for the request
req.Header.Set("Content-Type", "application/json")

// Make the request with the HTTP client
response, err := client.Do(req.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))

// Verify the response is valid JSON
var payload api.GenerateResponse
err = json.Unmarshal(body, &payload)
if err != nil {
assert.NoError(t, err, body)
}

// Verify the response contains the expected data
for _, resp := range anyResp {
assert.Contains(t, strings.ToLower(payload.Response), resp)
}
}
Loading

0 comments on commit 949b6c0

Please sign in to comment.