Skip to content

Commit

Permalink
default to "FROM ." if a Modelfile isn't present (ollama#7250)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdevine authored Oct 22, 2024
1 parent 5c44461 commit d78fb62
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 12 deletions.
59 changes: 47 additions & 12 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,28 +46,58 @@ import (
"github.com/ollama/ollama/version"
)

func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
var (
errModelNotFound = errors.New("no Modelfile or safetensors files found")
errModelfileNotFound = errors.New("specified Modelfile wasn't found")
)

func getModelfileName(cmd *cobra.Command) (string, error) {
fn, _ := cmd.Flags().GetString("file")

filename := fn
if filename == "" {
filename = "Modelfile"
}

absName, err := filepath.Abs(filename)
if err != nil {
return err
return "", err
}

client, err := api.ClientFromEnvironment()
_, err = os.Stat(absName)
if err != nil {
return err
return fn, err
}

return absName, nil
}

func CreateHandler(cmd *cobra.Command, args []string) error {
p := progress.NewProgress(os.Stderr)
defer p.Stop()

f, err := os.Open(filename)
if err != nil {
var reader io.Reader

filename, err := getModelfileName(cmd)
if os.IsNotExist(err) {
if filename == "" {
reader = strings.NewReader("FROM .\n")
} else {
return errModelfileNotFound
}
} else if err != nil {
return err
} else {
f, err := os.Open(filename)
if err != nil {
return err
}

reader = f
defer f.Close()
}
defer f.Close()

modelfile, err := parser.ParseFile(f)
modelfile, err := parser.ParseFile(reader)
if err != nil {
return err
}
Expand All @@ -82,6 +112,11 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
p.Add(status, spinner)
defer p.Stop()

client, err := api.ClientFromEnvironment()
if err != nil {
return err
}

for i := range modelfile.Commands {
switch modelfile.Commands[i].Name {
case "model", "adapter":
Expand Down Expand Up @@ -220,7 +255,7 @@ func tempZipFiles(path string) (string, error) {
// covers consolidated.x.pth, consolidated.pth
files = append(files, pt...)
} else {
return "", errors.New("no safetensors or torch files found")
return "", errModelNotFound
}

// add configuration files, json files are detected as text/plain
Expand Down Expand Up @@ -1315,7 +1350,7 @@ func NewCLI() *cobra.Command {
RunE: CreateHandler,
}

createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile")
createCmd.Flags().StringP("file", "f", "", "Name of the Modelfile (default \"Modelfile\"")
createCmd.Flags().StringP("quantize", "q", "", "Quantize model to this level (e.g. q4_0)")

showCmd := &cobra.Command{
Expand Down
99 changes: 99 additions & 0 deletions cmd/cmd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,102 @@ func TestDeleteHandler(t *testing.T) {
t.Fatalf("DeleteHandler failed: expected error about stopping non-existent model, got %v", err)
}
}

func TestGetModelfileName(t *testing.T) {
tests := []struct {
name string
modelfileName string
fileExists bool
expectedName string
expectedErr error
}{
{
name: "no modelfile specified, no modelfile exists",
modelfileName: "",
fileExists: false,
expectedName: "",
expectedErr: os.ErrNotExist,
},
{
name: "no modelfile specified, modelfile exists",
modelfileName: "",
fileExists: true,
expectedName: "Modelfile",
expectedErr: nil,
},
{
name: "modelfile specified, no modelfile exists",
modelfileName: "crazyfile",
fileExists: false,
expectedName: "crazyfile",
expectedErr: os.ErrNotExist,
},
{
name: "modelfile specified, modelfile exists",
modelfileName: "anotherfile",
fileExists: true,
expectedName: "anotherfile",
expectedErr: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := &cobra.Command{
Use: "fakecmd",
}
cmd.Flags().String("file", "", "path to modelfile")

var expectedFilename string

if tt.fileExists {
tempDir, err := os.MkdirTemp("", "modelfiledir")
defer os.RemoveAll(tempDir)
if err != nil {
t.Fatalf("temp modelfile dir creation failed: %v", err)
}
var fn string
if tt.modelfileName != "" {
fn = tt.modelfileName
} else {
fn = "Modelfile"
}

tempFile, err := os.CreateTemp(tempDir, fn)
if err != nil {
t.Fatalf("temp modelfile creation failed: %v", err)
}

expectedFilename = tempFile.Name()
err = cmd.Flags().Set("file", expectedFilename)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
} else {
if tt.modelfileName != "" {
expectedFilename = tt.modelfileName
err := cmd.Flags().Set("file", tt.modelfileName)
if err != nil {
t.Fatalf("couldn't set file flag: %v", err)
}
}
}

actualFilename, actualErr := getModelfileName(cmd)

if actualFilename != expectedFilename {
t.Errorf("expected filename: '%s' actual filename: '%s'", expectedFilename, actualFilename)
}

if tt.expectedErr != os.ErrNotExist {
if actualErr != tt.expectedErr {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
} else {
if !os.IsNotExist(actualErr) {
t.Errorf("expected err: %v actual err: %v", tt.expectedErr, actualErr)
}
}
})
}
}

0 comments on commit d78fb62

Please sign in to comment.