diff --git a/api/client.go b/api/client.go index 4e434faea3c..a0aa2afea61 100644 --- a/api/client.go +++ b/api/client.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -63,20 +64,19 @@ func (c *Client) stream(ctx context.Context, method string, path string, reqData for { line, err := reader.ReadBytes('\n') - if err != nil { - if err == io.EOF { - break - } else { - return err // Handle other errors - } + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return err } + if err := checkError(res, line); err != nil { return err } + callback(bytes.TrimSuffix(line, []byte("\n"))) } - - return nil } func (c *Client) do(ctx context.Context, method string, path string, reqData any, respData any) error { @@ -124,11 +124,9 @@ func (c *Client) do(ctx context.Context, method string, path string, reqData any return nil } -func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(token string)) (*GenerateResponse, error) { +func (c *Client) Generate(ctx context.Context, req *GenerateRequest, callback func(bts []byte)) (*GenerateResponse, error) { var res GenerateResponse - if err := c.stream(ctx, http.MethodPost, "/api/generate", req, func(token []byte) { - callback(string(token)) - }); err != nil { + if err := c.stream(ctx, http.MethodPost, "/api/generate", req, callback); err != nil { return nil, err } diff --git a/cmd/cmd.go b/cmd/cmd.go index 1619be09aec..8df05117b42 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -1,7 +1,9 @@ package cmd import ( + "bufio" "context" + "encoding/json" "fmt" "log" "net" @@ -10,9 +12,11 @@ import ( "sync" "github.com/gosuri/uiprogress" + "github.com/spf13/cobra" + "golang.org/x/term" + "github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/server" - "github.com/spf13/cobra" ) func cacheDir() string { @@ -28,13 +32,13 @@ func bytesToGB(bytes int) float64 { return float64(bytes) / float64(1<<30) } -func run(model string) error { +func RunRun(cmd *cobra.Command, args []string) error { client, err := NewAPIClient() if err != nil { return err } pr := api.PullRequest{ - Model: model, + Model: args[0], } var bar *uiprogress.Bar mutex := &sync.Mutex{} @@ -60,10 +64,71 @@ func run(model string) error { return err } fmt.Println("Up to date.") + return RunGenerate(cmd, args) +} + +func RunGenerate(_ *cobra.Command, args []string) error { + if len(args) > 1 { + return generate(args[0], args[1:]...) + } + + if term.IsTerminal(int(os.Stdin.Fd())) { + return generateInteractive(args[0]) + } + + return generateBatch(args[0]) +} + +func generate(model string, prompts ...string) error { + client, err := NewAPIClient() + if err != nil { + return err + } + + for _, prompt := range prompts { + client.Generate(context.Background(), &api.GenerateRequest{Model: model, Prompt: prompt}, func(bts []byte) { + var resp api.GenerateResponse + if err := json.Unmarshal(bts, &resp); err != nil { + return + } + + fmt.Print(resp.Response) + }) + } + + fmt.Println() + fmt.Println() + return nil +} + +func generateInteractive(model string) error { + fmt.Print(">>> ") + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + if err := generate(model, scanner.Text()); err != nil { + return err + } + + fmt.Print(">>> ") + } + return nil } -func serve() error { +func generateBatch(model string) error { + scanner := bufio.NewScanner(os.Stdin) + for scanner.Scan() { + prompt := scanner.Text() + fmt.Printf(">>> %s\n", prompt) + if err := generate(model, prompt); err != nil { + return err + } + } + + return nil +} + +func RunServer(_ *cobra.Command, _ []string) error { ln, err := net.Listen("tcp", "127.0.0.1:11434") if err != nil { return err @@ -82,39 +147,32 @@ func NewCLI() *cobra.Command { log.SetFlags(log.LstdFlags | log.Lshortfile) rootCmd := &cobra.Command{ - Use: "ollama", - Short: "Large language model runner", + Use: "ollama", + Short: "Large language model runner", + SilenceUsage: true, CompletionOptions: cobra.CompletionOptions{ DisableDefaultCmd: true, }, - PersistentPreRun: func(cmd *cobra.Command, args []string) { - // Disable usage printing on errors - cmd.SilenceUsage = true + PersistentPreRunE: func(_ *cobra.Command, args []string) error { // create the models directory and it's parent - if err := os.MkdirAll(path.Join(cacheDir(), "models"), 0o700); err != nil { - panic(err) - } + return os.MkdirAll(path.Join(cacheDir(), "models"), 0o700) }, } cobra.EnableCommandSorting = false runCmd := &cobra.Command{ - Use: "run MODEL", + Use: "run MODEL [PROMPT]", Short: "Run a model", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - return run(args[0]) - }, + Args: cobra.MinimumNArgs(1), + RunE: RunRun, } serveCmd := &cobra.Command{ Use: "serve", Aliases: []string{"start"}, Short: "Start ollama", - RunE: func(cmd *cobra.Command, args []string) error { - return serve() - }, + RunE: RunServer, } rootCmd.AddCommand( diff --git a/go.mod b/go.mod index 6ca336d1cce..6169db4c308 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( golang.org/x/crypto v0.10.0 // indirect golang.org/x/net v0.10.0 // indirect golang.org/x/sys v0.10.0 // indirect + golang.org/x/term v0.10.0 golang.org/x/text v0.10.0 // indirect google.golang.org/protobuf v1.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 065bb0db2c4..92f13ede5c7 100644 --- a/go.sum +++ b/go.sum @@ -106,6 +106,8 @@ golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/term v0.10.0 h1:3R7pNqamzBraeqj/Tj8qt1aQ2HpmlC+Cx/qL/7hn4/c= +golang.org/x/term v0.10.0/go.mod h1:lpqdcUyK/oCiQxvxVrppt5ggO2KCZ5QblwqPnfZ6d5o= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=