Skip to content

Commit

Permalink
Convert Safetensors to an Ollama model (ollama#2824)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdevine authored Mar 7, 2024
1 parent 0ded7fd commit 2c017ca
Show file tree
Hide file tree
Showing 9 changed files with 3,083 additions and 153 deletions.
97 changes: 89 additions & 8 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"archive/zip"
"bytes"
"context"
"crypto/ed25519"
Expand Down Expand Up @@ -87,22 +88,82 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = filepath.Join(filepath.Dir(filename), path)
}

bin, err := os.Open(path)
fi, err := os.Stat(path)
if errors.Is(err, os.ErrNotExist) && c.Name == "model" {
continue
} else if err != nil {
return err
}
defer bin.Close()

hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return err
// TODO make this work w/ adapters
if fi.IsDir() {
tf, err := os.CreateTemp("", "ollama-tf")
if err != nil {
return err
}
defer os.RemoveAll(tf.Name())

zf := zip.NewWriter(tf)

files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
if err != nil {
return err
}

if len(files) == 0 {
return fmt.Errorf("no safetensors files were found in '%s'", path)
}

// add the safetensor config file + tokenizer
files = append(files, filepath.Join(path, "config.json"))
files = append(files, filepath.Join(path, "added_tokens.json"))
files = append(files, filepath.Join(path, "tokenizer.model"))

for _, fn := range files {
f, err := os.Open(fn)
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
continue
} else if err != nil {
return err
}

fi, err := f.Stat()
if err != nil {
return err
}

h, err := zip.FileInfoHeader(fi)
if err != nil {
return err
}

h.Name = filepath.Base(fn)
h.Method = zip.Store

w, err := zf.CreateHeader(h)
if err != nil {
return err
}

_, err = io.Copy(w, f)
if err != nil {
return err
}

}

if err := zf.Close(); err != nil {
return err
}

if err := tf.Close(); err != nil {
return err
}
path = tf.Name()
}
bin.Seek(0, io.SeekStart)

digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
digest, err := createBlob(cmd, client, path)
if err != nil {
return err
}

Expand Down Expand Up @@ -141,6 +202,26 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
return nil
}

func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) {
bin, err := os.Open(path)
if err != nil {
return "", err
}
defer bin.Close()

hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil {
return "", err
}
bin.Seek(0, io.SeekStart)

digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
return "", err
}
return digest, nil
}

func RunHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
Expand Down
Loading

0 comments on commit 2c017ca

Please sign in to comment.