diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 262960df9a6ff..33cb0ee4e2b53 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -302,7 +302,7 @@ jobs: echo "cover=false" >> $GITHUB_OUTPUT fi - gotestsum --junitfile="gotests.xml" --packages="./..." -- -parallel=8 -timeout=5m -short -failfast $COVERAGE_FLAGS + gotestsum --junitfile="gotests.xml" --packages="./..." -- -parallel=8 -timeout=7m -short -failfast $COVERAGE_FLAGS - uses: actions/upload-artifact@v3 if: success() || failure() diff --git a/Makefile b/Makefile index 1ea380add9e1d..0e8508c2a81aa 100644 --- a/Makefile +++ b/Makefile @@ -501,8 +501,6 @@ docs/admin/prometheus.md: scripts/metricsdocgen/main.go scripts/metricsdocgen/me yarn run format:write:only ../docs/admin/prometheus.md docs/cli.md: scripts/clidocgen/main.go $(GO_SRC_FILES) docs/manifest.json - # TODO(@ammario): re-enable server.md once we finish clibase migration. - ls ./docs/cli/*.md | grep -vP "\/coder_server" | xargs rm BASE_PATH="." go run ./scripts/clidocgen cd site yarn run format:write:only ../docs/cli.md ../docs/cli/*.md ../docs/manifest.json @@ -519,7 +517,7 @@ coderd/apidoc/swagger.json: $(shell find ./scripts/apidocgen $(FIND_EXCLUSIONS) update-golden-files: cli/testdata/.gen-golden helm/tests/testdata/.gen-golden .PHONY: update-golden-files -cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(GO_SRC_FILES) +cli/testdata/.gen-golden: $(wildcard cli/testdata/*.golden) $(wildcard cli/*.tpl) $(GO_SRC_FILES) go test ./cli -run=TestCommandHelp -update touch "$@" diff --git a/cli/agent.go b/cli/agent.go index b3086815b2bad..d9058912e8207 100644 --- a/cli/agent.go +++ b/cli/agent.go @@ -16,7 +16,6 @@ import ( "time" "cloud.google.com/go/compute/metadata" - "github.com/spf13/cobra" "golang.org/x/xerrors" "gopkg.in/natefinch/lumberjack.v2" @@ -25,11 +24,11 @@ import ( "github.com/coder/coder/agent" "github.com/coder/coder/agent/reaper" "github.com/coder/coder/buildinfo" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/codersdk/agentsdk" ) -func workspaceAgent() *cobra.Command { +func (r *RootCmd) workspaceAgent() *clibase.Cmd { var ( auth string logDir string @@ -37,22 +36,15 @@ func workspaceAgent() *cobra.Command { noReap bool sshMaxTimeout time.Duration ) - cmd := &cobra.Command{ - Use: "agent", + cmd := &clibase.Cmd{ + Use: "agent", + Short: `Starts the Coder workspace agent.`, // This command isn't useful to manually execute. Hidden: true, - RunE: func(cmd *cobra.Command, _ []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + Handler: func(inv *clibase.Invocation) error { + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - rawURL, err := cmd.Flags().GetString(varAgentURL) - if err != nil { - return xerrors.Errorf("CODER_AGENT_URL must be set: %w", err) - } - coderURL, err := url.Parse(rawURL) - if err != nil { - return xerrors.Errorf("parse %q: %w", rawURL, err) - } agentPorts := map[int]string{} isLinux := runtime.GOOS == "linux" @@ -65,7 +57,7 @@ func workspaceAgent() *cobra.Command { MaxSize: 5, // MB } defer logWriter.Close() - logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) + logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) logger.Info(ctx, "spawning reaper process") // Do not start a reaper on the child process. It's important @@ -107,15 +99,15 @@ func workspaceAgent() *cobra.Command { logWriter := &closeWriter{w: ljLogger} defer logWriter.Close() - logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) + logger := slog.Make(sloghuman.Sink(inv.Stderr), sloghuman.Sink(logWriter)).Leveled(slog.LevelDebug) version := buildinfo.Version() logger.Info(ctx, "starting agent", - slog.F("url", coderURL), + slog.F("url", r.agentURL), slog.F("auth", auth), slog.F("version", version), ) - client := agentsdk.New(coderURL) + client := agentsdk.New(r.agentURL) client.SDK.Logger = logger // Set a reasonable timeout so requests can't hang forever! // The timeout needs to be reasonably long, because requests @@ -139,7 +131,7 @@ func workspaceAgent() *cobra.Command { var exchangeToken func(context.Context) (agentsdk.AuthenticateResponse, error) switch auth { case "token": - token, err := cmd.Flags().GetString(varAgentToken) + token, err := inv.ParsedFlags().GetString(varAgentToken) if err != nil { return xerrors.Errorf("CODER_AGENT_TOKEN must be set for token auth: %w", err) } @@ -220,11 +212,44 @@ func workspaceAgent() *cobra.Command { }, } - cliflag.StringVarP(cmd.Flags(), &auth, "auth", "", "CODER_AGENT_AUTH", "token", "Specify the authentication type to use for the agent") - cliflag.StringVarP(cmd.Flags(), &logDir, "log-dir", "", "CODER_AGENT_LOG_DIR", os.TempDir(), "Specify the location for the agent log files") - cliflag.StringVarP(cmd.Flags(), &pprofAddress, "pprof-address", "", "CODER_AGENT_PPROF_ADDRESS", "127.0.0.1:6060", "The address to serve pprof.") - cliflag.BoolVarP(cmd.Flags(), &noReap, "no-reap", "", "", false, "Do not start a process reaper.") - cliflag.DurationVarP(cmd.Flags(), &sshMaxTimeout, "ssh-max-timeout", "", "CODER_AGENT_SSH_MAX_TIMEOUT", time.Duration(0), "Specify the max timeout for a SSH connection") + cmd.Options = clibase.OptionSet{ + { + Flag: "auth", + Default: "token", + Description: "Specify the authentication type to use for the agent.", + Env: "CODER_AGENT_AUTH", + Value: clibase.StringOf(&auth), + }, + { + Flag: "log-dir", + Default: os.TempDir(), + Description: "Specify the location for the agent log files.", + Env: "CODER_AGENT_LOG_DIR", + Value: clibase.StringOf(&logDir), + }, + { + Flag: "pprof-address", + Default: "127.0.0.1:6060", + Env: "CODER_AGENT_PPROF_ADDRESS", + Value: clibase.StringOf(&pprofAddress), + Description: "The address to serve pprof.", + }, + { + Flag: "no-reap", + + Env: "", + Description: "Do not start a process reaper.", + Value: clibase.BoolOf(&noReap), + }, + { + Flag: "ssh-max-timeout", + Default: "0", + Env: "CODER_AGENT_SSH_MAX_TIMEOUT", + Description: "Specify the max timeout for a SSH connection.", + Value: clibase.DurationOf(&sshMaxTimeout), + }, + } + return cmd } diff --git a/cli/agent_test.go b/cli/agent_test.go index b285d7dba9e45..3911258cc19d6 100644 --- a/cli/agent_test.go +++ b/cli/agent_test.go @@ -16,7 +16,7 @@ import ( "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/provisioner/echo" "github.com/coder/coder/provisionersdk/proto" - "github.com/coder/coder/testutil" + "github.com/coder/coder/pty/ptytest" ) func TestWorkspaceAgent(t *testing.T) { @@ -40,24 +40,20 @@ func TestWorkspaceAgent(t *testing.T) { coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) logDir := t.TempDir() - cmd, _ := clitest.New(t, + inv, _ := clitest.New(t, "agent", "--auth", "token", "--agent-token", authToken, "--agent-url", client.URL.String(), "--log-dir", logDir, ) - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) - defer cancel() - errC := make(chan error, 1) - go func() { - errC <- cmd.ExecuteContext(ctx) - }() - coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - cancel() - err := <-errC - require.NoError(t, err) + pty := ptytest.New(t).Attach(inv) + + clitest.Start(t, inv) + pty.ExpectMatch("starting agent") + + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) info, err := os.Stat(filepath.Join(logDir, "coder-agent.log")) require.NoError(t, err) @@ -96,16 +92,14 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String()) + inv, _ := clitest.New(t, "agent", "--auth", "azure-instance-identity", "--agent-url", client.URL.String()) + inv = inv.WithContext( + //nolint:revive,staticcheck + context.WithValue(inv.Context(), "azure-client", metadataClient), + ) ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - errC := make(chan error) - go func() { - // A linting error occurs for weakly typing the context value here. - //nolint // The above seems reasonable for a one-off test. - ctx := context.WithValue(ctx, "azure-client", metadataClient) - errC <- cmd.ExecuteContext(ctx) - }() + clitest.Start(t, inv) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) workspace, err := client.Workspace(ctx, workspace.ID) require.NoError(t, err) @@ -117,9 +111,6 @@ func TestWorkspaceAgent(t *testing.T) { require.NoError(t, err) defer dialer.Close() require.True(t, dialer.AwaitReachable(context.Background())) - cancelFunc() - err = <-errC - require.NoError(t, err) }) t.Run("AWS", func(t *testing.T) { @@ -154,36 +145,29 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String()) - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - errC := make(chan error) - go func() { - // A linting error occurs for weakly typing the context value here. - //nolint // The above seems reasonable for a one-off test. - ctx := context.WithValue(ctx, "aws-client", metadataClient) - errC <- cmd.ExecuteContext(ctx) - }() + inv, _ := clitest.New(t, "agent", "--auth", "aws-instance-identity", "--agent-url", client.URL.String()) + inv = inv.WithContext( + //nolint:revive,staticcheck + context.WithValue(inv.Context(), "aws-client", metadataClient), + ) + clitest.Start(t, inv) coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) - workspace, err := client.Workspace(ctx, workspace.ID) + workspace, err := client.Workspace(inv.Context(), workspace.ID) require.NoError(t, err) resources := workspace.LatestBuild.Resources if assert.NotEmpty(t, resources) && assert.NotEmpty(t, resources[0].Agents) { assert.NotEmpty(t, resources[0].Agents[0].Version) } - dialer, err := client.DialWorkspaceAgent(ctx, resources[0].Agents[0].ID, nil) + dialer, err := client.DialWorkspaceAgent(inv.Context(), resources[0].Agents[0].ID, nil) require.NoError(t, err) defer dialer.Close() require.True(t, dialer.AwaitReachable(context.Background())) - cancelFunc() - err = <-errC - require.NoError(t, err) }) t.Run("GoogleCloud", func(t *testing.T) { t.Parallel() instanceID := "instanceidentifier" - validator, metadata := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) + validator, metadataClient := coderdtest.NewGoogleInstanceIdentity(t, instanceID, false) client := coderdtest.New(t, &coderdtest.Options{ GoogleTokenValidator: validator, IncludeProvisionerDaemon: true, @@ -212,16 +196,18 @@ func TestWorkspaceAgent(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, _ := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String()) - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - errC := make(chan error) - go func() { - // A linting error occurs for weakly typing the context value here. - //nolint // The above seems reasonable for a one-off test. - ctx := context.WithValue(ctx, "gcp-client", metadata) - errC <- cmd.ExecuteContext(ctx) - }() + inv, cfg := clitest.New(t, "agent", "--auth", "google-instance-identity", "--agent-url", client.URL.String()) + ptytest.New(t).Attach(inv) + clitest.SetupConfig(t, client, cfg) + clitest.Start(t, + inv.WithContext( + //nolint:revive,staticcheck + context.WithValue(context.Background(), "gcp-client", metadataClient), + ), + ) + + ctx := inv.Context() + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) workspace, err := client.Workspace(ctx, workspace.ID) require.NoError(t, err) @@ -248,9 +234,5 @@ func TestWorkspaceAgent(t *testing.T) { require.NoError(t, err) _, err = uuid.Parse(strings.TrimSpace(string(token))) require.NoError(t, err) - - cancelFunc() - err = <-errC - require.NoError(t, err) }) } diff --git a/cli/clibase/clibase.go b/cli/clibase/clibase.go index cbca6b81433cc..bdad2e97c36a6 100644 --- a/cli/clibase/clibase.go +++ b/cli/clibase/clibase.go @@ -1,10 +1,6 @@ // Package clibase offers an all-in-one solution for a highly configurable CLI -// application. Within Coder, we use it for our `server` subcommand, which -// demands more functionality than cobra/viper can offer. -// -// We will extend its usage to the rest of our application, completely replacing -// cobra/viper. It's also a candidate to be broken out into its own open-source -// library, so we avoid deep coupling with Coder concepts. +// application. Within Coder, we use it for all of our subcommands, which +// demands more functionality than cobra/viber offers. // // The Command interface is loosely based on the chi middleware pattern and // http.Handler/HandlerFunc. diff --git a/cli/clibase/cmd.go b/cli/clibase/cmd.go index 1b91f8e976a00..313a2afc9454a 100644 --- a/cli/clibase/cmd.go +++ b/cli/clibase/cmd.go @@ -3,11 +3,15 @@ package clibase import ( "context" "errors" + "flag" + "fmt" "io" "os" "strings" + "unicode" "github.com/spf13/pflag" + "golang.org/x/exp/slices" "golang.org/x/xerrors" ) @@ -47,14 +51,70 @@ type Cmd struct { HelpHandler HandlerFunc } +// AddSubcommands adds the given subcommands, setting their +// Parent field automatically. +func (c *Cmd) AddSubcommands(cmds ...*Cmd) { + for _, cmd := range cmds { + cmd.Parent = c + c.Children = append(c.Children, cmd) + } +} + // Walk calls fn for the command and all its children. func (c *Cmd) Walk(fn func(*Cmd)) { fn(c) for _, child := range c.Children { + child.Parent = c child.Walk(fn) } } +// PrepareAll performs initialization and linting on the command and all its children. +func (c *Cmd) PrepareAll() error { + if c.Use == "" { + return xerrors.New("command must have a Use field so that it has a name") + } + var merr error + + slices.SortFunc(c.Options, func(a, b Option) bool { + return a.Flag < b.Flag + }) + for _, opt := range c.Options { + if opt.Name == "" { + switch { + case opt.Flag != "": + opt.Name = opt.Flag + case opt.Env != "": + opt.Name = opt.Env + case opt.YAML != "": + opt.Name = opt.YAML + default: + merr = errors.Join(merr, xerrors.Errorf("option must have a Name, Flag, Env or YAML field")) + } + } + if opt.Description != "" { + // Enforce that description uses sentence form. + if unicode.IsLower(rune(opt.Description[0])) { + merr = errors.Join(merr, xerrors.Errorf("option %q description should start with a capital letter", opt.Name)) + } + if !strings.HasSuffix(opt.Description, ".") { + merr = errors.Join(merr, xerrors.Errorf("option %q description should end with a period", opt.Name)) + } + } + } + slices.SortFunc(c.Children, func(a, b *Cmd) bool { + return a.Name() < b.Name() + }) + for _, child := range c.Children { + child.Parent = c + err := child.PrepareAll() + if err != nil { + merr = errors.Join(merr, xerrors.Errorf("command %v: %w", child.Name(), err)) + } + } + return merr +} + // Name returns the first word in the Use string. func (c *Cmd) Name() string { return strings.Split(c.Use, " ")[0] @@ -64,7 +124,6 @@ func (c *Cmd) Name() string { // as seen on the command line. func (c *Cmd) FullName() string { var names []string - if c.Parent != nil { names = append(names, c.Parent.FullName()) } @@ -77,7 +136,7 @@ func (c *Cmd) FullName() string { func (c *Cmd) FullUsage() string { var uses []string if c.Parent != nil { - uses = append(uses, c.Parent.FullUsage()) + uses = append(uses, c.Parent.FullName()) } uses = append(uses, c.Use) return strings.Join(uses, " ") @@ -115,28 +174,17 @@ type Invocation struct { // fields with OS defaults. func (i *Invocation) WithOS() *Invocation { return i.with(func(i *Invocation) { - if i.Stdout == nil { - i.Stdout = os.Stdout - } - if i.Stderr == nil { - i.Stderr = os.Stderr - } - if i.Stdin == nil { - i.Stdin = os.Stdin - } - if i.Args == nil { - i.Args = os.Args[1:] - } - if i.Environ == nil { - i.Environ = ParseEnviron(os.Environ(), "") - } + i.Stdout = os.Stdout + i.Stderr = os.Stderr + i.Stdin = os.Stdin + i.Args = os.Args[1:] + i.Environ = ParseEnviron(os.Environ(), "") }) } func (i *Invocation) Context() context.Context { if i.ctx == nil { - // Consider returning context.Background() instead? - panic("context not set, has WithContext() or Run() been called?") + return context.Background() } return i.ctx } @@ -155,6 +203,18 @@ type runState struct { flagParseErr error } +func copyFlagSetWithout(fs *pflag.FlagSet, without string) *pflag.FlagSet { + fs2 := pflag.NewFlagSet("", pflag.ContinueOnError) + fs2.Usage = func() {} + fs.VisitAll(func(f *pflag.Flag) { + if f.Name == without { + return + } + fs2.AddFlag(f) + }) + return fs2 +} + // run recursively executes the command and its children. // allArgs is wired through the stack so that global flags can be accepted // anywhere in the command invocation. @@ -164,6 +224,23 @@ func (i *Invocation) run(state *runState) error { return xerrors.Errorf("setting defaults: %w", err) } + // If we set the Default of an array but later see a flag for it, we + // don't want to append, we want to replace. So, we need to keep the state + // of defaulted array options. + defaultedArrays := make(map[string]int) + for _, opt := range i.Command.Options { + sv, ok := opt.Value.(pflag.SliceValue) + if !ok { + continue + } + + if opt.Flag == "" { + continue + } + + defaultedArrays[opt.Flag] = len(sv.GetSlice()) + } + err = i.Command.Options.ParseEnv(i.Environ) if err != nil { return xerrors.Errorf("parsing env: %w", err) @@ -173,6 +250,7 @@ func (i *Invocation) run(state *runState) error { children := make(map[string]*Cmd) for _, child := range i.Command.Children { + child.Parent = i.Command for _, name := range append(child.Aliases, child.Name()) { if _, ok := children[name]; ok { return xerrors.Errorf("duplicate command name: %s", name) @@ -187,7 +265,15 @@ func (i *Invocation) run(state *runState) error { i.parsedFlags.Usage = func() {} } - i.parsedFlags.AddFlagSet(i.Command.Options.FlagSet()) + // If we find a duplicate flag, we want the deeper command's flag to override + // the shallow one. Unfortunately, pflag has no way to remove a flag, so we + // have to create a copy of the flagset without a value. + i.Command.Options.FlagSet().VisitAll(func(f *pflag.Flag) { + if i.parsedFlags.Lookup(f.Name) != nil { + i.parsedFlags = copyFlagSetWithout(i.parsedFlags, f.Name) + } + i.parsedFlags.AddFlag(f) + }) var parsedArgs []string @@ -196,24 +282,38 @@ func (i *Invocation) run(state *runState) error { // so we check the error after looking for a child command. state.flagParseErr = i.parsedFlags.Parse(state.allArgs) parsedArgs = i.parsedFlags.Args() + + i.parsedFlags.VisitAll(func(f *pflag.Flag) { + i, ok := defaultedArrays[f.Name] + if !ok { + return + } + + if !f.Changed { + return + } + + sv, ok := f.Value.(pflag.SliceValue) + if !ok { + panic("defaulted array option is not a slice value") + } + err := sv.Replace(sv.GetSlice()[i:]) + if err != nil { + panic(err) + } + }) } // Run child command if found (next child only) // We must do subcommand detection after flag parsing so we don't mistake flag // values for subcommand names. - if len(parsedArgs) > 0 { - nextArg := parsedArgs[0] + if len(parsedArgs) > state.commandDepth { + nextArg := parsedArgs[state.commandDepth] if child, ok := children[nextArg]; ok { child.Parent = i.Command i.Command = child state.commandDepth++ - err = i.run(state) - if err != nil { - return xerrors.Errorf( - "subcommand %s: %w", child.Name(), err, - ) - } - return nil + return i.run(state) } } @@ -266,11 +366,27 @@ func (i *Invocation) run(state *runState) error { err = mw(i.Command.Handler)(i) if err != nil { - return xerrors.Errorf("running command %s: %w", i.Command.FullName(), err) + return &RunCommandError{ + Cmd: i.Command, + Err: err, + } } return nil } +type RunCommandError struct { + Cmd *Cmd + Err error +} + +func (e *RunCommandError) Unwrap() error { + return e.Err +} + +func (e *RunCommandError) Error() string { + return fmt.Sprintf("running command %q: %+v", e.Cmd.FullName(), e.Err) +} + // findArg returns the index of the first occurrence of arg in args, skipping // over all flags. func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { @@ -314,10 +430,21 @@ func findArg(want string, args []string, fs *pflag.FlagSet) (int, error) { // If two command share a flag name, the first command wins. // //nolint:revive -func (i *Invocation) Run() error { - return i.run(&runState{ +func (i *Invocation) Run() (err error) { + defer func() { + // Pflag is panicky, so additional context is helpful in tests. + if flag.Lookup("test.v") == nil { + return + } + if r := recover(); r != nil { + err = xerrors.Errorf("panic recovered for %s: %v", i.Command.FullName(), r) + panic(err) + } + }() + err = i.run(&runState{ allArgs: i.Args, }) + return err } // WithContext returns a copy of the Invocation with the given context. @@ -378,6 +505,9 @@ func RequireRangeArgs(start, end int) MiddlewareFunc { case start == end && got != start: switch start { case 0: + if len(i.Command.Children) > 0 { + return xerrors.Errorf("unrecognized subcommand %q", i.Args[0]) + } return xerrors.Errorf("wanted no args but got %v %v", got, i.Args) default: return xerrors.Errorf( diff --git a/cli/clibase/cmd_test.go b/cli/clibase/cmd_test.go index ac08b26837072..cc6ff5858c3e8 100644 --- a/cli/clibase/cmd_test.go +++ b/cli/clibase/cmd_test.go @@ -213,6 +213,66 @@ func TestCommand(t *testing.T) { }) } +func TestCommand_DeepNest(t *testing.T) { + t.Parallel() + cmd := &clibase.Cmd{ + Use: "1", + Children: []*clibase.Cmd{ + { + Use: "2", + Children: []*clibase.Cmd{ + { + Use: "3", + Handler: func(i *clibase.Invocation) error { + i.Stdout.Write([]byte("3")) + return nil + }, + }, + }, + }, + }, + } + inv := cmd.Invoke("2", "3") + stdio := fakeIO(inv) + err := inv.Run() + require.NoError(t, err) + require.Equal(t, "3", stdio.Stdout.String()) +} + +func TestCommand_FlagOverride(t *testing.T) { + t.Parallel() + var flag string + + cmd := &clibase.Cmd{ + Use: "1", + Options: clibase.OptionSet{ + { + Flag: "f", + Value: clibase.DiscardValue, + }, + }, + Children: []*clibase.Cmd{ + { + Use: "2", + Options: clibase.OptionSet{ + { + Flag: "f", + Value: clibase.StringOf(&flag), + }, + }, + Handler: func(i *clibase.Invocation) error { + return nil + }, + }, + }, + } + + err := cmd.Invoke("2", "--f", "mhmm").Run() + require.NoError(t, err) + + require.Equal(t, "mhmm", flag) +} + func TestCommand_MiddlewareOrder(t *testing.T) { t.Parallel() @@ -252,7 +312,7 @@ func TestCommand_RawArgs(t *testing.T) { cmd := func() *clibase.Cmd { return &clibase.Cmd{ Use: "root", - Options: []clibase.Option{ + Options: clibase.OptionSet{ { Name: "password", Flag: "password", @@ -366,3 +426,80 @@ func TestCommand_ContextCancels(t *testing.T) { require.Error(t, gotCtx.Err()) } + +func TestCommand_Help(t *testing.T) { + t.Parallel() + + cmd := func() *clibase.Cmd { + return &clibase.Cmd{ + Use: "root", + HelpHandler: (func(i *clibase.Invocation) error { + i.Stdout.Write([]byte("abdracadabra")) + return nil + }), + Handler: (func(i *clibase.Invocation) error { + return xerrors.New("should not be called") + }), + } + } + + t.Run("NoHandler", func(t *testing.T) { + t.Parallel() + + c := cmd() + c.HelpHandler = nil + err := c.Invoke("--help").Run() + require.Error(t, err) + }) + + t.Run("Long", func(t *testing.T) { + t.Parallel() + + inv := cmd().Invoke("--help") + stdio := fakeIO(inv) + err := inv.Run() + require.NoError(t, err) + + require.Contains(t, stdio.Stdout.String(), "abdracadabra") + }) + + t.Run("Short", func(t *testing.T) { + t.Parallel() + + inv := cmd().Invoke("-h") + stdio := fakeIO(inv) + err := inv.Run() + require.NoError(t, err) + + require.Contains(t, stdio.Stdout.String(), "abdracadabra") + }) +} + +func TestCommand_SliceFlags(t *testing.T) { + t.Parallel() + + cmd := func(want ...string) *clibase.Cmd { + var got []string + return &clibase.Cmd{ + Use: "root", + Options: clibase.OptionSet{ + { + Name: "arr", + Flag: "arr", + Default: "bad,bad,bad", + Value: clibase.StringArrayOf(&got), + }, + }, + Handler: (func(i *clibase.Invocation) error { + require.Equal(t, want, got) + return nil + }), + } + } + + err := cmd("good", "good", "good").Invoke("--arr", "good", "--arr", "good", "--arr", "good").Run() + require.NoError(t, err) + + err = cmd("bad", "bad", "bad").Invoke().Run() + require.NoError(t, err) +} diff --git a/cli/clibase/env.go b/cli/clibase/env.go index 1a73f66a34b80..11fb50d4e0389 100644 --- a/cli/clibase/env.go +++ b/cli/clibase/env.go @@ -44,6 +44,11 @@ func (e Environ) Lookup(name string) (string, bool) { return "", false } +func (e Environ) Get(name string) string { + v, _ := e.Lookup(name) + return v +} + func (e *Environ) Set(name, value string) { for i, v := range *e { if v.Name == name { diff --git a/cli/clibase/option.go b/cli/clibase/option.go index 76fd4f51117c6..836517979db6c 100644 --- a/cli/clibase/option.go +++ b/cli/clibase/option.go @@ -77,7 +77,7 @@ func (s *OptionSet) FlagSet() *pflag.FlagSet { val := opt.Value if val == nil { - val = &DiscardValue{} + val = DiscardValue } fs.AddFlag(&pflag.Flag{ diff --git a/cli/clibase/option_test.go b/cli/clibase/option_test.go index 7b3702d83714c..d9d38cc6c7bd9 100644 --- a/cli/clibase/option_test.go +++ b/cli/clibase/option_test.go @@ -35,10 +35,10 @@ func TestOptionSet_ParseFlags(t *testing.T) { require.EqualValues(t, "f", workspaceName) }) - t.Run("Strings", func(t *testing.T) { + t.Run("StringArray", func(t *testing.T) { t.Parallel() - var names clibase.Strings + var names clibase.StringArray os := clibase.OptionSet{ clibase.Option{ @@ -49,7 +49,10 @@ func TestOptionSet_ParseFlags(t *testing.T) { }, } - err := os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) + err := os.SetDefaults() + require.NoError(t, err) + + err = os.FlagSet().Parse([]string{"--name", "foo", "--name", "bar"}) require.NoError(t, err) require.EqualValues(t, []string{"foo", "bar"}, names) }) diff --git a/cli/clibase/values.go b/cli/clibase/values.go index 7fdfd6b730411..acb4cab5d50f7 100644 --- a/cli/clibase/values.go +++ b/cli/clibase/values.go @@ -109,26 +109,26 @@ func (String) Type() string { return "string" } -var _ pflag.SliceValue = &Strings{} +var _ pflag.SliceValue = &StringArray{} -// Strings is a slice of strings that implements pflag.Value and pflag.SliceValue. -type Strings []string +// StringArray is a slice of strings that implements pflag.Value and pflag.SliceValue. +type StringArray []string -func StringsOf(ss *[]string) *Strings { - return (*Strings)(ss) +func StringArrayOf(ss *[]string) *StringArray { + return (*StringArray)(ss) } -func (s *Strings) Append(v string) error { +func (s *StringArray) Append(v string) error { *s = append(*s, v) return nil } -func (s *Strings) Replace(vals []string) error { +func (s *StringArray) Replace(vals []string) error { *s = vals return nil } -func (s *Strings) GetSlice() []string { +func (s *StringArray) GetSlice() []string { return *s } @@ -145,7 +145,7 @@ func writeAsCSV(vals []string) string { return sb.String() } -func (s *Strings) Set(v string) error { +func (s *StringArray) Set(v string) error { ss, err := readAsCSV(v) if err != nil { return err @@ -154,16 +154,16 @@ func (s *Strings) Set(v string) error { return nil } -func (s Strings) String() string { +func (s StringArray) String() string { return writeAsCSV([]string(s)) } -func (s Strings) Value() []string { +func (s StringArray) Value() []string { return []string(s) } -func (Strings) Type() string { - return "strings" +func (StringArray) Type() string { + return "string-array" } type Duration time.Duration @@ -287,7 +287,7 @@ func (hp *HostPort) UnmarshalJSON(b []byte) error { } func (*HostPort) Type() string { - return "bind-address" + return "host:port" } var ( @@ -344,16 +344,50 @@ func (s *Struct[T]) UnmarshalJSON(b []byte) error { // DiscardValue does nothing but implements the pflag.Value interface. // It's useful in cases where you want to accept an option, but access the // underlying value directly instead of through the Option methods. -type DiscardValue struct{} +var DiscardValue discardValue -func (DiscardValue) Set(string) error { +type discardValue struct{} + +func (discardValue) Set(string) error { return nil } -func (DiscardValue) String() string { +func (discardValue) String() string { return "" } -func (DiscardValue) Type() string { +func (discardValue) Type() string { return "discard" } + +var _ pflag.Value = (*Enum)(nil) + +type Enum struct { + Choices []string + Value *string +} + +func EnumOf(v *string, choices ...string) *Enum { + return &Enum{ + Choices: choices, + Value: v, + } +} + +func (e *Enum) Set(v string) error { + for _, c := range e.Choices { + if v == c { + *e.Value = v + return nil + } + } + return xerrors.Errorf("invalid choice: %s, should be one of %v", v, e.Choices) +} + +func (e *Enum) Type() string { + return fmt.Sprintf("enum[%v]", strings.Join(e.Choices, "|")) +} + +func (e *Enum) String() string { + return *e.Value +} diff --git a/cli/clibase/yaml_test.go b/cli/clibase/yaml_test.go index 1e148738816d3..3efad6ee54ed8 100644 --- a/cli/clibase/yaml_test.go +++ b/cli/clibase/yaml_test.go @@ -38,7 +38,7 @@ func TestOption_ToYAML(t *testing.T) { Name: "Workspace Name", Value: &workspaceName, Default: "billie", - Description: "The workspace's name", + Description: "The workspace's name.", Group: &clibase.Group{Name: "Names"}, YAML: "workspaceName", }, diff --git a/cli/cliflag/cliflag.go b/cli/cliflag/cliflag.go deleted file mode 100644 index 4d93f8a77bc15..0000000000000 --- a/cli/cliflag/cliflag.go +++ /dev/null @@ -1,185 +0,0 @@ -// Package cliflag extends flagset with environment variable defaults. -// -// Usage: -// -// cliflag.String(root.Flags(), &address, "address", "a", "CODER_ADDRESS", "127.0.0.1:3000", "The address to serve the API and dashboard") -// -// Will produce the following usage docs: -// -// -a, --address string The address to serve the API and dashboard (uses $CODER_ADDRESS). (default "127.0.0.1:3000") -package cliflag - -import ( - "fmt" - "os" - "strconv" - "strings" - "time" - - "github.com/spf13/cobra" - "github.com/spf13/pflag" - - "github.com/coder/coder/cli/cliui" -) - -// IsSetBool returns the value of the boolean flag if it is set. -// It returns false if the flag isn't set or if any error occurs attempting -// to parse the value of the flag. -func IsSetBool(cmd *cobra.Command, name string) bool { - val, ok := IsSet(cmd, name) - if !ok { - return false - } - - b, err := strconv.ParseBool(val) - return err == nil && b -} - -// IsSet returns the string value of the flag and whether it was set. -func IsSet(cmd *cobra.Command, name string) (string, bool) { - flag := cmd.Flag(name) - if flag == nil { - return "", false - } - - return flag.Value.String(), flag.Changed -} - -// String sets a string flag on the given flag set. -func String(flagset *pflag.FlagSet, name, shorthand, env, def, usage string) { - v, ok := os.LookupEnv(env) - if !ok || v == "" { - v = def - } - flagset.StringP(name, shorthand, v, fmtUsage(usage, env)) -} - -// StringVarP sets a string flag on the given flag set. -func StringVarP(flagset *pflag.FlagSet, p *string, name string, shorthand string, env string, def string, usage string) { - v, ok := os.LookupEnv(env) - if !ok || v == "" { - v = def - } - flagset.StringVarP(p, name, shorthand, v, fmtUsage(usage, env)) -} - -func StringArray(flagset *pflag.FlagSet, name, shorthand, env string, def []string, usage string) { - v, ok := os.LookupEnv(env) - if !ok || v == "" { - if v == "" { - def = []string{} - } else { - def = strings.Split(v, ",") - } - } - flagset.StringArrayP(name, shorthand, def, fmtUsage(usage, env)) -} - -func StringArrayVarP(flagset *pflag.FlagSet, ptr *[]string, name string, shorthand string, env string, def []string, usage string) { - val, ok := os.LookupEnv(env) - if ok { - if val == "" { - def = []string{} - } else { - def = strings.Split(val, ",") - } - } - flagset.StringArrayVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) -} - -// Uint8VarP sets a uint8 flag on the given flag set. -func Uint8VarP(flagset *pflag.FlagSet, ptr *uint8, name string, shorthand string, env string, def uint8, usage string) { - val, ok := os.LookupEnv(env) - if !ok || val == "" { - flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - vi64, err := strconv.ParseUint(val, 10, 8) - if err != nil { - flagset.Uint8VarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - flagset.Uint8VarP(ptr, name, shorthand, uint8(vi64), fmtUsage(usage, env)) -} - -// IntVarP sets a uint8 flag on the given flag set. -func IntVarP(flagset *pflag.FlagSet, ptr *int, name string, shorthand string, env string, def int, usage string) { - val, ok := os.LookupEnv(env) - if !ok || val == "" { - flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - vi64, err := strconv.ParseUint(val, 10, 8) - if err != nil { - flagset.IntVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - flagset.IntVarP(ptr, name, shorthand, int(vi64), fmtUsage(usage, env)) -} - -func Bool(flagset *pflag.FlagSet, name, shorthand, env string, def bool, usage string) { - val, ok := os.LookupEnv(env) - if !ok || val == "" { - flagset.BoolP(name, shorthand, def, fmtUsage(usage, env)) - return - } - - valb, err := strconv.ParseBool(val) - if err != nil { - flagset.BoolP(name, shorthand, def, fmtUsage(usage, env)) - return - } - - flagset.BoolP(name, shorthand, valb, fmtUsage(usage, env)) -} - -// BoolVarP sets a bool flag on the given flag set. -func BoolVarP(flagset *pflag.FlagSet, ptr *bool, name string, shorthand string, env string, def bool, usage string) { - val, ok := os.LookupEnv(env) - if !ok || val == "" { - flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - valb, err := strconv.ParseBool(val) - if err != nil { - flagset.BoolVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - flagset.BoolVarP(ptr, name, shorthand, valb, fmtUsage(usage, env)) -} - -// DurationVarP sets a time.Duration flag on the given flag set. -func DurationVarP(flagset *pflag.FlagSet, ptr *time.Duration, name string, shorthand string, env string, def time.Duration, usage string) { - val, ok := os.LookupEnv(env) - if !ok || val == "" { - flagset.DurationVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - valb, err := time.ParseDuration(val) - if err != nil { - flagset.DurationVarP(ptr, name, shorthand, def, fmtUsage(usage, env)) - return - } - - flagset.DurationVarP(ptr, name, shorthand, valb, fmtUsage(usage, env)) -} - -func fmtUsage(u string, env string) string { - if env != "" { - // Avoid double dotting. - dot := "." - if strings.HasSuffix(u, ".") { - dot = "" - } - u = fmt.Sprintf("%s%s\n"+cliui.Styles.Placeholder.Render("Consumes $%s"), u, dot, env) - } - - return u -} diff --git a/cli/cliflag/cliflag_test.go b/cli/cliflag/cliflag_test.go deleted file mode 100644 index 5d826166307a5..0000000000000 --- a/cli/cliflag/cliflag_test.go +++ /dev/null @@ -1,277 +0,0 @@ -package cliflag_test - -import ( - "fmt" - "strconv" - "testing" - "time" - - "github.com/spf13/pflag" - "github.com/stretchr/testify/require" - - "github.com/coder/coder/cli/cliflag" - "github.com/coder/coder/cryptorand" -) - -// Testcliflag cannot run in parallel because it uses t.Setenv. -// -//nolint:paralleltest -func TestCliflag(t *testing.T) { - t.Run("StringDefault", func(t *testing.T) { - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.String(10) - cliflag.String(flagset, name, shorthand, env, def, usage) - got, err := flagset.GetString(name) - require.NoError(t, err) - require.Equal(t, def, got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("StringEnvVar", func(t *testing.T) { - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.String(10) - cliflag.String(flagset, name, shorthand, env, def, usage) - got, err := flagset.GetString(name) - require.NoError(t, err) - require.Equal(t, envValue, got) - }) - - t.Run("StringVarPDefault", func(t *testing.T) { - var ptr string - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.String(10) - - cliflag.StringVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetString(name) - require.NoError(t, err) - require.Equal(t, def, got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("StringVarPEnvVar", func(t *testing.T) { - var ptr string - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.String(10) - - cliflag.StringVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetString(name) - require.NoError(t, err) - require.Equal(t, envValue, got) - }) - - t.Run("EmptyEnvVar", func(t *testing.T) { - var ptr string - flagset, name, shorthand, _, usage := randomFlag() - def, _ := cryptorand.String(10) - - cliflag.StringVarP(flagset, &ptr, name, shorthand, "", def, usage) - got, err := flagset.GetString(name) - require.NoError(t, err) - require.Equal(t, def, got) - require.Contains(t, flagset.FlagUsages(), usage) - require.NotContains(t, flagset.FlagUsages(), "Consumes") - }) - - t.Run("StringArrayDefault", func(t *testing.T) { - var ptr []string - flagset, name, shorthand, env, usage := randomFlag() - def := []string{"hello"} - cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetStringArray(name) - require.NoError(t, err) - require.Equal(t, def, got) - }) - - t.Run("StringArrayEnvVar", func(t *testing.T) { - var ptr []string - flagset, name, shorthand, env, usage := randomFlag() - t.Setenv(env, "wow,test") - cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage) - got, err := flagset.GetStringArray(name) - require.NoError(t, err) - require.Equal(t, []string{"wow", "test"}, got) - }) - - t.Run("StringArrayEnvVarEmpty", func(t *testing.T) { - var ptr []string - flagset, name, shorthand, env, usage := randomFlag() - t.Setenv(env, "") - cliflag.StringArrayVarP(flagset, &ptr, name, shorthand, env, nil, usage) - got, err := flagset.GetStringArray(name) - require.NoError(t, err) - require.Equal(t, []string{}, got) - }) - - t.Run("UInt8Default", func(t *testing.T) { - var ptr uint8 - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.Int63n(10) - - cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage) - got, err := flagset.GetUint8(name) - require.NoError(t, err) - require.Equal(t, uint8(def), got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("UInt8EnvVar", func(t *testing.T) { - var ptr uint8 - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.Int63n(10) - t.Setenv(env, strconv.FormatUint(uint64(envValue), 10)) - def, _ := cryptorand.Int() - - cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage) - got, err := flagset.GetUint8(name) - require.NoError(t, err) - require.Equal(t, uint8(envValue), got) - }) - - t.Run("UInt8FailParse", func(t *testing.T) { - var ptr uint8 - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.Int63n(10) - - cliflag.Uint8VarP(flagset, &ptr, name, shorthand, env, uint8(def), usage) - got, err := flagset.GetUint8(name) - require.NoError(t, err) - require.Equal(t, uint8(def), got) - }) - - t.Run("IntDefault", func(t *testing.T) { - var ptr int - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.Int63n(10) - - cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage) - got, err := flagset.GetInt(name) - require.NoError(t, err) - require.Equal(t, int(def), got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("IntEnvVar", func(t *testing.T) { - var ptr int - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.Int63n(10) - t.Setenv(env, strconv.FormatUint(uint64(envValue), 10)) - def, _ := cryptorand.Int() - - cliflag.IntVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetInt(name) - require.NoError(t, err) - require.Equal(t, int(envValue), got) - }) - - t.Run("IntFailParse", func(t *testing.T) { - var ptr int - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.Int63n(10) - - cliflag.IntVarP(flagset, &ptr, name, shorthand, env, int(def), usage) - got, err := flagset.GetInt(name) - require.NoError(t, err) - require.Equal(t, int(def), got) - }) - - t.Run("BoolDefault", func(t *testing.T) { - var ptr bool - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.Bool() - - cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetBool(name) - require.NoError(t, err) - require.Equal(t, def, got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("BoolEnvVar", func(t *testing.T) { - var ptr bool - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.Bool() - t.Setenv(env, strconv.FormatBool(envValue)) - def, _ := cryptorand.Bool() - - cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetBool(name) - require.NoError(t, err) - require.Equal(t, envValue, got) - }) - - t.Run("BoolFailParse", func(t *testing.T) { - var ptr bool - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.Bool() - - cliflag.BoolVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetBool(name) - require.NoError(t, err) - require.Equal(t, def, got) - }) - - t.Run("DurationDefault", func(t *testing.T) { - var ptr time.Duration - flagset, name, shorthand, env, usage := randomFlag() - def, _ := cryptorand.Duration() - - cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetDuration(name) - require.NoError(t, err) - require.Equal(t, def, got) - require.Contains(t, flagset.FlagUsages(), usage) - require.Contains(t, flagset.FlagUsages(), fmt.Sprintf("Consumes $%s", env)) - }) - - t.Run("DurationEnvVar", func(t *testing.T) { - var ptr time.Duration - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.Duration() - t.Setenv(env, envValue.String()) - def, _ := cryptorand.Duration() - - cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetDuration(name) - require.NoError(t, err) - require.Equal(t, envValue, got) - }) - - t.Run("DurationFailParse", func(t *testing.T) { - var ptr time.Duration - flagset, name, shorthand, env, usage := randomFlag() - envValue, _ := cryptorand.String(10) - t.Setenv(env, envValue) - def, _ := cryptorand.Duration() - - cliflag.DurationVarP(flagset, &ptr, name, shorthand, env, def, usage) - got, err := flagset.GetDuration(name) - require.NoError(t, err) - require.Equal(t, def, got) - }) -} - -func randomFlag() (*pflag.FlagSet, string, string, string, string) { - fsname, _ := cryptorand.String(10) - flagset := pflag.NewFlagSet(fsname, pflag.PanicOnError) - name, _ := cryptorand.String(10) - shorthand, _ := cryptorand.String(1) - env, _ := cryptorand.String(10) - usage, _ := cryptorand.String(10) - - return flagset, name, shorthand, env, usage -} diff --git a/cli/clitest/clitest.go b/cli/clitest/clitest.go index a0e235df4fd4b..7680a06981a05 100644 --- a/cli/clitest/clitest.go +++ b/cli/clitest/clitest.go @@ -10,14 +10,16 @@ import ( "os" "path/filepath" "strings" + "sync" + "sync/atomic" "testing" "time" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/coder/coder/cli" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/config" "github.com/coder/coder/codersdk" "github.com/coder/coder/provisioner/echo" @@ -26,8 +28,13 @@ import ( // New creates a CLI instance with a configuration pointed to a // temporary testing directory. -func New(t *testing.T, args ...string) (*cobra.Command, config.Root) { - return NewWithSubcommands(t, cli.AGPL(), args...) +func New(t *testing.T, args ...string) (*clibase.Invocation, config.Root) { + var root cli.RootCmd + + cmd, err := root.Command(root.AGPL()) + require.NoError(t, err) + + return NewWithCommand(t, cmd, args...) } type logWriter struct { @@ -46,19 +53,21 @@ func (l *logWriter) Write(p []byte) (n int, err error) { return len(p), nil } -func NewWithSubcommands( - t *testing.T, subcommands []*cobra.Command, args ...string, -) (*cobra.Command, config.Root) { - cmd := cli.Root(subcommands) - dir := t.TempDir() - root := config.Root(dir) - cmd.SetArgs(append([]string{"--global-config", dir}, args...)) +func NewWithCommand( + t *testing.T, cmd *clibase.Cmd, args ...string, +) (*clibase.Invocation, config.Root) { + configDir := config.Root(t.TempDir()) + i := &clibase.Invocation{ + Command: cmd, + Args: append([]string{"--global-config", string(configDir)}, args...), + Stdin: io.LimitReader(nil, 0), + Stdout: (&logWriter{prefix: "stdout", t: t}), + Stderr: (&logWriter{prefix: "stderr", t: t}), + } + t.Logf("invoking command: %s %s", cmd.Name(), strings.Join(i.Args, " ")) // These can be overridden by the test. - cmd.SetOut(&logWriter{prefix: "stdout", t: t}) - cmd.SetErr(&logWriter{prefix: "stderr", t: t}) - - return cmd, root + return i, configDir } // SetupConfig applies the URL and SessionToken of the client to the config. @@ -120,31 +129,111 @@ func extractTar(t *testing.T, data []byte, directory string) { // Start runs the command in a goroutine and cleans it up when // the test completed. -func Start(ctx context.Context, t *testing.T, cmd *cobra.Command) { +func Start(t *testing.T, inv *clibase.Invocation) { t.Helper() closeCh := make(chan struct{}) + go func() { + defer close(closeCh) + err := StartWithWaiter(t, inv).Wait() + switch { + case errors.Is(err, context.Canceled): + return + default: + assert.NoError(t, err) + } + }() + + t.Cleanup(func() { + <-closeCh + }) +} + +// Run runs the command and asserts that there is no error. +func Run(t *testing.T, inv *clibase.Invocation) { + t.Helper() + + err := inv.Run() + require.NoError(t, err) +} + +type ErrorWaiter struct { + waitOnce sync.Once + cachedError error + + c <-chan error + t *testing.T +} - deadline, hasDeadline := ctx.Deadline() - if !hasDeadline { - // We don't want to wait the full 5 minutes for a test to time out. - deadline = time.Now().Add(testutil.WaitMedium) +func (w *ErrorWaiter) Wait() error { + w.waitOnce.Do(func() { + var ok bool + w.cachedError, ok = <-w.c + if !ok { + panic("unexpoected channel close") + } + }) + return w.cachedError +} + +func (w *ErrorWaiter) RequireSuccess() { + require.NoError(w.t, w.Wait()) +} + +func (w *ErrorWaiter) RequireError() { + require.Error(w.t, w.Wait()) +} + +func (w *ErrorWaiter) RequireContains(s string) { + require.ErrorContains(w.t, w.Wait(), s) +} + +func (w *ErrorWaiter) RequireIs(want error) { + require.ErrorIs(w.t, w.Wait(), want) +} + +func (w *ErrorWaiter) RequireAs(want interface{}) { + require.ErrorAs(w.t, w.Wait(), want) +} + +// StartWithWaiter runs the command in a goroutine but returns the error +// instead of asserting it. This is useful for testing error cases. +func StartWithWaiter(t *testing.T, inv *clibase.Invocation) *ErrorWaiter { + t.Helper() + + errCh := make(chan error, 1) + + var cleaningUp atomic.Bool + + var ( + ctx = inv.Context() + cancel func() + ) + if _, ok := ctx.Deadline(); !ok { + ctx, cancel = context.WithDeadline(ctx, time.Now().Add(testutil.WaitMedium)) + } else { + ctx, cancel = context.WithCancel(inv.Context()) } - ctx, cancel := context.WithDeadline(ctx, deadline) + inv = inv.WithContext(ctx) go func() { - defer cancel() - defer close(closeCh) - err := cmd.ExecuteContext(ctx) - if ctx.Err() == nil { - assert.NoError(t, err) + defer close(errCh) + err := inv.Run() + if cleaningUp.Load() && errors.Is(err, context.DeadlineExceeded) { + // If we're cleaning up, this error is likely related to the + // CLI teardown process. E.g., the server could be slow to shut + // down Postgres. + t.Logf("command %q timed out during test cleanup", inv.Command.FullName()) } + errCh <- err }() // Don't exit test routine until server is done. t.Cleanup(func() { cancel() - <-closeCh + cleaningUp.Store(true) + <-errCh }) + return &ErrorWaiter{c: errCh, t: t} } diff --git a/cli/clitest/clitest_test.go b/cli/clitest/clitest_test.go index 441b84048d05e..283f7b48ca588 100644 --- a/cli/clitest/clitest_test.go +++ b/cli/clitest/clitest_test.go @@ -18,13 +18,9 @@ func TestCli(t *testing.T) { t.Parallel() clitest.CreateTemplateVersionSource(t, nil) client := coderdtest.New(t, nil) - cmd, config := clitest.New(t) + i, config := clitest.New(t) clitest.SetupConfig(t, client, config) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - _ = cmd.Execute() - }() + pty := ptytest.New(t).Attach(i) + clitest.Start(t, i) pty.ExpectMatch("coder") } diff --git a/cli/cliui/agent_test.go b/cli/cliui/agent_test.go index 5122baaa755a1..c6b13b3bbe8d0 100644 --- a/cli/cliui/agent_test.go +++ b/cli/cliui/agent_test.go @@ -5,11 +5,11 @@ import ( "testing" "time" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/atomic" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" "github.com/coder/coder/pty/ptytest" @@ -24,9 +24,9 @@ func TestAgent(t *testing.T) { var disconnected atomic.Bool ptty := ptytest.New(t) - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -44,12 +44,13 @@ func TestAgent(t *testing.T) { return err }, } - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke() + ptty.Attach(inv) done := make(chan struct{}) go func() { defer close(done) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() ptty.ExpectMatchContext(ctx, "lost connection") @@ -66,9 +67,9 @@ func TestAgent_TimeoutWithTroubleshootingURL(t *testing.T) { wantURL := "https://coder.com/troubleshoot" var connected, timeout atomic.Bool - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -91,11 +92,12 @@ func TestAgent_TimeoutWithTroubleshootingURL(t *testing.T) { }, } ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke() + ptty.Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") timeout.Store(true) @@ -115,9 +117,10 @@ func TestAgent_StartupTimeout(t *testing.T) { var status, state atomic.String setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -144,11 +147,12 @@ func TestAgent_StartupTimeout(t *testing.T) { } ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke() + ptty.Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() setStatus(codersdk.WorkspaceAgentConnecting) ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") @@ -173,9 +177,9 @@ func TestAgent_StartErrorExit(t *testing.T) { var status, state atomic.String setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -202,11 +206,12 @@ func TestAgent_StartErrorExit(t *testing.T) { } ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke() + ptty.Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() setStatus(codersdk.WorkspaceAgentConnected) setState(codersdk.WorkspaceAgentLifecycleStarting) @@ -228,9 +233,9 @@ func TestAgent_NoWait(t *testing.T) { var status, state atomic.String setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -257,11 +262,12 @@ func TestAgent_NoWait(t *testing.T) { } ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke() + ptty.Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() setStatus(codersdk.WorkspaceAgentConnecting) ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") @@ -270,19 +276,19 @@ func TestAgent_NoWait(t *testing.T) { require.NoError(t, <-done, "created - should exit early") setState(codersdk.WorkspaceAgentLifecycleStarting) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "starting - should exit early") setState(codersdk.WorkspaceAgentLifecycleStartTimeout) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "start timeout - should exit early") setState(codersdk.WorkspaceAgentLifecycleStartError) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "start error - should exit early") setState(codersdk.WorkspaceAgentLifecycleReady) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "ready - should exit early") } @@ -297,9 +303,9 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) { var status, state atomic.String setStatus := func(s codersdk.WorkspaceAgentStatus) { status.Store(string(s)) } setState := func(s codersdk.WorkspaceAgentLifecycle) { state.Store(string(s)) } - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, _ []string) error { - err := cliui.Agent(cmd.Context(), cmd.OutOrStdout(), cliui.AgentOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + err := cliui.Agent(inv.Context(), inv.Stdout, cliui.AgentOptions{ WorkspaceName: "example", Fetch: func(_ context.Context) (codersdk.WorkspaceAgent, error) { agent := codersdk.WorkspaceAgent{ @@ -325,12 +331,13 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) { }, } + inv := cmd.Invoke() + ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + ptty.Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() setStatus(codersdk.WorkspaceAgentConnecting) ptty.ExpectMatchContext(ctx, "Don't panic, your workspace is booting") @@ -339,18 +346,18 @@ func TestAgent_LoginBeforeReadyEnabled(t *testing.T) { require.NoError(t, <-done, "created - should exit early") setState(codersdk.WorkspaceAgentLifecycleStarting) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "starting - should exit early") setState(codersdk.WorkspaceAgentLifecycleStartTimeout) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "start timeout - should exit early") setState(codersdk.WorkspaceAgentLifecycleStartError) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "start error - should exit early") setState(codersdk.WorkspaceAgentLifecycleReady) - go func() { done <- cmd.ExecuteContext(ctx) }() + go func() { done <- inv.WithContext(ctx).Run() }() require.NoError(t, <-done, "ready - should exit early") } diff --git a/cli/cliui/cliui.go b/cli/cliui/cliui.go index c9ff5c003ab97..d0c1ef86a6ba0 100644 --- a/cli/cliui/cliui.go +++ b/cli/cliui/cliui.go @@ -53,6 +53,8 @@ var Styles = struct { FocusedPrompt: defaultStyles.FocusedPrompt.Foreground(lipgloss.Color("#651fff")), Fuchsia: defaultStyles.SelectedMenuItem.Copy(), Logo: defaultStyles.Logo.SetString("Coder"), - Warn: lipgloss.NewStyle().Foreground(lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#ECFD65"}), - Wrap: lipgloss.NewStyle().Width(80), + Warn: lipgloss.NewStyle().Foreground( + lipgloss.AdaptiveColor{Light: "#04B575", Dark: "#ECFD65"}, + ), + Wrap: lipgloss.NewStyle().Width(80), } diff --git a/cli/cliui/gitauth_test.go b/cli/cliui/gitauth_test.go index de2198798e8d3..13310ab85ffda 100644 --- a/cli/cliui/gitauth_test.go +++ b/cli/cliui/gitauth_test.go @@ -7,9 +7,9 @@ import ( "testing" "time" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" "github.com/coder/coder/pty/ptytest" @@ -23,10 +23,10 @@ func TestGitAuth(t *testing.T) { defer cancel() ptty := ptytest.New(t) - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { var fetched atomic.Bool - return cliui.GitAuth(cmd.Context(), cmd.OutOrStdout(), cliui.GitAuthOptions{ + return cliui.GitAuth(inv.Context(), inv.Stdout, cliui.GitAuthOptions{ Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) { defer fetched.Store(true) return []codersdk.TemplateVersionGitAuth{{ @@ -40,12 +40,14 @@ func TestGitAuth(t *testing.T) { }) }, } - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + + inv := cmd.Invoke().WithContext(ctx) + + ptty.Attach(inv) done := make(chan struct{}) go func() { defer close(done) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() ptty.ExpectMatchContext(ctx, "You must authenticate with") diff --git a/cli/cliui/log.go b/cli/cliui/log.go index 62b4ccd872ee3..f76c3d7a1653e 100644 --- a/cli/cliui/log.go +++ b/cli/cliui/log.go @@ -10,17 +10,22 @@ import ( // cliMessage provides a human-readable message for CLI errors and messages. type cliMessage struct { - Level string Style lipgloss.Style Header string + Prefix string Lines []string } // String formats the CLI message for consumption by a human. func (m cliMessage) String() string { var str strings.Builder - _, _ = fmt.Fprintf(&str, "%s\r\n", - Styles.Bold.Render(m.Header)) + + if m.Prefix != "" { + _, _ = str.WriteString(m.Style.Bold(true).Render(m.Prefix)) + } + + _, _ = str.WriteString(m.Style.Bold(false).Render(m.Header)) + _, _ = str.WriteString("\r\n") for _, line := range m.Lines { _, _ = fmt.Fprintf(&str, " %s %s\r\n", m.Style.Render("|"), line) } @@ -30,9 +35,42 @@ func (m cliMessage) String() string { // Warn writes a log to the writer provided. func Warn(wtr io.Writer, header string, lines ...string) { _, _ = fmt.Fprint(wtr, cliMessage{ - Level: "warning", Style: Styles.Warn, + Prefix: "WARN: ", + Header: header, + Lines: lines, + }.String()) +} + +// Warn writes a formatted log to the writer provided. +func Warnf(wtr io.Writer, fmtStr string, args ...interface{}) { + Warn(wtr, fmt.Sprintf(fmtStr, args...)) +} + +// Info writes a log to the writer provided. +func Info(wtr io.Writer, header string, lines ...string) { + _, _ = fmt.Fprint(wtr, cliMessage{ Header: header, Lines: lines, }.String()) } + +// Infof writes a formatted log to the writer provided. +func Infof(wtr io.Writer, fmtStr string, args ...interface{}) { + Info(wtr, fmt.Sprintf(fmtStr, args...)) +} + +// Error writes a log to the writer provided. +func Error(wtr io.Writer, header string, lines ...string) { + _, _ = fmt.Fprint(wtr, cliMessage{ + Style: Styles.Error, + Prefix: "ERROR: ", + Header: header, + Lines: lines, + }.String()) +} + +// Errorf writes a formatted log to the writer provided. +func Errorf(wtr io.Writer, fmtStr string, args ...interface{}) { + Error(wtr, fmt.Sprintf(fmtStr, args...)) +} diff --git a/cli/cliui/output.go b/cli/cliui/output.go index e537e30473da1..cf3a981fd5a86 100644 --- a/cli/cliui/output.go +++ b/cli/cliui/output.go @@ -6,13 +6,14 @@ import ( "reflect" "strings" - "github.com/spf13/cobra" "golang.org/x/xerrors" + + "github.com/coder/coder/cli/clibase" ) type OutputFormat interface { ID() string - AttachFlags(cmd *cobra.Command) + AttachOptions(opts *clibase.OptionSet) Format(ctx context.Context, data any) (string, error) } @@ -45,11 +46,11 @@ func NewOutputFormatter(formats ...OutputFormat) *OutputFormatter { } } -// AttachFlags attaches the --output flag to the given command, and any +// AttachOptions attaches the --output flag to the given command, and any // additional flags required by the output formatters. -func (f *OutputFormatter) AttachFlags(cmd *cobra.Command) { +func (f *OutputFormatter) AttachOptions(opts *clibase.OptionSet) { for _, format := range f.formats { - format.AttachFlags(cmd) + format.AttachOptions(opts) } formatNames := make([]string, 0, len(f.formats)) @@ -57,7 +58,15 @@ func (f *OutputFormatter) AttachFlags(cmd *cobra.Command) { formatNames = append(formatNames, format.ID()) } - cmd.Flags().StringVarP(&f.formatID, "output", "o", f.formats[0].ID(), "Output format. Available formats: "+strings.Join(formatNames, ", ")) + *opts = append(*opts, + clibase.Option{ + Flag: "output", + FlagShorthand: "o", + Default: f.formats[0].ID(), + Value: clibase.StringOf(&f.formatID), + Description: "Output format. Available formats: " + strings.Join(formatNames, ", ") + ".", + }, + ) } // Format formats the given data using the format specified by the --output @@ -118,9 +127,17 @@ func (*tableFormat) ID() string { return "table" } -// AttachFlags implements OutputFormat. -func (f *tableFormat) AttachFlags(cmd *cobra.Command) { - cmd.Flags().StringSliceVarP(&f.columns, "column", "c", f.defaultColumns, "Columns to display in table output. Available columns: "+strings.Join(f.allColumns, ", ")) +// AttachOptions implements OutputFormat. +func (f *tableFormat) AttachOptions(opts *clibase.OptionSet) { + *opts = append(*opts, + clibase.Option{ + Flag: "column", + FlagShorthand: "c", + Default: strings.Join(f.defaultColumns, ","), + Value: clibase.StringArrayOf(&f.columns), + Description: "Columns to display in table output. Available columns: " + strings.Join(f.allColumns, ", ") + ".", + }, + ) } // Format implements OutputFormat. @@ -142,8 +159,8 @@ func (jsonFormat) ID() string { return "json" } -// AttachFlags implements OutputFormat. -func (jsonFormat) AttachFlags(_ *cobra.Command) {} +// AttachOptions implements OutputFormat. +func (jsonFormat) AttachOptions(_ *clibase.OptionSet) {} // Format implements OutputFormat. func (jsonFormat) Format(_ context.Context, data any) (string, error) { diff --git a/cli/cliui/output_test.go b/cli/cliui/output_test.go index 7a31df9ab8afd..6dbe2fa144b62 100644 --- a/cli/cliui/output_test.go +++ b/cli/cliui/output_test.go @@ -6,16 +6,16 @@ import ( "sync/atomic" "testing" - "github.com/spf13/cobra" "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" ) type format struct { - id string - attachFlagsFn func(cmd *cobra.Command) - formatFn func(ctx context.Context, data any) (string, error) + id string + attachOptionsFn func(opts *clibase.OptionSet) + formatFn func(ctx context.Context, data any) (string, error) } var _ cliui.OutputFormat = &format{} @@ -24,9 +24,9 @@ func (f *format) ID() string { return f.id } -func (f *format) AttachFlags(cmd *cobra.Command) { - if f.attachFlagsFn != nil { - f.attachFlagsFn(cmd) +func (f *format) AttachOptions(opts *clibase.OptionSet) { + if f.attachOptionsFn != nil { + f.attachOptionsFn(opts) } } @@ -82,8 +82,14 @@ func Test_OutputFormatter(t *testing.T) { cliui.JSONFormat(), &format{ id: "foo", - attachFlagsFn: func(cmd *cobra.Command) { - cmd.Flags().StringP("foo", "f", "", "foo flag 1234") + attachOptionsFn: func(opts *clibase.OptionSet) { + opts.Add(clibase.Option{ + Name: "foo", + Flag: "foo", + FlagShorthand: "f", + Value: clibase.DiscardValue, + Description: "foo flag 1234", + }) }, formatFn: func(_ context.Context, _ any) (string, error) { atomic.AddInt64(&called, 1) @@ -92,13 +98,15 @@ func Test_OutputFormatter(t *testing.T) { }, ) - cmd := &cobra.Command{} - f.AttachFlags(cmd) + cmd := &clibase.Cmd{} + f.AttachOptions(&cmd.Options) - selected, err := cmd.Flags().GetString("output") + fs := cmd.Options.FlagSet() + + selected, err := fs.GetString("output") require.NoError(t, err) require.Equal(t, "json", selected) - usage := cmd.Flags().FlagUsages() + usage := fs.FlagUsages() require.Contains(t, usage, "Available formats: json, foo") require.Contains(t, usage, "foo flag 1234") @@ -112,13 +120,13 @@ func Test_OutputFormatter(t *testing.T) { require.Equal(t, data, got) require.EqualValues(t, 0, atomic.LoadInt64(&called)) - require.NoError(t, cmd.Flags().Set("output", "foo")) + require.NoError(t, fs.Set("output", "foo")) out, err = f.Format(ctx, data) require.NoError(t, err) require.Equal(t, "foo", out) require.EqualValues(t, 1, atomic.LoadInt64(&called)) - require.NoError(t, cmd.Flags().Set("output", "bar")) + require.NoError(t, fs.Set("output", "bar")) out, err = f.Format(ctx, data) require.Error(t, err) require.ErrorContains(t, err, "bar") diff --git a/cli/cliui/parameter.go b/cli/cliui/parameter.go index f57891a6c8ffd..96e8eacbf157c 100644 --- a/cli/cliui/parameter.go +++ b/cli/cliui/parameter.go @@ -5,16 +5,15 @@ import ( "fmt" "strings" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/coderd/parameter" "github.com/coder/coder/codersdk" ) -func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchema) (string, error) { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), Styles.Bold.Render("var."+parameterSchema.Name)) +func ParameterSchema(inv *clibase.Invocation, parameterSchema codersdk.ParameterSchema) (string, error) { + _, _ = fmt.Fprintln(inv.Stdout, Styles.Bold.Render("var."+parameterSchema.Name)) if parameterSchema.Description != "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+strings.TrimSpace(strings.Join(strings.Split(parameterSchema.Description, "\n"), "\n "))+"\n") + _, _ = fmt.Fprintln(inv.Stdout, " "+strings.TrimSpace(strings.Join(strings.Split(parameterSchema.Description, "\n"), "\n "))+"\n") } var err error @@ -28,15 +27,15 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem var value string if len(options) > 0 { // Move the cursor up a single line for nicer display! - _, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") - value, err = Select(cmd, SelectOptions{ + _, _ = fmt.Fprint(inv.Stdout, "\033[1A") + value, err = Select(inv, SelectOptions{ Options: options, Default: parameterSchema.DefaultSourceValue, HideSearch: true, }) if err == nil { - _, _ = fmt.Fprintln(cmd.OutOrStdout()) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(value)) + _, _ = fmt.Fprintln(inv.Stdout) + _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(value)) } } else { text := "Enter a value" @@ -45,7 +44,7 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem } text += ":" - value, err = Prompt(cmd, PromptOptions{ + value, err = Prompt(inv, PromptOptions{ Text: Styles.Bold.Render(text), }) value = strings.TrimSpace(value) @@ -62,17 +61,17 @@ func ParameterSchema(cmd *cobra.Command, parameterSchema codersdk.ParameterSchem return value, nil } -func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.TemplateVersionParameter) (string, error) { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), Styles.Bold.Render(templateVersionParameter.Name)) +func RichParameter(inv *clibase.Invocation, templateVersionParameter codersdk.TemplateVersionParameter) (string, error) { + _, _ = fmt.Fprintln(inv.Stdout, Styles.Bold.Render(templateVersionParameter.Name)) if templateVersionParameter.DescriptionPlaintext != "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+strings.TrimSpace(strings.Join(strings.Split(templateVersionParameter.DescriptionPlaintext, "\n"), "\n "))+"\n") + _, _ = fmt.Fprintln(inv.Stdout, " "+strings.TrimSpace(strings.Join(strings.Split(templateVersionParameter.DescriptionPlaintext, "\n"), "\n "))+"\n") } var err error var value string if templateVersionParameter.Type == "list(string)" { // Move the cursor up a single line for nicer display! - _, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") + _, _ = fmt.Fprint(inv.Stdout, "\033[1A") var options []string err = json.Unmarshal([]byte(templateVersionParameter.DefaultValue), &options) @@ -80,29 +79,29 @@ func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.Templat return "", err } - values, err := MultiSelect(cmd, options) + values, err := MultiSelect(inv, options) if err == nil { v, err := json.Marshal(&values) if err != nil { return "", err } - _, _ = fmt.Fprintln(cmd.OutOrStdout()) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(strings.Join(values, ", "))) + _, _ = fmt.Fprintln(inv.Stdout) + _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(strings.Join(values, ", "))) value = string(v) } } else if len(templateVersionParameter.Options) > 0 { // Move the cursor up a single line for nicer display! - _, _ = fmt.Fprint(cmd.OutOrStdout(), "\033[1A") + _, _ = fmt.Fprint(inv.Stdout, "\033[1A") var richParameterOption *codersdk.TemplateVersionParameterOption - richParameterOption, err = RichSelect(cmd, RichSelectOptions{ + richParameterOption, err = RichSelect(inv, RichSelectOptions{ Options: templateVersionParameter.Options, Default: templateVersionParameter.DefaultValue, HideSearch: true, }) if err == nil { - _, _ = fmt.Fprintln(cmd.OutOrStdout()) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+Styles.Prompt.String()+Styles.Field.Render(richParameterOption.Name)) + _, _ = fmt.Fprintln(inv.Stdout) + _, _ = fmt.Fprintln(inv.Stdout, " "+Styles.Prompt.String()+Styles.Field.Render(richParameterOption.Name)) value = richParameterOption.Value } } else { @@ -112,7 +111,7 @@ func RichParameter(cmd *cobra.Command, templateVersionParameter codersdk.Templat } text += ":" - value, err = Prompt(cmd, PromptOptions{ + value, err = Prompt(inv, PromptOptions{ Text: Styles.Bold.Render(text), Validate: func(value string) error { return validateRichPrompt(value, templateVersionParameter) diff --git a/cli/cliui/prompt.go b/cli/cliui/prompt.go index 86c2aa0e506fd..7ce927c0b6b7d 100644 --- a/cli/cliui/prompt.go +++ b/cli/cliui/prompt.go @@ -11,8 +11,9 @@ import ( "github.com/bgentry/speakeasy" "github.com/mattn/go-isatty" - "github.com/spf13/cobra" "golang.org/x/xerrors" + + "github.com/coder/coder/cli/clibase" ) // PromptOptions supply a set of options to the prompt. @@ -26,8 +27,16 @@ type PromptOptions struct { const skipPromptFlag = "yes" -func AllowSkipPrompt(cmd *cobra.Command) { - cmd.Flags().BoolP(skipPromptFlag, "y", false, "Bypass prompts") +// SkipPromptOption adds a "--yes/-y" flag to the cmd that can be used to skip +// prompts. +func SkipPromptOption() clibase.Option { + return clibase.Option{ + Flag: skipPromptFlag, + FlagShorthand: "y", + Description: "Bypass prompts.", + // Discard + Value: clibase.BoolOf(new(bool)), + } } const ( @@ -36,17 +45,17 @@ const ( ) // Prompt asks the user for input. -func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { +func Prompt(inv *clibase.Invocation, opts PromptOptions) (string, error) { // If the cmd has a "yes" flag for skipping confirm prompts, honor it. // If it's not a "Confirm" prompt, then don't skip. As the default value of // "yes" makes no sense. - if opts.IsConfirm && cmd.Flags().Lookup(skipPromptFlag) != nil { - if skip, _ := cmd.Flags().GetBool(skipPromptFlag); skip { + if opts.IsConfirm && inv.ParsedFlags().Lookup(skipPromptFlag) != nil { + if skip, _ := inv.ParsedFlags().GetBool(skipPromptFlag); skip { return ConfirmYes, nil } } - _, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.FocusedPrompt.String()+opts.Text+" ") + _, _ = fmt.Fprint(inv.Stdout, Styles.FocusedPrompt.String()+opts.Text+" ") if opts.IsConfirm { if len(opts.Default) == 0 { opts.Default = ConfirmYes @@ -58,19 +67,24 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { } else { renderedNo = Styles.Bold.Render(ConfirmNo) } - _, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.Placeholder.Render("("+renderedYes+Styles.Placeholder.Render("/"+renderedNo+Styles.Placeholder.Render(") ")))) + _, _ = fmt.Fprint(inv.Stdout, Styles.Placeholder.Render("("+renderedYes+Styles.Placeholder.Render("/"+renderedNo+Styles.Placeholder.Render(") ")))) } else if opts.Default != "" { - _, _ = fmt.Fprint(cmd.OutOrStdout(), Styles.Placeholder.Render("("+opts.Default+") ")) + _, _ = fmt.Fprint(inv.Stdout, Styles.Placeholder.Render("("+opts.Default+") ")) } interrupt := make(chan os.Signal, 1) + if inv.Stdin == nil { + panic("inv.Stdin is nil") + } + errCh := make(chan error, 1) lineCh := make(chan string) + go func() { var line string var err error - inFile, isInputFile := cmd.InOrStdin().(*os.File) + inFile, isInputFile := inv.Stdin.(*os.File) if opts.Secret && isInputFile && isatty.IsTerminal(inFile.Fd()) { // we don't install a signal handler here because speakeasy has its own line, err = speakeasy.Ask("") @@ -78,7 +92,7 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { signal.Notify(interrupt, os.Interrupt) defer signal.Stop(interrupt) - reader := bufio.NewReader(cmd.InOrStdin()) + reader := bufio.NewReader(inv.Stdin) line, err = reader.ReadString('\n') // Check if the first line beings with JSON object or array chars. @@ -96,7 +110,10 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { if line == "" { line = opts.Default } - lineCh <- line + select { + case <-inv.Context().Done(): + case lineCh <- line: + } }() select { @@ -109,16 +126,16 @@ func Prompt(cmd *cobra.Command, opts PromptOptions) (string, error) { if opts.Validate != nil { err := opts.Validate(line) if err != nil { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), defaultStyles.Error.Render(err.Error())) - return Prompt(cmd, opts) + _, _ = fmt.Fprintln(inv.Stdout, defaultStyles.Error.Render(err.Error())) + return Prompt(inv, opts) } } return line, nil - case <-cmd.Context().Done(): - return "", cmd.Context().Err() + case <-inv.Context().Done(): + return "", inv.Context().Err() case <-interrupt: // Print a newline so that any further output starts properly on a new line. - _, _ = fmt.Fprintln(cmd.OutOrStdout()) + _, _ = fmt.Fprintln(inv.Stdout) return "", Canceled } } diff --git a/cli/cliui/prompt_test.go b/cli/cliui/prompt_test.go index 6c7f233c872e6..49f6dee46e957 100644 --- a/cli/cliui/prompt_test.go +++ b/cli/cliui/prompt_test.go @@ -8,10 +8,10 @@ import ( "os/exec" "testing" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/pty" "github.com/coder/coder/pty/ptytest" @@ -77,9 +77,9 @@ func TestPrompt(t *testing.T) { resp, err := newPrompt(ptty, cliui.PromptOptions{ Text: "ShouldNotSeeThis", IsConfirm: true, - }, func(cmd *cobra.Command) { - cliui.AllowSkipPrompt(cmd) - cmd.SetArgs([]string{"-y"}) + }, func(inv *clibase.Invocation) { + inv.Command.Options = append(inv.Command.Options, cliui.SkipPromptOption()) + inv.Args = []string{"-y"} }) assert.NoError(t, err) doneChan <- resp @@ -145,23 +145,25 @@ func TestPrompt(t *testing.T) { }) } -func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, cmdOpt func(cmd *cobra.Command)) (string, error) { +func newPrompt(ptty *ptytest.PTY, opts cliui.PromptOptions, invOpt func(inv *clibase.Invocation)) (string, error) { value := "" - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { var err error - value, err = cliui.Prompt(cmd, opts) + value, err = cliui.Prompt(inv, opts) return err }, } + + inv := cmd.Invoke() // Optionally modify the cmd - if cmdOpt != nil { - cmdOpt(cmd) + if invOpt != nil { + invOpt(inv) } - cmd.SetOut(ptty.Output()) - cmd.SetErr(ptty.Output()) - cmd.SetIn(ptty.Input()) - return value, cmd.ExecuteContext(context.Background()) + inv.Stdout = ptty.Output() + inv.Stderr = ptty.Output() + inv.Stdin = ptty.Input() + return value, inv.WithContext(context.Background()).Run() } func TestPasswordTerminalState(t *testing.T) { @@ -208,13 +210,17 @@ func TestPasswordTerminalState(t *testing.T) { // nolint:unused func passwordHelper() { - cmd := &cobra.Command{ - Run: func(cmd *cobra.Command, args []string) { - cliui.Prompt(cmd, cliui.PromptOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + cliui.Prompt(inv, cliui.PromptOptions{ Text: "Password:", Secret: true, }) + return nil }, } - cmd.ExecuteContext(context.Background()) + err := cmd.Invoke().WithOS().Run() + if err != nil { + panic(err) + } } diff --git a/cli/cliui/provisionerjob_test.go b/cli/cliui/provisionerjob_test.go index 122ff513dd79e..4795867843b74 100644 --- a/cli/cliui/provisionerjob_test.go +++ b/cli/cliui/provisionerjob_test.go @@ -9,9 +9,9 @@ import ( "testing" "time" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" "github.com/coder/coder/codersdk" @@ -125,9 +125,9 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { } jobLock := sync.Mutex{} logs := make(chan codersdk.ProvisionerJobLog, 1) - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { - return cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + return cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{ FetchInterval: time.Millisecond, Fetch: func() (codersdk.ProvisionerJob, error) { jobLock.Lock() @@ -145,13 +145,14 @@ func newProvisionerJob(t *testing.T) provisionerJobTest { }) }, } + inv := cmd.Invoke() + ptty := ptytest.New(t) - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) + ptty.Attach(inv) done := make(chan struct{}) go func() { defer close(done) - err := cmd.ExecuteContext(context.Background()) + err := inv.WithContext(context.Background()).Run() if err != nil { assert.ErrorIs(t, err, cliui.Canceled) } diff --git a/cli/cliui/select.go b/cli/cliui/select.go index 1b6412f51f675..86f8521fe4525 100644 --- a/cli/cliui/select.go +++ b/cli/cliui/select.go @@ -8,9 +8,9 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/AlecAivazis/survey/v2/terminal" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/codersdk" ) @@ -68,7 +68,7 @@ type RichSelectOptions struct { } // RichSelect displays a list of user options including name and description. -func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) { +func RichSelect(inv *clibase.Invocation, richOptions RichSelectOptions) (*codersdk.TemplateVersionParameterOption, error) { opts := make([]string, len(richOptions.Options)) for i, option := range richOptions.Options { line := option.Name @@ -78,7 +78,7 @@ func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.Te opts[i] = line } - selected, err := Select(cmd, SelectOptions{ + selected, err := Select(inv, SelectOptions{ Options: opts, Default: richOptions.Default, Size: richOptions.Size, @@ -97,7 +97,7 @@ func RichSelect(cmd *cobra.Command, richOptions RichSelectOptions) (*codersdk.Te } // Select displays a list of user options. -func Select(cmd *cobra.Command, opts SelectOptions) (string, error) { +func Select(inv *clibase.Invocation, opts SelectOptions) (string, error) { // The survey library used *always* fails when testing on Windows, // as it requires a live TTY (can't be a conpty). We should fork // this library to add a dummy fallback, that simply reads/writes @@ -123,17 +123,17 @@ func Select(cmd *cobra.Command, opts SelectOptions) (string, error) { is.Help.Text = "" } }), survey.WithStdio(fileReadWriter{ - Reader: cmd.InOrStdin(), + Reader: inv.Stdin, }, fileReadWriter{ - Writer: cmd.OutOrStdout(), - }, cmd.OutOrStdout())) + Writer: inv.Stdout, + }, inv.Stdout)) if errors.Is(err, terminal.InterruptErr) { return value, Canceled } return value, err } -func MultiSelect(cmd *cobra.Command, items []string) ([]string, error) { +func MultiSelect(inv *clibase.Invocation, items []string) ([]string, error) { // Similar hack is applied to Select() if flag.Lookup("test.v") != nil { return items, nil @@ -146,10 +146,10 @@ func MultiSelect(cmd *cobra.Command, items []string) ([]string, error) { var values []string err := survey.AskOne(prompt, &values, survey.WithStdio(fileReadWriter{ - Reader: cmd.InOrStdin(), + Reader: inv.Stdin, }, fileReadWriter{ - Writer: cmd.OutOrStdout(), - }, cmd.OutOrStdout())) + Writer: inv.Stdout, + }, inv.Stdout)) if errors.Is(err, terminal.InterruptErr) { return nil, Canceled } diff --git a/cli/cliui/select_test.go b/cli/cliui/select_test.go index c22df1af83097..f7467098cb263 100644 --- a/cli/cliui/select_test.go +++ b/cli/cliui/select_test.go @@ -1,13 +1,12 @@ package cliui_test import ( - "context" "testing" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" "github.com/coder/coder/pty/ptytest" @@ -32,16 +31,16 @@ func TestSelect(t *testing.T) { func newSelect(ptty *ptytest.PTY, opts cliui.SelectOptions) (string, error) { value := "" - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { var err error - value, err = cliui.Select(cmd, opts) + value, err = cliui.Select(inv, opts) return err }, } - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) - return value, cmd.ExecuteContext(context.Background()) + inv := cmd.Invoke() + ptty.Attach(inv) + return value, inv.Run() } func TestRichSelect(t *testing.T) { @@ -56,11 +55,11 @@ func TestRichSelect(t *testing.T) { { Name: "A-Name", Value: "A-Value", - Description: "A-Description", + Description: "A-Description.", }, { Name: "B-Name", Value: "B-Value", - Description: "B-Description", + Description: "B-Description.", }, }, }) @@ -73,18 +72,18 @@ func TestRichSelect(t *testing.T) { func newRichSelect(ptty *ptytest.PTY, opts cliui.RichSelectOptions) (string, error) { value := "" - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { - richOption, err := cliui.RichSelect(cmd, opts) + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + richOption, err := cliui.RichSelect(inv, opts) if err == nil { value = richOption.Value } return err }, } - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) - return value, cmd.ExecuteContext(context.Background()) + inv := cmd.Invoke() + ptty.Attach(inv) + return value, inv.Run() } func TestMultiSelect(t *testing.T) { @@ -106,16 +105,16 @@ func TestMultiSelect(t *testing.T) { func newMultiSelect(ptty *ptytest.PTY, items []string) ([]string, error) { var values []string - cmd := &cobra.Command{ - RunE: func(cmd *cobra.Command, args []string) error { - selectedItems, err := cliui.MultiSelect(cmd, items) + cmd := &clibase.Cmd{ + Handler: func(inv *clibase.Invocation) error { + selectedItems, err := cliui.MultiSelect(inv, items) if err == nil { values = selectedItems } return err }, } - cmd.SetOutput(ptty.Output()) - cmd.SetIn(ptty.Input()) - return values, cmd.ExecuteContext(context.Background()) + inv := cmd.Invoke() + ptty.Attach(inv) + return values, inv.Run() } diff --git a/cli/config/file.go b/cli/config/file.go index b3707b3c2a57c..59b7b74a862d2 100644 --- a/cli/config/file.go +++ b/cli/config/file.go @@ -6,6 +6,7 @@ import ( "path/filepath" "github.com/kirsle/configdir" + "golang.org/x/xerrors" ) const ( @@ -15,36 +16,53 @@ const ( // Root represents the configuration directory. type Root string +// mustNotBeEmpty prevents us from accidentally writing configuration to the +// current directory. This is primarily valuable in development, where we may +// accidentally use an empty root. +func (r Root) mustNotEmpty() { + if r == "" { + panic("config root must not be empty") + } +} + func (r Root) Session() File { + r.mustNotEmpty() return File(filepath.Join(string(r), "session")) } // ReplicaID is a unique identifier for the Coder server. func (r Root) ReplicaID() File { + r.mustNotEmpty() return File(filepath.Join(string(r), "replica_id")) } func (r Root) URL() File { + r.mustNotEmpty() return File(filepath.Join(string(r), "url")) } func (r Root) Organization() File { + r.mustNotEmpty() return File(filepath.Join(string(r), "organization")) } func (r Root) DotfilesURL() File { + r.mustNotEmpty() return File(filepath.Join(string(r), "dotfilesurl")) } func (r Root) PostgresPath() string { + r.mustNotEmpty() return filepath.Join(string(r), "postgres") } func (r Root) PostgresPassword() File { + r.mustNotEmpty() return File(filepath.Join(r.PostgresPath(), "password")) } func (r Root) PostgresPort() File { + r.mustNotEmpty() return File(filepath.Join(r.PostgresPath(), "port")) } @@ -53,16 +71,25 @@ type File string // Delete deletes the file. func (f File) Delete() error { + if f == "" { + return xerrors.Errorf("empty file path") + } return os.Remove(string(f)) } // Write writes the string to the file. func (f File) Write(s string) error { + if f == "" { + return xerrors.Errorf("empty file path") + } return write(string(f), 0o600, []byte(s)) } // Read reads the file to a string. func (f File) Read() (string, error) { + if f == "" { + return "", xerrors.Errorf("empty file path") + } byt, err := read(string(f)) return string(byt), err } diff --git a/cli/configssh.go b/cli/configssh.go index 71d6e45107988..71c8aac526cf8 100644 --- a/cli/configssh.go +++ b/cli/configssh.go @@ -18,12 +18,11 @@ import ( "github.com/cli/safeexec" "github.com/pkg/diff" "github.com/pkg/diff/write" - "github.com/spf13/cobra" "golang.org/x/exp/slices" "golang.org/x/sync/errgroup" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) @@ -170,7 +169,7 @@ func sshPrepareWorkspaceConfigs(ctx context.Context, client *codersdk.Client) (r } } -func configSSH() *cobra.Command { +func (r *RootCmd) configSSH() *clibase.Cmd { var ( sshConfigFile string sshConfigOpts sshConfigOptions @@ -179,11 +178,12 @@ func configSSH() *cobra.Command { skipProxyCommand bool userHostPrefix string ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "config-ssh", Short: "Add an SSH Host entry for your workspaces \"ssh coder.workspace\"", - Example: formatExamples( + Long: formatExamples( example{ Description: "You can use -o (or --ssh-option) so set SSH options to be used for all your workspaces", Command: "coder config-ssh -o ForwardAgent=yes", @@ -193,21 +193,18 @@ func configSSH() *cobra.Command { Command: "coder config-ssh --dry-run", }, ), - Args: cobra.ExactArgs(0), - RunE: func(cmd *cobra.Command, _ []string) error { - ctx := cmd.Context() - client, err := CreateClient(cmd) - if err != nil { - return err - } - - recvWorkspaceConfigs := sshPrepareWorkspaceConfigs(ctx, client) + Middleware: clibase.Chain( + clibase.RequireNArgs(0), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + recvWorkspaceConfigs := sshPrepareWorkspaceConfigs(inv.Context(), client) - out := cmd.OutOrStdout() + out := inv.Stdout if dryRun { // Print everything except diff to stderr so // that it's possible to capture the diff. - out = cmd.OutOrStderr() + out = inv.Stderr } coderBinary, err := currentBinPath(out) if err != nil { @@ -218,7 +215,7 @@ func configSSH() *cobra.Command { return xerrors.Errorf("escape coder binary for ssh failed: %w", err) } - root := createConfig(cmd) + root := r.createConfig() escapedGlobalConfig, err := sshConfigExecEscape(string(root)) if err != nil { return xerrors.Errorf("escape global config for ssh failed: %w", err) @@ -278,7 +275,7 @@ func configSSH() *cobra.Command { oldOptsMsg = fmt.Sprintf("\n\n Previous options:\n * %s", strings.Join(oldOpts, "\n * ")) } - line, err := cliui.Prompt(cmd, cliui.PromptOptions{ + line, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("New options differ from previous options:%s%s\n\n Use new options?", newOptsMsg, oldOptsMsg), IsConfirm: true, }) @@ -292,7 +289,7 @@ func configSSH() *cobra.Command { changes = append(changes, "Use new SSH options") } // Only print when prompts are shown. - if yes, _ := cmd.Flags().GetBool("yes"); !yes { + if yes, _ := inv.ParsedFlags().GetBool("yes"); !yes { _, _ = fmt.Fprint(out, "\n") } } @@ -317,7 +314,7 @@ func configSSH() *cobra.Command { return xerrors.Errorf("fetch workspace configs failed: %w", err) } - coderdConfig, err := client.SSHConfiguration(ctx) + coderdConfig, err := client.SSHConfiguration(inv.Context()) if err != nil { // If the error is 404, this deployment does not support // this endpoint yet. Do not error, just assume defaults. @@ -417,21 +414,21 @@ func configSSH() *cobra.Command { if dryRun { _, _ = fmt.Fprintf(out, "Dry run, the following changes would be made to your SSH configuration:\n\n * %s\n\n", strings.Join(changes, "\n * ")) - color := isTTYOut(cmd) + color := isTTYOut(inv) diff, err := diffBytes(sshConfigFile, configRaw, configModified, color) if err != nil { return xerrors.Errorf("diff failed: %w", err) } if len(diff) > 0 { // Write diff to stdout. - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s", diff) + _, _ = fmt.Fprintf(inv.Stdout, "%s", diff) } return nil } if len(changes) > 0 { - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("The following changes will be made to your SSH configuration:\n\n * %s\n\n Continue?", strings.Join(changes, "\n * ")), IsConfirm: true, }) @@ -439,7 +436,7 @@ func configSSH() *cobra.Command { return nil } // Only print when prompts are shown. - if yes, _ := cmd.Flags().GetBool("yes"); !yes { + if yes, _ := inv.ParsedFlags().GetBool("yes"); !yes { _, _ = fmt.Fprint(out, "\n") } } @@ -449,6 +446,7 @@ func configSSH() *cobra.Command { if err != nil { return xerrors.Errorf("write ssh config failed: %w", err) } + _, _ = fmt.Fprintf(out, "Updated %q\n", sshConfigFile) } if len(workspaceConfigs) > 0 { @@ -460,14 +458,50 @@ func configSSH() *cobra.Command { return nil }, } - cliflag.StringVarP(cmd.Flags(), &sshConfigFile, "ssh-config-file", "", "CODER_SSH_CONFIG_FILE", sshDefaultConfigFileName, "Specifies the path to an SSH config.") - cmd.Flags().StringArrayVarP(&sshConfigOpts.sshOptions, "ssh-option", "o", []string{}, "Specifies additional SSH options to embed in each host stanza.") - cmd.Flags().BoolVarP(&dryRun, "dry-run", "n", false, "Perform a trial run with no changes made, showing a diff at the end.") - cmd.Flags().BoolVarP(&skipProxyCommand, "skip-proxy-command", "", false, "Specifies whether the ProxyCommand option should be skipped. Useful for testing.") - _ = cmd.Flags().MarkHidden("skip-proxy-command") - cliflag.BoolVarP(cmd.Flags(), &usePreviousOpts, "use-previous-options", "", "CODER_SSH_USE_PREVIOUS_OPTIONS", false, "Specifies whether or not to keep options from previous run of config-ssh.") - cmd.Flags().StringVarP(&userHostPrefix, "ssh-host-prefix", "", "", "Override the default host prefix.") - cliui.AllowSkipPrompt(cmd) + + cmd.Options = clibase.OptionSet{ + { + Flag: "ssh-config-file", + Env: "CODER_SSH_CONFIG_FILE", + Default: sshDefaultConfigFileName, + Description: "Specifies the path to an SSH config.", + Value: clibase.StringOf(&sshConfigFile), + }, + { + Flag: "ssh-option", + FlagShorthand: "o", + Env: "CODER_SSH_CONFIG_OPTS", + Description: "Specifies additional SSH options to embed in each host stanza.", + Value: clibase.StringArrayOf(&sshConfigOpts.sshOptions), + }, + { + Flag: "dry-run", + FlagShorthand: "n", + Env: "CODER_SSH_DRY_RUN", + Description: "Perform a trial run with no changes made, showing a diff at the end.", + Value: clibase.BoolOf(&dryRun), + }, + { + Flag: "skip-proxy-command", + Env: "CODER_SSH_SKIP_PROXY_COMMAND", + Description: "Specifies whether the ProxyCommand option should be skipped. Useful for testing.", + Value: clibase.BoolOf(&skipProxyCommand), + Hidden: true, + }, + { + Flag: "use-previous-options", + Env: "CODER_SSH_USE_PREVIOUS_OPTIONS", + Description: "Specifies whether or not to keep options from previous run of config-ssh.", + Value: clibase.BoolOf(&usePreviousOpts), + }, + { + Flag: "ssh-host-prefix", + Env: "", + Description: "Override the default host prefix.", + Value: clibase.StringOf(&userHostPrefix), + }, + cliui.SkipPromptOption(), + } return cmd } diff --git a/cli/configssh_test.go b/cli/configssh_test.go index 112ee82850e9b..a5767dc4a168e 100644 --- a/cli/configssh_test.go +++ b/cli/configssh_test.go @@ -149,21 +149,17 @@ func TestConfigSSH(t *testing.T) { tcpAddr, valid := listener.Addr().(*net.TCPAddr) require.True(t, valid) - cmd, root := clitest.New(t, "config-ssh", + inv, root := clitest.New(t, "config-ssh", "--ssh-option", "HostName "+tcpAddr.IP.String(), "--ssh-option", "Port "+strconv.Itoa(tcpAddr.Port), "--ssh-config-file", sshConfigFile, "--skip-proxy-command") clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + + waiter := clitest.StartWithWaiter(t, inv) matches := []struct { match, write string @@ -175,7 +171,7 @@ func TestConfigSSH(t *testing.T) { pty.WriteLine(m.write) } - <-doneChan + waiter.RequireSuccess() fileContents, err := os.ReadFile(sshConfigFile) require.NoError(t, err, "read ssh config file") @@ -187,7 +183,7 @@ func TestConfigSSH(t *testing.T) { pty = ptytest.New(t) // Set HOME because coder config is included from ~/.ssh/coder. sshCmd.Env = append(sshCmd.Env, fmt.Sprintf("HOME=%s", home)) - sshCmd.Stderr = pty.Output() + inv.Stderr = pty.Output() data, err := sshCmd.Output() require.NoError(t, err) require.Equal(t, "test", strings.TrimSpace(string(data))) @@ -586,14 +582,14 @@ func TestConfigSSH_FileWriteAndOptionsFlow(t *testing.T) { "--ssh-config-file", sshConfigName, } args = append(args, tt.args...) - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() done := tGo(t, func() { - err := cmd.Execute() + err := inv.Run() if !tt.wantErr { assert.NoError(t, err) } else { @@ -703,17 +699,13 @@ func TestConfigSSH_Hostnames(t *testing.T) { sshConfigFile := sshConfigFileName(t) - cmd, root := clitest.New(t, "config-ssh", "--ssh-config-file", sshConfigFile) + inv, root := clitest.New(t, "config-ssh", "--ssh-config-file", sshConfigFile) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) + pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + clitest.Start(t, inv) matches := []struct { match, write string @@ -725,7 +717,7 @@ func TestConfigSSH_Hostnames(t *testing.T) { pty.WriteLine(m.write) } - <-doneChan + pty.ExpectMatch("Updated") var expectedHosts []string for _, hostnamePattern := range tt.expected { diff --git a/cli/create.go b/cli/create.go index e1db7b65ceed5..06901cf43d22e 100644 --- a/cli/create.go +++ b/cli/create.go @@ -6,17 +6,16 @@ import ( "io" "time" - "github.com/spf13/cobra" "golang.org/x/exp/slices" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/util/ptr" "github.com/coder/coder/codersdk" ) -func create() *cobra.Command { +func (r *RootCmd) create() *clibase.Cmd { var ( parameterFile string richParameterFile string @@ -25,30 +24,27 @@ func create() *cobra.Command { stopAfter time.Duration workspaceName string ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "create [name]", Short: "Create a workspace", - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) + Middleware: clibase.Chain(r.InitClient(client)), + Handler: func(inv *clibase.Invocation) error { + organization, err := CurrentOrganization(inv, client) if err != nil { return err } - organization, err := CurrentOrganization(cmd, client) - if err != nil { - return err - } - - if len(args) >= 1 { - workspaceName = args[0] + if len(inv.Args) >= 1 { + workspaceName = inv.Args[0] } if workspaceName == "" { - workspaceName, err = cliui.Prompt(cmd, cliui.PromptOptions{ + workspaceName, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Specify a name for your workspace:", Validate: func(workspaceName string) error { - _, err = client.WorkspaceByOwnerAndName(cmd.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) + _, err = client.WorkspaceByOwnerAndName(inv.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) if err == nil { return xerrors.Errorf("A workspace already exists named %q!", workspaceName) } @@ -60,16 +56,16 @@ func create() *cobra.Command { } } - _, err = client.WorkspaceByOwnerAndName(cmd.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) + _, err = client.WorkspaceByOwnerAndName(inv.Context(), codersdk.Me, workspaceName, codersdk.WorkspaceOptions{}) if err == nil { return xerrors.Errorf("A workspace already exists named %q!", workspaceName) } var template codersdk.Template if templateName == "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Wrap.Render("Select a template below to preview the provisioned infrastructure:")) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Wrap.Render("Select a template below to preview the provisioned infrastructure:")) - templates, err := client.TemplatesByOrganization(cmd.Context(), organization.ID) + templates, err := client.TemplatesByOrganization(inv.Context(), organization.ID) if err != nil { return err } @@ -98,7 +94,7 @@ func create() *cobra.Command { } // Move the cursor up a single line for nicer display! - option, err := cliui.Select(cmd, cliui.SelectOptions{ + option, err := cliui.Select(inv, cliui.SelectOptions{ Options: templateNames, HideSearch: true, }) @@ -108,7 +104,7 @@ func create() *cobra.Command { template = templateByName[option] } else { - template, err = client.TemplateByName(cmd.Context(), organization.ID, templateName) + template, err = client.TemplateByName(inv.Context(), organization.ID, templateName) if err != nil { return xerrors.Errorf("get template by name: %w", err) } @@ -123,7 +119,7 @@ func create() *cobra.Command { schedSpec = ptr.Ref(sched.String()) } - buildParams, err := prepWorkspaceBuild(cmd, client, prepWorkspaceBuildArgs{ + buildParams, err := prepWorkspaceBuild(inv, client, prepWorkspaceBuildArgs{ Template: template, ExistingParams: []codersdk.Parameter{}, ParameterFile: parameterFile, @@ -131,10 +127,10 @@ func create() *cobra.Command { NewWorkspaceName: workspaceName, }) if err != nil { - return err + return xerrors.Errorf("prepare build: %w", err) } - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm create?", IsConfirm: true, }) @@ -149,7 +145,7 @@ func create() *cobra.Command { ttlMillis = &template.MaxTTLMillis } - workspace, err := client.CreateWorkspace(cmd.Context(), organization.ID, codersdk.Me, codersdk.CreateWorkspaceRequest{ + workspace, err := client.CreateWorkspace(inv.Context(), organization.ID, codersdk.Me, codersdk.CreateWorkspaceRequest{ TemplateID: template.ID, Name: workspaceName, AutostartSchedule: schedSpec, @@ -158,25 +154,53 @@ func create() *cobra.Command { RichParameterValues: buildParams.richParameters, }) if err != nil { - return err + return xerrors.Errorf("create workspace: %w", err) } - err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, workspace.LatestBuild.ID) + err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, workspace.LatestBuild.ID) if err != nil { - return err + return xerrors.Errorf("watch build: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been created at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) + _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been created at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) return nil }, } + cmd.Options = append(cmd.Options, + clibase.Option{ + Flag: "template", + FlagShorthand: "t", + Env: "CODER_TEMPLATE_NAME", + Description: "Specify a template name.", + Value: clibase.StringOf(&templateName), + }, + clibase.Option{ + Flag: "parameter-file", + Env: "CODER_PARAMETER_FILE", + Description: "Specify a file path with parameter values.", + Value: clibase.StringOf(¶meterFile), + }, + clibase.Option{ + Flag: "rich-parameter-file", + Env: "CODER_RICH_PARAMETER_FILE", + Description: "Specify a file path with values for rich parameters defined in the template.", + Value: clibase.StringOf(&richParameterFile), + }, + clibase.Option{ + Flag: "start-at", + Env: "CODER_WORKSPACE_START_AT", + Description: "Specify the workspace autostart schedule. Check coder schedule start --help for the syntax.", + Value: clibase.StringOf(&startAt), + }, + clibase.Option{ + Flag: "stop-after", + Env: "CODER_WORKSPACE_STOP_AFTER", + Description: "Specify a duration after which the workspace should shut down (e.g. 8h).", + Value: clibase.DurationOf(&stopAfter), + }, + cliui.SkipPromptOption(), + ) - cliui.AllowSkipPrompt(cmd) - cliflag.StringVarP(cmd.Flags(), &templateName, "template", "t", "CODER_TEMPLATE_NAME", "", "Specify a template name.") - cliflag.StringVarP(cmd.Flags(), ¶meterFile, "parameter-file", "", "CODER_PARAMETER_FILE", "", "Specify a file path with parameter values.") - cliflag.StringVarP(cmd.Flags(), &richParameterFile, "rich-parameter-file", "", "CODER_RICH_PARAMETER_FILE", "", "Specify a file path with values for rich parameters defined in the template.") - cliflag.StringVarP(cmd.Flags(), &startAt, "start-at", "", "CODER_WORKSPACE_START_AT", "", "Specify the workspace autostart schedule. Check `coder schedule start --help` for the syntax.") - cliflag.DurationVarP(cmd.Flags(), &stopAfter, "stop-after", "", "CODER_WORKSPACE_STOP_AFTER", 0, "Specify a duration after which the workspace should shut down (e.g. 8h).") return cmd } @@ -200,8 +224,8 @@ type buildParameters struct { // prepWorkspaceBuild will ensure a workspace build will succeed on the latest template version. // Any missing params will be prompted to the user. It supports legacy and rich parameters. -func prepWorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, args prepWorkspaceBuildArgs) (*buildParameters, error) { - ctx := cmd.Context() +func prepWorkspaceBuild(inv *clibase.Invocation, client *codersdk.Client, args prepWorkspaceBuildArgs) (*buildParameters, error) { + ctx := inv.Context() var useRichParameters bool if len(args.ExistingRichParams) > 0 && len(args.RichParameterFile) > 0 { @@ -233,7 +257,7 @@ func prepWorkspaceBuild(cmd *cobra.Command, client *codersdk.Client, args prepWo useParamFile := false if args.ParameterFile != "" { useParamFile = true - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile) if err != nil { return nil, err @@ -247,7 +271,7 @@ PromptParamLoop: continue } if !disclaimerPrinted { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") disclaimerPrinted = true } @@ -262,7 +286,7 @@ PromptParamLoop: } } - parameterValue, err := getParameterValueFromMapOrInput(cmd, parameterMapFromFile, parameterSchema) + parameterValue, err := getParameterValueFromMapOrInput(inv, parameterMapFromFile, parameterSchema) if err != nil { return nil, err } @@ -276,11 +300,11 @@ PromptParamLoop: } if disclaimerPrinted { - _, _ = fmt.Fprintln(cmd.OutOrStdout()) + _, _ = fmt.Fprintln(inv.Stdout) } // Rich parameters - templateVersionParameters, err := client.TemplateVersionRichParameters(cmd.Context(), templateVersion.ID) + templateVersionParameters, err := client.TemplateVersionRichParameters(inv.Context(), templateVersion.ID) if err != nil { return nil, xerrors.Errorf("get template version rich parameters: %w", err) } @@ -289,7 +313,7 @@ PromptParamLoop: useParamFile = false if args.RichParameterFile != "" { useParamFile = true - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the rich parameter file.")+"\r\n") + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the rich parameter file.")+"\r\n") parameterMapFromFile, err = createParameterMapFromFile(args.RichParameterFile) if err != nil { return nil, err @@ -300,7 +324,7 @@ PromptParamLoop: PromptRichParamLoop: for _, templateVersionParameter := range templateVersionParameters { if !disclaimerPrinted { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has customizable parameters. Values can be changed after create, but may have unintended side effects (like data loss).")+"\r\n") disclaimerPrinted = true } @@ -316,11 +340,11 @@ PromptRichParamLoop: } if args.UpdateWorkspace && !templateVersionParameter.Mutable { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Warn.Render(fmt.Sprintf(`Parameter %q is not mutable, so can't be customized after workspace creation.`, templateVersionParameter.Name))) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Warn.Render(fmt.Sprintf(`Parameter %q is not mutable, so can't be customized after workspace creation.`, templateVersionParameter.Name))) continue } - parameterValue, err := getWorkspaceBuildParameterValueFromMapOrInput(cmd, parameterMapFromFile, templateVersionParameter) + parameterValue, err := getWorkspaceBuildParameterValueFromMapOrInput(inv, parameterMapFromFile, templateVersionParameter) if err != nil { return nil, err } @@ -329,10 +353,10 @@ PromptRichParamLoop: } if disclaimerPrinted { - _, _ = fmt.Fprintln(cmd.OutOrStdout()) + _, _ = fmt.Fprintln(inv.Stdout) } - err = cliui.GitAuth(ctx, cmd.OutOrStdout(), cliui.GitAuthOptions{ + err = cliui.GitAuth(ctx, inv.Stdout, cliui.GitAuthOptions{ Fetch: func(ctx context.Context) ([]codersdk.TemplateVersionGitAuth, error) { return client.TemplateVersionGitAuth(ctx, templateVersion.ID) }, @@ -342,7 +366,7 @@ PromptRichParamLoop: } // Run a dry-run with the given parameters to check correctness - dryRun, err := client.CreateTemplateVersionDryRun(cmd.Context(), templateVersion.ID, codersdk.CreateTemplateVersionDryRunRequest{ + dryRun, err := client.CreateTemplateVersionDryRun(inv.Context(), templateVersion.ID, codersdk.CreateTemplateVersionDryRunRequest{ WorkspaceName: args.NewWorkspaceName, ParameterValues: legacyParameters, RichParameterValues: richParameters, @@ -350,16 +374,16 @@ PromptRichParamLoop: if err != nil { return nil, xerrors.Errorf("begin workspace dry-run: %w", err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Planning workspace...") - err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ + _, _ = fmt.Fprintln(inv.Stdout, "Planning workspace...") + err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{ Fetch: func() (codersdk.ProvisionerJob, error) { - return client.TemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) + return client.TemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID) }, Cancel: func() error { - return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) + return client.CancelTemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID) }, Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { - return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, 0) + return client.TemplateVersionDryRunLogsAfter(inv.Context(), templateVersion.ID, dryRun.ID, 0) }, // Don't show log output for the dry-run unless there's an error. Silent: true, @@ -370,19 +394,19 @@ PromptRichParamLoop: return nil, xerrors.Errorf("dry-run workspace: %w", err) } - resources, err := client.TemplateVersionDryRunResources(cmd.Context(), templateVersion.ID, dryRun.ID) + resources, err := client.TemplateVersionDryRunResources(inv.Context(), templateVersion.ID, dryRun.ID) if err != nil { return nil, xerrors.Errorf("get workspace dry-run resources: %w", err) } - err = cliui.WorkspaceResources(cmd.OutOrStdout(), resources, cliui.WorkspaceResourcesOptions{ + err = cliui.WorkspaceResources(inv.Stdout, resources, cliui.WorkspaceResourcesOptions{ WorkspaceName: args.NewWorkspaceName, // Since agents haven't connected yet, hiding this makes more sense. HideAgentState: true, Title: "Workspace Preview", }) if err != nil { - return nil, err + return nil, xerrors.Errorf("get resources: %w", err) } return &buildParameters{ diff --git a/cli/create_test.go b/cli/create_test.go index 304fb7c49957a..f215e8ae8bf35 100644 --- a/cli/create_test.go +++ b/cli/create_test.go @@ -42,15 +42,13 @@ func TestCreate(t *testing.T) { "--start-at", "9:30AM Mon-Fri US/Central", "--stop-after", "8h", } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() matches := []struct { @@ -100,17 +98,10 @@ func TestCreate(t *testing.T) { "my-workspace", "--template", template.Name, } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + pty := ptytest.New(t).Attach(inv) + waiter := clitest.StartWithWaiter(t, inv) matches := []struct { match string write string @@ -125,7 +116,7 @@ func TestCreate(t *testing.T) { pty.WriteLine(m.write) } } - <-doneChan + waiter.RequireSuccess() ws, err := client.WorkspaceByOwnerAndName(context.Background(), "testuser", "my-workspace", codersdk.WorkspaceOptions{}) require.NoError(t, err, "expected workspace to be created") @@ -140,14 +131,14 @@ func TestCreate(t *testing.T) { version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) coderdtest.AwaitTemplateVersionJob(t, client, version.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "-y") + inv, root := clitest.New(t, "create", "my-workspace", "-y") member, _ := coderdtest.CreateAnotherUser(t, client, user.OrganizationID) clitest.SetupConfig(t, member, root) cmdCtx, done := context.WithTimeout(context.Background(), testutil.WaitLong) go func() { defer done() - err := cmd.ExecuteContext(cmdCtx) + err := inv.WithContext(cmdCtx).Run() assert.NoError(t, err) }() // No pty interaction needed since we use the -y skip prompt flag @@ -162,15 +153,13 @@ func TestCreate(t *testing.T) { version := coderdtest.CreateTemplateVersion(t, client, user.OrganizationID, nil) coderdtest.AwaitTemplateVersionJob(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "") + inv, root := clitest.New(t, "create", "") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() matches := []string{ @@ -185,7 +174,7 @@ func TestCreate(t *testing.T) { } <-doneChan - ws, err := client.WorkspaceByOwnerAndName(cmd.Context(), "testuser", "my-workspace", codersdk.WorkspaceOptions{}) + ws, err := client.WorkspaceByOwnerAndName(inv.Context(), "testuser", "my-workspace", codersdk.WorkspaceOptions{}) if assert.NoError(t, err, "expected workspace to be created") { assert.Equal(t, ws.TemplateName, template.Name) assert.Nil(t, ws.AutostartSchedule, "expected workspace autostart schedule to be nil") @@ -206,15 +195,13 @@ func TestCreate(t *testing.T) { coderdtest.AwaitTemplateVersionJob(t, client, version.ID) _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "") + inv, root := clitest.New(t, "create", "") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -251,15 +238,13 @@ func TestCreate(t *testing.T) { removeTmpDirUntilSuccessAfterTest(t, tempDir) parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") _, _ = parameterFile.WriteString("region: \"bingo\"\nusername: \"boingo\"") - cmd, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -296,15 +281,13 @@ func TestCreate(t *testing.T) { parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") _, _ = parameterFile.WriteString("username: \"boingo\"") - cmd, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "create", "", "--parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() matches := []struct { @@ -364,13 +347,11 @@ func TestCreate(t *testing.T) { require.Equal(t, codersdk.ProvisionerJobSucceeded, version.Job.Status, "job is not failed") _ = coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "test", "--parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "create", "test", "--parameter-file", parameterFile.Name(), "-y") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + ptytest.New(t).Attach(inv) - err = cmd.Execute() + err = inv.Run() require.Error(t, err) require.ErrorContains(t, err, "dry-run workspace") }) @@ -425,15 +406,13 @@ func TestCreateWithRichParameters(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -469,16 +448,14 @@ func TestCreateWithRichParameters(t *testing.T) { firstParameterName + ": " + firstParameterValue + "\n" + secondParameterName + ": " + secondParameterValue + "\n" + immutableParameterName + ": " + immutableParameterValue) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -559,15 +536,13 @@ func TestCreateValidateRichParameters(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -596,15 +571,13 @@ func TestCreateValidateRichParameters(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -636,15 +609,13 @@ func TestCreateValidateRichParameters(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() @@ -672,17 +643,10 @@ func TestCreateValidateRichParameters(t *testing.T) { coderdtest.AwaitTemplateVersionJob(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) matches := []string{ listOfStringsParameterName, "", @@ -697,7 +661,6 @@ func TestCreateValidateRichParameters(t *testing.T) { pty.WriteLine(value) } } - <-doneChan }) t.Run("ValidateListOfStrings_YAMLFile", func(t *testing.T) { @@ -716,17 +679,11 @@ func TestCreateValidateRichParameters(t *testing.T) { - ddd - eee - fff`) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name, "--rich-parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + pty := ptytest.New(t).Attach(inv) + + clitest.Start(t, inv) matches := []string{ "Confirm create?", "yes", @@ -739,7 +696,6 @@ func TestCreateValidateRichParameters(t *testing.T) { pty.WriteLine(value) } } - <-doneChan }) } @@ -777,17 +733,10 @@ func TestCreateWithGitAuth(t *testing.T) { coderdtest.AwaitTemplateVersionJob(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) + inv, root := clitest.New(t, "create", "my-workspace", "--template", template.Name) clitest.SetupConfig(t, client, root) - doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := cmd.Execute() - assert.NoError(t, err) - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) pty.ExpectMatch("You must authenticate with GitHub to create a workspace") resp := coderdtest.RequestGitAuthCallback(t, "github", client) @@ -795,7 +744,6 @@ func TestCreateWithGitAuth(t *testing.T) { require.Equal(t, http.StatusTemporaryRedirect, resp.StatusCode) pty.ExpectMatch("Confirm create?") pty.WriteLine("yes") - <-doneChan } func createTestParseResponseWithDefault(defaultValue string) []*proto.Parse_Response { diff --git a/cli/delete.go b/cli/delete.go index 4c655339d8a8f..24443402fbe72 100644 --- a/cli/delete.go +++ b/cli/delete.go @@ -4,23 +4,25 @@ import ( "fmt" "time" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) // nolint -func deleteWorkspace() *cobra.Command { +func (r *RootCmd) deleteWorkspace() *clibase.Cmd { var orphan bool - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "delete ", Short: "Delete a workspace", - Aliases: []string{"rm"}, - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + _, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm delete workspace?", IsConfirm: true, Default: cliui.ConfirmNo, @@ -29,11 +31,7 @@ func deleteWorkspace() *cobra.Command { return err } - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } @@ -42,12 +40,12 @@ func deleteWorkspace() *cobra.Command { if orphan { cliui.Warn( - cmd.ErrOrStderr(), + inv.Stderr, "Orphaning workspace requires template edit permission", ) } - build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ Transition: codersdk.WorkspaceTransitionDelete, ProvisionerState: state, Orphan: orphan, @@ -56,19 +54,23 @@ func deleteWorkspace() *cobra.Command { return err } - err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) + err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID) if err != nil { return err } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been deleted at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) + _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been deleted at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) return nil }, } - cmd.Flags().BoolVar(&orphan, "orphan", false, - `Delete a workspace without deleting its resources. This can delete a -workspace in a broken state, but may also lead to unaccounted cloud resources.`, - ) - cliui.AllowSkipPrompt(cmd) + cmd.Options = clibase.OptionSet{ + { + Flag: "orphan", + Description: "Delete a workspace without deleting its resources. This can delete a workspace in a broken state, but may also lead to unaccounted cloud resources.", + + Value: clibase.BoolOf(&orphan), + }, + cliui.SkipPromptOption(), + } return cmd } diff --git a/cli/delete_test.go b/cli/delete_test.go index 2f2cf404fc6ab..c79d07b075425 100644 --- a/cli/delete_test.go +++ b/cli/delete_test.go @@ -25,15 +25,13 @@ func TestDelete(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "delete", workspace.Name, "-y") + inv, root := clitest.New(t, "delete", workspace.Name, "-y") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() // When running with the race detector on, we sometimes get an EOF. if err != nil { assert.ErrorIs(t, err, io.EOF) @@ -52,17 +50,15 @@ func TestDelete(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "delete", workspace.Name, "-y", "--orphan") + inv, root := clitest.New(t, "delete", workspace.Name, "-y", "--orphan") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() // When running with the race detector on, we sometimes get an EOF. if err != nil { assert.ErrorIs(t, err, io.EOF) @@ -87,15 +83,13 @@ func TestDelete(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, orgID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "delete", user.Username+"/"+workspace.Name, "-y") + inv, root := clitest.New(t, "delete", user.Username+"/"+workspace.Name, "-y") clitest.SetupConfig(t, adminClient, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() // When running with the race detector on, we sometimes get an EOF. if err != nil { assert.ErrorIs(t, err, io.EOF) @@ -112,12 +106,12 @@ func TestDelete(t *testing.T) { t.Run("InvalidWorkspaceIdentifier", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) - cmd, root := clitest.New(t, "delete", "a/b/c", "-y") + inv, root := clitest.New(t, "delete", "a/b/c", "-y") clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.ErrorContains(t, err, "invalid workspace name: \"a/b/c\"") }() <-doneChan diff --git a/cli/dotfiles.go b/cli/dotfiles.go index 4adb06cc05f9a..c0473dae336a1 100644 --- a/cli/dotfiles.go +++ b/cli/dotfiles.go @@ -10,30 +10,29 @@ import ( "strings" "time" - "github.com/spf13/cobra" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" ) -func dotfiles() *cobra.Command { +func (r *RootCmd) dotfiles() *clibase.Cmd { var symlinkDir string - cmd := &cobra.Command{ - Use: "dotfiles [git_repo_url]", - Args: cobra.ExactArgs(1), - Short: "Checkout and install a dotfiles repository from a Git URL", - Example: formatExamples( + cmd := &clibase.Cmd{ + Use: "dotfiles ", + Middleware: clibase.RequireNArgs(1), + Short: "Personalize your workspace by applying a canonical dotfiles repository", + Long: formatExamples( example{ Description: "Check out and install a dotfiles repository without prompts", Command: "coder dotfiles --yes git@github.com:example/dotfiles.git", }, ), - RunE: func(cmd *cobra.Command, args []string) error { + Handler: func(inv *clibase.Invocation) error { var ( dotfilesRepoDir = "dotfiles" - gitRepo = args[0] - cfg = createConfig(cmd) + gitRepo = inv.Args[0] + cfg = r.createConfig() cfgDir = string(cfg) dotfilesDir = filepath.Join(cfgDir, dotfilesRepoDir) // This follows the same pattern outlined by others in the market: @@ -50,7 +49,11 @@ func dotfiles() *cobra.Command { } ) - _, _ = fmt.Fprint(cmd.OutOrStdout(), "Checking if dotfiles repository already exists...\n") + if cfg == "" { + return xerrors.Errorf("no config directory") + } + + _, _ = fmt.Fprint(inv.Stdout, "Checking if dotfiles repository already exists...\n") dotfilesExists, err := dirExists(dotfilesDir) if err != nil { return xerrors.Errorf("checking dir %s: %w", dotfilesDir, err) @@ -65,7 +68,7 @@ func dotfiles() *cobra.Command { // if the git url has changed we create a backup and clone fresh if gitRepo != du { backupDir := fmt.Sprintf("%s_backup_%s", dotfilesDir, time.Now().Format(time.RFC3339)) - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("The dotfiles URL has changed from %q to %q.\n Coder will backup the existing repo to %s.\n\n Continue?", du, gitRepo, backupDir), IsConfirm: true, }) @@ -77,7 +80,7 @@ func dotfiles() *cobra.Command { if err != nil { return xerrors.Errorf("renaming dir %s: %w", dotfilesDir, err) } - _, _ = fmt.Fprint(cmd.OutOrStdout(), "Done backup up dotfiles.\n") + _, _ = fmt.Fprint(inv.Stdout, "Done backup up dotfiles.\n") dotfilesExists = false moved = true } @@ -89,20 +92,20 @@ func dotfiles() *cobra.Command { promptText string ) if dotfilesExists { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Found dotfiles repository at %s\n", dotfilesDir) + _, _ = fmt.Fprintf(inv.Stdout, "Found dotfiles repository at %s\n", dotfilesDir) gitCmdDir = dotfilesDir subcommands = []string{"pull", "--ff-only"} promptText = fmt.Sprintf("Pulling latest from %s into directory %s.\n Continue?", gitRepo, dotfilesDir) } else { if !moved { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Did not find dotfiles repository at %s\n", dotfilesDir) + _, _ = fmt.Fprintf(inv.Stdout, "Did not find dotfiles repository at %s\n", dotfilesDir) } gitCmdDir = cfgDir - subcommands = []string{"clone", args[0], dotfilesRepoDir} + subcommands = []string{"clone", inv.Args[0], dotfilesRepoDir} promptText = fmt.Sprintf("Cloning %s into directory %s.\n\n Continue?", gitRepo, dotfilesDir) } - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: promptText, IsConfirm: true, }) @@ -113,7 +116,7 @@ func dotfiles() *cobra.Command { // ensure command dir exists err = os.MkdirAll(gitCmdDir, 0o750) if err != nil { - return xerrors.Errorf("ensuring dir at %s: %w", gitCmdDir, err) + return xerrors.Errorf("ensuring dir at %q: %w", gitCmdDir, err) } // check if git ssh command already exists so we can just wrap it @@ -123,18 +126,18 @@ func dotfiles() *cobra.Command { } // clone or pull repo - c := exec.CommandContext(cmd.Context(), "git", subcommands...) + c := exec.CommandContext(inv.Context(), "git", subcommands...) c.Dir = gitCmdDir - c.Env = append(os.Environ(), fmt.Sprintf(`GIT_SSH_COMMAND=%s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no`, gitsshCmd)) - c.Stdout = cmd.OutOrStdout() - c.Stderr = cmd.ErrOrStderr() + c.Env = append(inv.Environ.ToOS(), fmt.Sprintf(`GIT_SSH_COMMAND=%s -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no`, gitsshCmd)) + c.Stdout = inv.Stdout + c.Stderr = inv.Stderr err = c.Run() if err != nil { if !dotfilesExists { return err } // if the repo exists we soft fail the update operation and try to continue - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Error.Render("Failed to update repo, continuing...")) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Error.Render("Failed to update repo, continuing...")) } // save git repo url so we can detect changes next time @@ -158,7 +161,7 @@ func dotfiles() *cobra.Command { script := findScript(installScriptSet, files) if script != "" { - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("Running install script %s.\n\n Continue?", script), IsConfirm: true, }) @@ -166,29 +169,29 @@ func dotfiles() *cobra.Command { return err } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Running %s...\n", script) + _, _ = fmt.Fprintf(inv.Stdout, "Running %s...\n", script) // it is safe to use a variable command here because it's from // a filtered list of pre-approved install scripts // nolint:gosec - scriptCmd := exec.CommandContext(cmd.Context(), filepath.Join(dotfilesDir, script)) + scriptCmd := exec.CommandContext(inv.Context(), filepath.Join(dotfilesDir, script)) scriptCmd.Dir = dotfilesDir - scriptCmd.Stdout = cmd.OutOrStdout() - scriptCmd.Stderr = cmd.ErrOrStderr() + scriptCmd.Stdout = inv.Stdout + scriptCmd.Stderr = inv.Stderr err = scriptCmd.Run() if err != nil { return xerrors.Errorf("running %s: %w", script, err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Dotfiles installation complete.") + _, _ = fmt.Fprintln(inv.Stdout, "Dotfiles installation complete.") return nil } if len(dotfiles) == 0 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "No install scripts or dotfiles found, nothing to do.") + _, _ = fmt.Fprintln(inv.Stdout, "No install scripts or dotfiles found, nothing to do.") return nil } - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "No install scripts found, symlinking dotfiles to home directory.\n\n Continue?", IsConfirm: true, }) @@ -206,7 +209,7 @@ func dotfiles() *cobra.Command { for _, df := range dotfiles { from := filepath.Join(dotfilesDir, df) to := filepath.Join(symlinkDir, df) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Symlinking %s to %s...\n", from, to) + _, _ = fmt.Fprintf(inv.Stdout, "Symlinking %s to %s...\n", from, to) isRegular, err := isRegular(to) if err != nil { @@ -215,7 +218,7 @@ func dotfiles() *cobra.Command { // move conflicting non-symlink files to file.ext.bak if isRegular { backup := fmt.Sprintf("%s.bak", to) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Moving %s to %s...\n", to, backup) + _, _ = fmt.Fprintf(inv.Stdout, "Moving %s to %s...\n", to, backup) err = os.Rename(to, backup) if err != nil { return xerrors.Errorf("renaming dir %s: %w", to, err) @@ -228,13 +231,19 @@ func dotfiles() *cobra.Command { } } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Dotfiles installation complete.") + _, _ = fmt.Fprintln(inv.Stdout, "Dotfiles installation complete.") return nil }, } - cliui.AllowSkipPrompt(cmd) - cliflag.StringVarP(cmd.Flags(), &symlinkDir, "symlink-dir", "", "CODER_SYMLINK_DIR", "", "Specifies the directory for the dotfiles symlink destinations. If empty will use $HOME.") - + cmd.Options = clibase.OptionSet{ + { + Flag: "symlink-dir", + Env: "CODER_SYMLINK_DIR", + Description: "Specifies the directory for the dotfiles symlink destinations. If empty, will use $HOME.", + Value: clibase.StringOf(&symlinkDir), + }, + cliui.SkipPromptOption(), + } return cmd } diff --git a/cli/dotfiles_test.go b/cli/dotfiles_test.go index 479baf3f9a05a..e579010912306 100644 --- a/cli/dotfiles_test.go +++ b/cli/dotfiles_test.go @@ -15,14 +15,16 @@ import ( "github.com/coder/coder/cryptorand" ) -// nolint:paralleltest func TestDotfiles(t *testing.T) { + t.Parallel() t.Run("MissingArg", func(t *testing.T) { - cmd, _ := clitest.New(t, "dotfiles") - err := cmd.Execute() + t.Parallel() + inv, _ := clitest.New(t, "dotfiles") + err := inv.Run() require.Error(t, err) }) t.Run("NoInstallScript", func(t *testing.T) { + t.Parallel() _, root := clitest.New(t) testRepo := testGitRepo(t, root) @@ -40,8 +42,8 @@ func TestDotfiles(t *testing.T) { out, err := c.CombinedOutput() require.NoError(t, err, string(out)) - cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) - err = cmd.Execute() + inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) + err = inv.Run() require.NoError(t, err) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) @@ -49,6 +51,7 @@ func TestDotfiles(t *testing.T) { require.Equal(t, string(b), "wow") }) t.Run("InstallScript", func(t *testing.T) { + t.Parallel() if runtime.GOOS == "windows" { t.Skip("install scripts on windows require sh and aren't very practical") } @@ -69,8 +72,8 @@ func TestDotfiles(t *testing.T) { err = c.Run() require.NoError(t, err) - cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) - err = cmd.Execute() + inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) + err = inv.Run() require.NoError(t, err) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) @@ -78,6 +81,7 @@ func TestDotfiles(t *testing.T) { require.Equal(t, string(b), "wow\n") }) t.Run("SymlinkBackup", func(t *testing.T) { + t.Parallel() _, root := clitest.New(t) testRepo := testGitRepo(t, root) @@ -100,8 +104,8 @@ func TestDotfiles(t *testing.T) { out, err := c.CombinedOutput() require.NoError(t, err, string(out)) - cmd, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) - err = cmd.Execute() + inv, _ := clitest.New(t, "dotfiles", "--global-config", string(root), "--symlink-dir", string(root), "-y", testRepo) + err = inv.Run() require.NoError(t, err) b, err := os.ReadFile(filepath.Join(string(root), ".bashrc")) diff --git a/cli/gitaskpass.go b/cli/gitaskpass.go index 4c78c47728db3..cbf5bd3315ad2 100644 --- a/cli/gitaskpass.go +++ b/cli/gitaskpass.go @@ -7,9 +7,9 @@ import ( "os/signal" "time" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/gitauth" "github.com/coder/coder/codersdk" @@ -18,23 +18,22 @@ import ( // gitAskpass is used by the Coder agent to automatically authenticate // with Git providers based on a hostname. -func gitAskpass() *cobra.Command { - return &cobra.Command{ +func (r *RootCmd) gitAskpass() *clibase.Cmd { + return &clibase.Cmd{ Use: "gitaskpass", Hidden: true, - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() ctx, stop := signal.NotifyContext(ctx, InterruptSignals...) defer stop() - user, host, err := gitauth.ParseAskpass(args[0]) + user, host, err := gitauth.ParseAskpass(inv.Args[0]) if err != nil { return xerrors.Errorf("parse host: %w", err) } - client, err := createAgentClient(cmd) + client, err := r.createAgentClient() if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -45,16 +44,16 @@ func gitAskpass() *cobra.Command { if errors.As(err, &apiError) && apiError.StatusCode() == http.StatusNotFound { // This prevents the "Run 'coder --help' for usage" // message from occurring. - cmd.Printf("%s\n", apiError.Message) + cliui.Errorf(inv.Stderr, "%s\n", apiError.Message) return cliui.Canceled } return xerrors.Errorf("get git token: %w", err) } if token.URL != "" { - if err := openURL(cmd, token.URL); err == nil { - cmd.Printf("Your browser has been opened to authenticate with Git:\n\n\t%s\n\n", token.URL) + if err := openURL(inv, token.URL); err == nil { + cliui.Infof(inv.Stdout, "Your browser has been opened to authenticate with Git:\n\n\t%s\n\n", token.URL) } else { - cmd.Printf("Open the following URL to authenticate with Git:\n\n\t%s\n\n", token.URL) + cliui.Infof(inv.Stdout, "Open the following URL to authenticate with Git:\n\n\t%s\n\n", token.URL) } for r := retry.New(250*time.Millisecond, 10*time.Second); r.Wait(ctx); { @@ -62,19 +61,19 @@ func gitAskpass() *cobra.Command { if err != nil { continue } - cmd.Printf("You've been authenticated with Git!\n") + cliui.Infof(inv.Stdout, "You've been authenticated with Git!\n") break } } if token.Password != "" { if user == "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Username) + _, _ = fmt.Fprintln(inv.Stdout, token.Username) } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Password) + _, _ = fmt.Fprintln(inv.Stdout, token.Password) } } else { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), token.Username) + _, _ = fmt.Fprintln(inv.Stdout, token.Username) } return nil diff --git a/cli/gitaskpass_test.go b/cli/gitaskpass_test.go index 2e3bdc88505e7..db64a522aeb57 100644 --- a/cli/gitaskpass_test.go +++ b/cli/gitaskpass_test.go @@ -18,10 +18,10 @@ import ( "github.com/coder/coder/pty/ptytest" ) -// nolint:paralleltest func TestGitAskpass(t *testing.T) { - t.Setenv("GIT_PREFIX", "/") + t.Parallel() t.Run("UsernameAndPassword", func(t *testing.T) { + t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpapi.Write(context.Background(), w, http.StatusOK, agentsdk.GitAuthResponse{ Username: "something", @@ -30,22 +30,23 @@ func TestGitAskpass(t *testing.T) { })) t.Cleanup(srv.Close) url := srv.URL - cmd, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':") + inv, _ := clitest.New(t, "--agent-url", url, "Username for 'https://github.com':") + inv.Environ.Set("GIT_PREFIX", "/") pty := ptytest.New(t) - cmd.SetOutput(pty.Output()) - err := cmd.Execute() - require.NoError(t, err) + inv.Stdout = pty.Output() + clitest.Start(t, inv) pty.ExpectMatch("something") - cmd, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':") + inv, _ = clitest.New(t, "--agent-url", url, "Password for 'https://potato@github.com':") + inv.Environ.Set("GIT_PREFIX", "/") pty = ptytest.New(t) - cmd.SetOutput(pty.Output()) - err = cmd.Execute() - require.NoError(t, err) + inv.Stdout = pty.Output() + clitest.Start(t, inv) pty.ExpectMatch("bananas") }) t.Run("NoHost", func(t *testing.T) { + t.Parallel() srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { httpapi.Write(context.Background(), w, http.StatusNotFound, codersdk.Response{ Message: "Nope!", @@ -53,15 +54,17 @@ func TestGitAskpass(t *testing.T) { })) t.Cleanup(srv.Close) url := srv.URL - cmd, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") + inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") + inv.Environ.Set("GIT_PREFIX", "/") pty := ptytest.New(t) - cmd.SetOutput(pty.Output()) - err := cmd.Execute() + inv.Stderr = pty.Output() + err := inv.Run() require.ErrorIs(t, err, cliui.Canceled) pty.ExpectMatch("Nope!") }) t.Run("Poll", func(t *testing.T) { + t.Parallel() resp := atomic.Pointer[agentsdk.GitAuthResponse]{} resp.Store(&agentsdk.GitAuthResponse{ URL: "https://something.org", @@ -81,11 +84,12 @@ func TestGitAskpass(t *testing.T) { t.Cleanup(srv.Close) url := srv.URL - cmd, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") + inv, _ := clitest.New(t, "--agent-url", url, "--no-open", "Username for 'https://github.com':") + inv.Environ.Set("GIT_PREFIX", "/") pty := ptytest.New(t) - cmd.SetOutput(pty.Output()) + inv.Stdout = pty.Output() go func() { - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() <-poll diff --git a/cli/gitssh.go b/cli/gitssh.go index 02a985abee22f..70af9ebd3ef08 100644 --- a/cli/gitssh.go +++ b/cli/gitssh.go @@ -12,19 +12,19 @@ import ( "path/filepath" "strings" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" ) -func gitssh() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) gitssh() *clibase.Cmd { + cmd := &clibase.Cmd{ Use: "gitssh", Hidden: true, Short: `Wraps the "ssh" command and uses the coder gitssh key for authentication`, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() env := os.Environ() // Catch interrupt signals to ensure the temporary private @@ -33,12 +33,12 @@ func gitssh() *cobra.Command { defer stop() // Early check so errors are reported immediately. - identityFiles, err := parseIdentityFilesForHost(ctx, args, env) + identityFiles, err := parseIdentityFilesForHost(ctx, inv.Args, env) if err != nil { return err } - client, err := createAgentClient(cmd) + client, err := r.createAgentClient() if err != nil { return xerrors.Errorf("create agent client: %w", err) } @@ -78,24 +78,25 @@ func gitssh() *cobra.Command { identityArgs = append(identityArgs, "-i", id) } + args := inv.Args args = append(identityArgs, args...) c := exec.CommandContext(ctx, "ssh", args...) c.Env = append(c.Env, env...) - c.Stderr = cmd.ErrOrStderr() - c.Stdout = cmd.OutOrStdout() - c.Stdin = cmd.InOrStdin() + c.Stderr = inv.Stderr + c.Stdout = inv.Stdout + c.Stdin = inv.Stdin err = c.Run() if err != nil { exitErr := &exec.ExitError{} if xerrors.As(err, &exitErr) && exitErr.ExitCode() == 255 { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), + _, _ = fmt.Fprintln(inv.Stderr, "\n"+cliui.Styles.Wrap.Render("Coder authenticates with "+cliui.Styles.Field.Render("git")+ " using the public key below. All clones with SSH are authenticated automatically 🪄.")+"\n") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Add to GitHub and GitLab:") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys") - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) + _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n") + _, _ = fmt.Fprintln(inv.Stderr, "Add to GitHub and GitLab:") + _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new") + _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys") + _, _ = fmt.Fprintln(inv.Stderr) return err } return xerrors.Errorf("run ssh command: %w", err) diff --git a/cli/gitssh_test.go b/cli/gitssh_test.go index cc7a37ca00578..6d7dfe7518e22 100644 --- a/cli/gitssh_test.go +++ b/cli/gitssh_test.go @@ -57,15 +57,12 @@ func prepareTestGitSSH(ctx context.Context, t *testing.T) (*codersdk.Client, str coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) // start workspace agent - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) + inv, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) agentClient := client clitest.SetupConfig(t, agentClient, root) - errC := make(chan error, 1) - go func() { - errC <- cmd.ExecuteContext(ctx) - }() - t.Cleanup(func() { require.NoError(t, <-errC) }) + clitest.Start(t, inv) + coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) return agentClient, agentToken, pubkey } @@ -141,7 +138,7 @@ func TestGitSSH(t *testing.T) { }, pubkey) // set to agent config dir - cmd, _ := clitest.New(t, + inv, _ := clitest.New(t, "gitssh", "--agent-url", client.URL.String(), "--agent-token", token, @@ -151,7 +148,7 @@ func TestGitSSH(t *testing.T) { "-o", "IdentitiesOnly=yes", "127.0.0.1", ) - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() require.NoError(t, err) require.EqualValues(t, 1, inc) @@ -213,10 +210,10 @@ func TestGitSSH(t *testing.T) { "mytest", } // Test authentication via local private key. - cmd, _ := clitest.New(t, cmdArgs...) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) - err = cmd.ExecuteContext(ctx) + inv, _ := clitest.New(t, cmdArgs...) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + err = inv.WithContext(ctx).Run() require.NoError(t, err) select { case key := <-authkey: @@ -230,10 +227,10 @@ func TestGitSSH(t *testing.T) { require.NoError(t, err) // With the local file deleted, the coder key should be used. - cmd, _ = clitest.New(t, cmdArgs...) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) - err = cmd.ExecuteContext(ctx) + inv, _ = clitest.New(t, cmdArgs...) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + err = inv.WithContext(ctx).Run() require.NoError(t, err) select { case key := <-authkey: diff --git a/cli/help.go b/cli/help.go new file mode 100644 index 0000000000000..5720e2a7af71b --- /dev/null +++ b/cli/help.go @@ -0,0 +1,292 @@ +package cli + +import ( + "bufio" + "bytes" + _ "embed" + "fmt" + "io" + "regexp" + "sort" + "strings" + "text/tabwriter" + "text/template" + "unicode" + + "github.com/mitchellh/go-wordwrap" + "golang.org/x/crypto/ssh/terminal" + "golang.org/x/xerrors" + + "github.com/coder/coder/cli/clibase" + "github.com/coder/coder/cli/cliui" +) + +//go:embed help.tpl +var helpTemplateRaw string + +type optionGroup struct { + Name string + Description string + Options clibase.OptionSet +} + +func ttyWidth() int { + width, _, err := terminal.GetSize(0) + if err != nil { + return 80 + } + return width +} + +// wrapTTY wraps a string to the width of the terminal, or 80 no terminal +// is detected. +func wrapTTY(s string) string { + return wordwrap.WrapString(s, uint(ttyWidth())) +} + +var usageTemplate = template.Must( + template.New("usage").Funcs( + template.FuncMap{ + "wrapTTY": func(s string) string { + return wrapTTY(s) + }, + "trimNewline": func(s string) string { + return strings.TrimSuffix(s, "\n") + }, + "typeHelper": func(opt *clibase.Option) string { + switch v := opt.Value.(type) { + case *clibase.Enum: + return strings.Join(v.Choices, "|") + default: + return v.Type() + } + }, + "joinStrings": func(s []string) string { + return strings.Join(s, ", ") + }, + "indent": func(body string, spaces int) string { + twidth := ttyWidth() + + spacing := strings.Repeat(" ", spaces) + + body = wordwrap.WrapString(body, uint(twidth-len(spacing))) + + var sb strings.Builder + for _, line := range strings.Split(body, "\n") { + // Remove existing indent, if any. + line = strings.TrimSpace(line) + // Use spaces so we can easily calculate wrapping. + _, _ = sb.WriteString(spacing) + _, _ = sb.WriteString(line) + _, _ = sb.WriteString("\n") + } + return sb.String() + }, + "formatSubcommand": func(cmd *clibase.Cmd) string { + // Minimize padding by finding the longest neighboring name. + maxNameLength := len(cmd.Name()) + if parent := cmd.Parent; parent != nil { + for _, c := range parent.Children { + if len(c.Name()) > maxNameLength { + maxNameLength = len(c.Name()) + } + } + } + + var sb strings.Builder + _, _ = fmt.Fprintf( + &sb, "%s%s%s", + strings.Repeat(" ", 4), cmd.Name(), strings.Repeat(" ", maxNameLength-len(cmd.Name())+4), + ) + + // This is the point at which indentation begins if there's a + // next line. + descStart := sb.Len() + + twidth := ttyWidth() + + for i, line := range strings.Split( + wordwrap.WrapString(cmd.Short, uint(twidth-descStart)), "\n", + ) { + if i > 0 { + _, _ = sb.WriteString(strings.Repeat(" ", descStart)) + } + _, _ = sb.WriteString(line) + _, _ = sb.WriteString("\n") + } + + return sb.String() + }, + "envName": func(opt clibase.Option) string { + if opt.Env == "" { + return "" + } + return opt.Env + }, + "flagName": func(opt clibase.Option) string { + return opt.Flag + }, + "prettyHeader": func(s string) string { + return cliui.Styles.Bold.Render(s) + }, + "isEnterprise": func(opt clibase.Option) bool { + return opt.Annotations.IsSet("enterprise") + }, + "isDeprecated": func(opt clibase.Option) bool { + return len(opt.UseInstead) > 0 + }, + "formatLong": func(long string) string { + // We intentionally don't wrap here because it would misformat + // examples, where the new line would start without the prior + // line's indentation. + return strings.TrimSpace(long) + }, + "formatGroupDescription": func(s string) string { + s = strings.ReplaceAll(s, "\n", "") + s = s + "\n" + s = wrapTTY(s) + return s + }, + "visibleChildren": func(cmd *clibase.Cmd) []*clibase.Cmd { + return filterSlice(cmd.Children, func(c *clibase.Cmd) bool { + return !c.Hidden + }) + }, + "optionGroups": func(cmd *clibase.Cmd) []optionGroup { + groups := []optionGroup{{ + // Default group. + Name: "", + Description: "", + }} + + enterpriseGroup := optionGroup{ + Name: "Enterprise", + Description: `These options are only available in the Enterprise Edition.`, + } + + // Sort options lexicographically. + sort.Slice(cmd.Options, func(i, j int) bool { + return cmd.Options[i].Name < cmd.Options[j].Name + }) + + optionLoop: + for _, opt := range cmd.Options { + if opt.Hidden { + continue + } + // Enterprise options are always grouped separately. + if opt.Annotations.IsSet("enterprise") { + enterpriseGroup.Options = append(enterpriseGroup.Options, opt) + continue + } + if len(opt.Group.Ancestry()) == 0 { + // Just add option to default group. + groups[0].Options = append(groups[0].Options, opt) + continue + } + + groupName := opt.Group.FullName() + + for i, foundGroup := range groups { + if foundGroup.Name != groupName { + continue + } + groups[i].Options = append(groups[i].Options, opt) + continue optionLoop + } + + groups = append(groups, optionGroup{ + Name: groupName, + Description: opt.Group.Description, + Options: clibase.OptionSet{opt}, + }) + } + sort.Slice(groups, func(i, j int) bool { + // Sort groups lexicographically. + return groups[i].Name < groups[j].Name + }) + + // Always show enterprise group last. + groups = append(groups, enterpriseGroup) + + return filterSlice(groups, func(g optionGroup) bool { + return len(g.Options) > 0 + }) + }, + }, + ).Parse(helpTemplateRaw), +) + +func filterSlice[T any](s []T, f func(T) bool) []T { + var r []T + for _, v := range s { + if f(v) { + r = append(r, v) + } + } + return r +} + +// newLineLimiter makes working with Go templates more bearable. Without this, +// modifying the template is a slow toil of counting newlines and constantly +// checking that a change to one command's help doesn't clobber break another. +type newlineLimiter struct { + w io.Writer + limit int + + newLineCounter int +} + +func (lm *newlineLimiter) Write(p []byte) (int, error) { + rd := bytes.NewReader(p) + for r, n, _ := rd.ReadRune(); n > 0; r, n, _ = rd.ReadRune() { + switch { + case r == '\r': + // Carriage returns can sneak into `help.tpl` when `git clone` + // is configured to automatically convert line endings. + continue + case r == '\n': + lm.newLineCounter++ + if lm.newLineCounter > lm.limit { + continue + } + case !unicode.IsSpace(r): + lm.newLineCounter = 0 + } + _, err := lm.w.Write([]byte(string(r))) + if err != nil { + return 0, err + } + } + return len(p), nil +} + +var usageWantsArgRe = regexp.MustCompile(`<.*>`) + +// helpFn returns a function that generates usage (help) +// output for a given command. +func helpFn() clibase.HandlerFunc { + return func(inv *clibase.Invocation) error { + // We buffer writes to stderr because the newlineLimiter writes one + // rune at a time. + stderrBuf := bufio.NewWriter(inv.Stderr) + out := newlineLimiter{w: stderrBuf, limit: 2} + tabwriter := tabwriter.NewWriter(&out, 0, 0, 2, ' ', 0) + err := usageTemplate.Execute(tabwriter, inv.Command) + if err != nil { + return xerrors.Errorf("execute template: %w", err) + } + err = tabwriter.Flush() + if err != nil { + return err + } + err = stderrBuf.Flush() + if err != nil { + return err + } + if len(inv.Args) > 0 && !usageWantsArgRe.MatchString(inv.Command.Use) { + _, _ = fmt.Fprintf(inv.Stderr, "---\nerror: unknown subcommand %q\n", inv.Args[0]) + } + return nil + } +} diff --git a/cli/help.tpl b/cli/help.tpl new file mode 100644 index 0000000000000..b464cbb248273 --- /dev/null +++ b/cli/help.tpl @@ -0,0 +1,55 @@ +{{- /* Heavily inspired by the Go toolchain formatting. */ -}} +Usage: {{.FullUsage}} + + +{{ with .Short }} +{{- wrapTTY . }} +{{"\n"}} +{{- end}} + +{{ with .Aliases }} +{{ "\n" }} +{{ "Aliases:"}} {{ joinStrings .}} +{{ "\n" }} +{{- end }} + +{{- with .Long}} +{{- formatLong . }} +{{ "\n" }} +{{- end }} +{{ with visibleChildren . }} +{{- range $index, $child := . }} +{{- if eq $index 0 }} +{{ prettyHeader "Subcommands"}} +{{- end }} + {{- "\n" }} + {{- formatSubcommand . | trimNewline }} +{{- end }} +{{- "\n" }} +{{- end }} +{{- range $index, $group := optionGroups . }} +{{ with $group.Name }} {{- print $group.Name " Options" | prettyHeader }} {{ else -}} {{ prettyHeader "Options"}}{{- end -}} +{{- with $group.Description }} +{{ formatGroupDescription . }} +{{- else }} +{{- end }} + {{- range $index, $option := $group.Options }} + {{- if not (eq $option.FlagShorthand "") }}{{- print "\n -" $option.FlagShorthand ", " -}} + {{- else }}{{- print "\n " -}} + {{- end }} + {{- with flagName $option }}--{{ . }}{{ end }} {{- with typeHelper $option }} {{ . }}{{ end }} + {{- with envName $option }}, ${{ . }}{{ end }} + {{- with $option.Default }} (default: {{ . }}){{ end }} + {{- with $option.Description }} + {{- $desc := $option.Description }} +{{ indent $desc 10 }} +{{- if isDeprecated $option }} DEPRECATED {{ end }} + {{- end -}} + {{- end }} +{{- end }} +--- +{{- if .Parent }} +Run `coder --help` for a list of global options. +{{- else }} +Report bugs and request features at https://github.com/coder/coder/issues/new +{{- end }} diff --git a/cli/list.go b/cli/list.go index 33493cf807080..384ff923fa2f5 100644 --- a/cli/list.go +++ b/cli/list.go @@ -5,8 +5,8 @@ import ( "time" "github.com/google/uuid" - "github.com/spf13/cobra" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/util/ptr" @@ -64,7 +64,7 @@ func workspaceListRowFromWorkspace(now time.Time, usersByID map[uuid.UUID]coders } } -func list() *cobra.Command { +func (r *RootCmd) list() *clibase.Cmd { var ( all bool defaultQuery = "owner:me" @@ -75,18 +75,17 @@ func list() *cobra.Command { cliui.JSONFormat(), ) ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "list", Short: "List workspaces", Aliases: []string{"ls"}, - Args: cobra.ExactArgs(0), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - + Middleware: clibase.Chain( + clibase.RequireNArgs(0), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { filter := codersdk.WorkspaceFilter{ FilterQuery: searchQuery, } @@ -94,19 +93,19 @@ func list() *cobra.Command { filter.FilterQuery = "" } - res, err := client.Workspaces(cmd.Context(), filter) + res, err := client.Workspaces(inv.Context(), filter) if err != nil { return err } if len(res.Workspaces) == 0 { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Prompt.String()+"No workspaces found! Create one:") - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), " "+cliui.Styles.Code.Render("coder create ")) - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) + _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Prompt.String()+"No workspaces found! Create one:") + _, _ = fmt.Fprintln(inv.Stderr) + _, _ = fmt.Fprintln(inv.Stderr, " "+cliui.Styles.Code.Render("coder create ")) + _, _ = fmt.Fprintln(inv.Stderr) return nil } - userRes, err := client.Users(cmd.Context(), codersdk.UsersRequest{}) + userRes, err := client.Users(inv.Context(), codersdk.UsersRequest{}) if err != nil { return err } @@ -122,20 +121,31 @@ func list() *cobra.Command { displayWorkspaces[i] = workspaceListRowFromWorkspace(now, usersByID, workspace) } - out, err := formatter.Format(cmd.Context(), displayWorkspaces) + out, err := formatter.Format(inv.Context(), displayWorkspaces) if err != nil { return err } - _, err = fmt.Fprintln(cmd.OutOrStdout(), out) + _, err = fmt.Fprintln(inv.Stdout, out) return err }, } + cmd.Options = clibase.OptionSet{ + { + Flag: "all", + FlagShorthand: "a", + Description: "Specifies whether all workspaces will be listed or not.", - cmd.Flags().BoolVarP(&all, "all", "a", false, - "Specifies whether all workspaces will be listed or not.") - cmd.Flags().StringVar(&searchQuery, "search", defaultQuery, "Search for a workspace with a query.") + Value: clibase.BoolOf(&all), + }, + { + Flag: "search", + Description: "Search for a workspace with a query.", + Default: defaultQuery, + Value: clibase.StringOf(&searchQuery), + }, + } - formatter.AttachFlags(cmd) + formatter.AttachOptions(&cmd.Options) return cmd } diff --git a/cli/list_test.go b/cli/list_test.go index 19b265724b817..39567cd6d9167 100644 --- a/cli/list_test.go +++ b/cli/list_test.go @@ -27,17 +27,15 @@ func TestList(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "ls") + inv, root := clitest.New(t, "ls") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancelFunc() done := make(chan any) go func() { - errC := cmd.ExecuteContext(ctx) + errC := inv.WithContext(ctx).Run() assert.NoError(t, errC) close(done) }() @@ -57,15 +55,15 @@ func TestList(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "list", "--output=json") + inv, root := clitest.New(t, "list", "--output=json") clitest.SetupConfig(t, client, root) ctx, cancelFunc := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancelFunc() out := bytes.NewBuffer(nil) - cmd.SetOut(out) - err := cmd.ExecuteContext(ctx) + inv.Stdout = out + err := inv.WithContext(ctx).Run() require.NoError(t, err) var templates []codersdk.Workspace diff --git a/cli/login.go b/cli/login.go index bda389c3fe771..8f83b77cc5520 100644 --- a/cli/login.go +++ b/cli/login.go @@ -14,10 +14,9 @@ import ( "github.com/go-playground/validator/v10" "github.com/pkg/browser" - "github.com/spf13/cobra" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/userpassword" "github.com/coder/coder/codersdk" @@ -38,7 +37,7 @@ func init() { browser.Stdout = io.Discard } -func login() *cobra.Command { +func (r *RootCmd) login() *clibase.Cmd { const firstUserTrialEnv = "CODER_FIRST_USER_TRIAL" var ( @@ -47,20 +46,16 @@ func login() *cobra.Command { password string trial bool ) - cmd := &cobra.Command{ - Use: "login ", - Short: "Authenticate with Coder deployment", - Args: cobra.MaximumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { + cmd := &clibase.Cmd{ + Use: "login ", + Short: "Authenticate with Coder deployment", + Middleware: clibase.RequireRangeArgs(0, 1), + Handler: func(inv *clibase.Invocation) error { rawURL := "" - if len(args) == 0 { - var err error - rawURL, err = cmd.Flags().GetString(varURL) - if err != nil { - return xerrors.Errorf("get global url flag") - } + if len(inv.Args) == 0 { + rawURL = r.clientURL.String() } else { - rawURL = args[0] + rawURL = inv.Args[0] } if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") { @@ -79,7 +74,7 @@ func login() *cobra.Command { serverURL.Scheme = "https" } - client, err := createUnauthenticatedClient(cmd, serverURL) + client, err := r.createUnauthenticatedClient(serverURL) if err != nil { return err } @@ -87,25 +82,25 @@ func login() *cobra.Command { // Try to check the version of the server prior to logging in. // It may be useful to warn the user if they are trying to login // on a very old client. - err = checkVersions(cmd, client) + err = r.checkVersions(inv, client) if err != nil { // Checking versions isn't a fatal error so we print a warning // and proceed. - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(err.Error())) + _, _ = fmt.Fprintln(inv.Stderr, cliui.Styles.Warn.Render(err.Error())) } - hasInitialUser, err := client.HasFirstUser(cmd.Context()) + hasInitialUser, err := client.HasFirstUser(inv.Context()) if err != nil { return xerrors.Errorf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser? Error - has initial user: %w", serverURL.String(), err) } if !hasInitialUser { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"Your Coder deployment hasn't been set up!\n") + _, _ = fmt.Fprintf(inv.Stdout, Caret+"Your Coder deployment hasn't been set up!\n") if username == "" { - if !isTTY(cmd) { + if !isTTY(inv) { return xerrors.New("the initial user cannot be created in non-interactive mode. use the API") } - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + _, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Would you like to create the first user?", Default: cliui.ConfirmYes, IsConfirm: true, @@ -120,7 +115,7 @@ func login() *cobra.Command { if err != nil { return xerrors.Errorf("get current user: %w", err) } - username, err = cliui.Prompt(cmd, cliui.PromptOptions{ + username, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "What " + cliui.Styles.Field.Render("username") + " would you like?", Default: currentUser.Username, }) @@ -133,7 +128,7 @@ func login() *cobra.Command { } if email == "" { - email, err = cliui.Prompt(cmd, cliui.PromptOptions{ + email, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "What's your " + cliui.Styles.Field.Render("email") + "?", Validate: func(s string) error { err := validator.New().Var(s, "email") @@ -152,7 +147,7 @@ func login() *cobra.Command { var matching bool for !matching { - password, err = cliui.Prompt(cmd, cliui.PromptOptions{ + password, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Enter a " + cliui.Styles.Field.Render("password") + ":", Secret: true, Validate: func(s string) error { @@ -162,7 +157,7 @@ func login() *cobra.Command { if err != nil { return xerrors.Errorf("specify password prompt: %w", err) } - confirm, err := cliui.Prompt(cmd, cliui.PromptOptions{ + confirm, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm " + cliui.Styles.Field.Render("password") + ":", Secret: true, Validate: cliui.ValidateNotEmpty, @@ -173,13 +168,13 @@ func login() *cobra.Command { matching = confirm == password if !matching { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Error.Render("Passwords do not match")) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Error.Render("Passwords do not match")) } } } - if !cmd.Flags().Changed("first-user-trial") && os.Getenv(firstUserTrialEnv) == "" { - v, _ := cliui.Prompt(cmd, cliui.PromptOptions{ + if !inv.ParsedFlags().Changed("first-user-trial") && os.Getenv(firstUserTrialEnv) == "" { + v, _ := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Start a 30-day trial of Enterprise?", IsConfirm: true, Default: "yes", @@ -187,7 +182,7 @@ func login() *cobra.Command { trial = v == "yes" || v == "y" } - _, err = client.CreateFirstUser(cmd.Context(), codersdk.CreateFirstUserRequest{ + _, err = client.CreateFirstUser(inv.Context(), codersdk.CreateFirstUserRequest{ Email: email, Username: username, Password: password, @@ -196,7 +191,7 @@ func login() *cobra.Command { if err != nil { return xerrors.Errorf("create initial user: %w", err) } - resp, err := client.LoginWithPassword(cmd.Context(), codersdk.LoginWithPasswordRequest{ + resp, err := client.LoginWithPassword(inv.Context(), codersdk.LoginWithPasswordRequest{ Email: email, Password: password, }) @@ -205,7 +200,7 @@ func login() *cobra.Command { } sessionToken := resp.SessionToken - config := createConfig(cmd) + config := r.createConfig() err = config.Session().Write(sessionToken) if err != nil { return xerrors.Errorf("write session token: %w", err) @@ -215,32 +210,32 @@ func login() *cobra.Command { return xerrors.Errorf("write server url: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), + _, _ = fmt.Fprintf(inv.Stdout, cliui.Styles.Paragraph.Render(fmt.Sprintf("Welcome to Coder, %s! You're authenticated.", cliui.Styles.Keyword.Render(username)))+"\n") - _, _ = fmt.Fprintf(cmd.OutOrStdout(), + _, _ = fmt.Fprintf(inv.Stdout, cliui.Styles.Paragraph.Render("Get started by creating a template: "+cliui.Styles.Code.Render("coder templates init"))+"\n") return nil } - sessionToken, _ := cmd.Flags().GetString(varToken) + sessionToken, _ := inv.ParsedFlags().GetString(varToken) if sessionToken == "" { authURL := *serverURL // Don't use filepath.Join, we don't want to use the os separator // for a url. authURL.Path = path.Join(serverURL.Path, "/cli-auth") - if err := openURL(cmd, authURL.String()); err != nil { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Open the following in your browser:\n\n\t%s\n\n", authURL.String()) + if err := openURL(inv, authURL.String()); err != nil { + _, _ = fmt.Fprintf(inv.Stdout, "Open the following in your browser:\n\n\t%s\n\n", authURL.String()) } else { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "Your browser has been opened to visit:\n\n\t%s\n\n", authURL.String()) + _, _ = fmt.Fprintf(inv.Stdout, "Your browser has been opened to visit:\n\n\t%s\n\n", authURL.String()) } - sessionToken, err = cliui.Prompt(cmd, cliui.PromptOptions{ + sessionToken, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Paste your token here:", Secret: true, Validate: func(token string) error { client.SetSessionToken(token) - _, err := client.User(cmd.Context(), codersdk.Me) + _, err := client.User(inv.Context(), codersdk.Me) if err != nil { return xerrors.New("That's not a valid token!") } @@ -254,12 +249,12 @@ func login() *cobra.Command { // Login to get user data - verify it is OK before persisting client.SetSessionToken(sessionToken) - resp, err := client.User(cmd.Context(), codersdk.Me) + resp, err := client.User(inv.Context(), codersdk.Me) if err != nil { return xerrors.Errorf("get user: %w", err) } - config := createConfig(cmd) + config := r.createConfig() err = config.Session().Write(sessionToken) if err != nil { return xerrors.Errorf("write session token: %w", err) @@ -269,14 +264,36 @@ func login() *cobra.Command { return xerrors.Errorf("write server url: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"Welcome to Coder, %s! You're authenticated.\n", cliui.Styles.Keyword.Render(resp.Username)) + _, _ = fmt.Fprintf(inv.Stdout, Caret+"Welcome to Coder, %s! You're authenticated.\n", cliui.Styles.Keyword.Render(resp.Username)) return nil }, } - cliflag.StringVarP(cmd.Flags(), &email, "first-user-email", "", "CODER_FIRST_USER_EMAIL", "", "Specifies an email address to use if creating the first user for the deployment.") - cliflag.StringVarP(cmd.Flags(), &username, "first-user-username", "", "CODER_FIRST_USER_USERNAME", "", "Specifies a username to use if creating the first user for the deployment.") - cliflag.StringVarP(cmd.Flags(), &password, "first-user-password", "", "CODER_FIRST_USER_PASSWORD", "", "Specifies a password to use if creating the first user for the deployment.") - cliflag.BoolVarP(cmd.Flags(), &trial, "first-user-trial", "", firstUserTrialEnv, false, "Specifies whether a trial license should be provisioned for the Coder deployment or not.") + cmd.Options = clibase.OptionSet{ + { + Flag: "first-user-email", + Env: "CODER_FIRST_USER_EMAIL", + Description: "Specifies an email address to use if creating the first user for the deployment.", + Value: clibase.StringOf(&email), + }, + { + Flag: "first-user-username", + Env: "CODER_FIRST_USER_USERNAME", + Description: "Specifies a username to use if creating the first user for the deployment.", + Value: clibase.StringOf(&username), + }, + { + Flag: "first-user-password", + Env: "CODER_FIRST_USER_PASSWORD", + Description: "Specifies a password to use if creating the first user for the deployment.", + Value: clibase.StringOf(&password), + }, + { + Flag: "first-user-trial", + Env: firstUserTrialEnv, + Description: "Specifies whether a trial license should be provisioned for the Coder deployment or not.", + Value: clibase.BoolOf(&trial), + }, + } return cmd } @@ -293,8 +310,8 @@ func isWSL() (bool, error) { } // openURL opens the provided URL via user's default browser -func openURL(cmd *cobra.Command, urlToOpen string) error { - noOpen, err := cmd.Flags().GetBool(varNoOpen) +func openURL(inv *clibase.Invocation, urlToOpen string) error { + noOpen, err := inv.ParsedFlags().GetBool(varNoOpen) if err != nil { panic(err) } @@ -314,7 +331,7 @@ func openURL(cmd *cobra.Command, urlToOpen string) error { browserEnv := os.Getenv("BROWSER") if browserEnv != "" { browserSh := fmt.Sprintf("%s '%s'", browserEnv, urlToOpen) - cmd := exec.CommandContext(cmd.Context(), "sh", "-c", browserSh) + cmd := exec.CommandContext(inv.Context(), "sh", "-c", browserSh) out, err := cmd.CombinedOutput() if err != nil { return xerrors.Errorf("failed to run %v (out: %q): %w", cmd.Args, out, err) diff --git a/cli/login_test.go b/cli/login_test.go index 14f9208360002..7e552fbe503dc 100644 --- a/cli/login_test.go +++ b/cli/login_test.go @@ -20,7 +20,7 @@ func TestLogin(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) root, _ := clitest.New(t, "login", client.URL.String()) - err := root.Execute() + err := root.Run() require.Error(t, err) }) @@ -28,7 +28,7 @@ func TestLogin(t *testing.T) { t.Parallel() badLoginURL := "https://fcca2077f06e68aaf9" root, _ := clitest.New(t, "login", badLoginURL) - err := root.Execute() + err := root.Run() errMsg := fmt.Sprintf("Failed to check server %q for first user, is the URL correct and is coder accessible from your browser?", badLoginURL) require.ErrorContains(t, err, errMsg) }) @@ -41,12 +41,10 @@ func TestLogin(t *testing.T) { // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(root) go func() { defer close(doneChan) - err := root.Execute() + err := root.Run() assert.NoError(t, err) }() @@ -74,16 +72,10 @@ func TestLogin(t *testing.T) { // The --force-tty flag is required on Windows, because the `isatty` library does not // accurately detect Windows ptys when they are not attached to a process: // https://github.com/mattn/go-isatty/issues/59 - doneChan := make(chan struct{}) - root, _ := clitest.New(t, "--url", client.URL.String(), "login", "--force-tty") - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) - go func() { - defer close(doneChan) - err := root.Execute() - assert.NoError(t, err) - }() + inv, _ := clitest.New(t, "--url", client.URL.String(), "login", "--force-tty") + pty := ptytest.New(t).Attach(inv) + + clitest.Start(t, inv) matches := []string{ "first user?", "yes", @@ -100,7 +92,6 @@ func TestLogin(t *testing.T) { pty.WriteLine(value) } pty.ExpectMatch("Welcome to Coder") - <-doneChan }) t.Run("InitialUserFlags", func(t *testing.T) { @@ -108,12 +99,10 @@ func TestLogin(t *testing.T) { client := coderdtest.New(t, nil) doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", client.URL.String(), "--first-user-username", "testuser", "--first-user-email", "user@coder.com", "--first-user-password", "SomeSecurePassword!", "--first-user-trial") - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(root) go func() { defer close(doneChan) - err := root.Execute() + err := root.Run() assert.NoError(t, err) }() pty.ExpectMatch("Welcome to Coder") @@ -130,12 +119,10 @@ func TestLogin(t *testing.T) { // https://github.com/mattn/go-isatty/issues/59 doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String()) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(root) go func() { defer close(doneChan) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() assert.NoError(t, err) }() @@ -173,12 +160,10 @@ func TestLogin(t *testing.T) { doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open") - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(root) go func() { defer close(doneChan) - err := root.Execute() + err := root.Run() assert.NoError(t, err) }() @@ -197,12 +182,10 @@ func TestLogin(t *testing.T) { defer cancelFunc() doneChan := make(chan struct{}) root, _ := clitest.New(t, "login", client.URL.String(), "--no-open") - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(root) go func() { defer close(doneChan) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() // An error is expected in this case, since the login wasn't successful: assert.Error(t, err) }() @@ -219,7 +202,7 @@ func TestLogin(t *testing.T) { client := coderdtest.New(t, nil) coderdtest.CreateFirstUser(t, client) root, cfg := clitest.New(t, "login", client.URL.String(), "--token", client.SessionToken()) - err := root.Execute() + err := root.Run() require.NoError(t, err) sessionFile, err := cfg.Session().Read() require.NoError(t, err) diff --git a/cli/logout.go b/cli/logout.go index d40a9ef45940c..6a4e8872bd227 100644 --- a/cli/logout.go +++ b/cli/logout.go @@ -5,27 +5,28 @@ import ( "os" "strings" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/codersdk" ) -func logout() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) logout() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "logout", Short: "Unauthenticate your local session", - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - + Middleware: clibase.Chain( + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { var errors []error - config := createConfig(cmd) + config := r.createConfig() - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + var err error + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Are you sure you want to log out?", IsConfirm: true, Default: cliui.ConfirmYes, @@ -34,7 +35,7 @@ func logout() *cobra.Command { return err } - err = client.Logout(cmd.Context()) + err = client.Logout(inv.Context()) if err != nil { errors = append(errors, xerrors.Errorf("logout api: %w", err)) } @@ -67,11 +68,10 @@ func logout() *cobra.Command { errorString := strings.TrimRight(errorStringBuilder.String(), "\n") return xerrors.New("Failed to log out.\n" + errorString) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), Caret+"You are no longer logged in. You can log in using 'coder login '.\n") + _, _ = fmt.Fprintf(inv.Stdout, Caret+"You are no longer logged in. You can log in using 'coder login '.\n") return nil }, } - - cliui.AllowSkipPrompt(cmd) + cmd.Options = append(cmd.Options, cliui.SkipPromptOption()) return cmd } diff --git a/cli/logout_test.go b/cli/logout_test.go index dea70710baf97..849016a68ce81 100644 --- a/cli/logout_test.go +++ b/cli/logout_test.go @@ -1,9 +1,7 @@ package cli_test import ( - "fmt" "os" - "regexp" "runtime" "testing" @@ -30,12 +28,12 @@ func TestLogout(t *testing.T) { logoutChan := make(chan struct{}) logout, _ := clitest.New(t, "logout", "--global-config", string(config)) - logout.SetIn(pty.Input()) - logout.SetOut(pty.Output()) + logout.Stdin = pty.Input() + logout.Stdout = pty.Output() go func() { defer close(logoutChan) - err := logout.Execute() + err := logout.Run() assert.NoError(t, err) assert.NoFileExists(t, string(config.URL())) assert.NoFileExists(t, string(config.Session())) @@ -58,12 +56,12 @@ func TestLogout(t *testing.T) { logoutChan := make(chan struct{}) logout, _ := clitest.New(t, "logout", "--global-config", string(config), "-y") - logout.SetIn(pty.Input()) - logout.SetOut(pty.Output()) + logout.Stdin = pty.Input() + logout.Stdout = pty.Output() go func() { defer close(logoutChan) - err := logout.Execute() + err := logout.Run() assert.NoError(t, err) assert.NoFileExists(t, string(config.URL())) assert.NoFileExists(t, string(config.Session())) @@ -88,13 +86,13 @@ func TestLogout(t *testing.T) { logoutChan := make(chan struct{}) logout, _ := clitest.New(t, "logout", "--global-config", string(config)) - logout.SetIn(pty.Input()) - logout.SetOut(pty.Output()) + logout.Stdin = pty.Input() + logout.Stdout = pty.Output() go func() { defer close(logoutChan) - err := logout.Execute() - assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login '.") + err := logout.Run() + assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login '.") }() <-logoutChan @@ -115,13 +113,13 @@ func TestLogout(t *testing.T) { logoutChan := make(chan struct{}) logout, _ := clitest.New(t, "logout", "--global-config", string(config)) - logout.SetIn(pty.Input()) - logout.SetOut(pty.Output()) + logout.Stdin = pty.Input() + logout.Stdout = pty.Output() go func() { defer close(logoutChan) - err = logout.Execute() - assert.EqualError(t, err, "You are not logged in. Try logging in using 'coder login '.") + err = logout.Run() + assert.ErrorContains(t, err, "You are not logged in. Try logging in using 'coder login '.") }() <-logoutChan @@ -166,29 +164,27 @@ func TestLogout(t *testing.T) { } }() - logoutChan := make(chan struct{}) logout, _ := clitest.New(t, "logout", "--global-config", string(config)) - logout.SetIn(pty.Input()) - logout.SetOut(pty.Output()) + logout.Stdin = pty.Input() + logout.Stdout = pty.Output() go func() { - defer close(logoutChan) - err := logout.Execute() - assert.NotNil(t, err) - var errorMessage string - if runtime.GOOS == "windows" { - errorMessage = "The process cannot access the file because it is being used by another process." - } else { - errorMessage = "permission denied" - } - errRegex := regexp.MustCompile(fmt.Sprintf("Failed to log out.\n\tremove URL file: .+: %s\n\tremove session file: .+: %s", errorMessage, errorMessage)) - assert.Regexp(t, errRegex, err.Error()) + pty.ExpectMatch("Are you sure you want to log out?") + pty.WriteLine("yes") }() + err = logout.Run() + require.Error(t, err) - pty.ExpectMatch("Are you sure you want to log out?") - pty.WriteLine("yes") - <-logoutChan + t.Logf("err: %v", err) + + var wantError string + if runtime.GOOS == "windows" { + wantError = "The process cannot access the file because it is being used by another process." + } else { + wantError = "permission denied" + } + require.ErrorContains(t, err, wantError) }) } @@ -200,11 +196,11 @@ func login(t *testing.T, pty *ptytest.PTY) config.Root { doneChan := make(chan struct{}) root, cfg := clitest.New(t, "login", "--force-tty", client.URL.String(), "--no-open") - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) + root.Stdin = pty.Input() + root.Stdout = pty.Output() go func() { defer close(doneChan) - err := root.Execute() + err := root.Run() assert.NoError(t, err) }() diff --git a/cli/parameter.go b/cli/parameter.go index 9d2853b3d2d03..8da63b209233b 100644 --- a/cli/parameter.go +++ b/cli/parameter.go @@ -5,10 +5,10 @@ import ( "fmt" "os" - "github.com/spf13/cobra" "golang.org/x/xerrors" "gopkg.in/yaml.v3" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) @@ -51,20 +51,20 @@ func createParameterMapFromFile(parameterFile string) (map[string]string, error) // Returns a parameter value from a given map, if the map does not exist or does not contain the item, it takes input from the user. // Throws an error if there are any errors with the users input. -func getParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string]string, parameterSchema codersdk.ParameterSchema) (string, error) { +func getParameterValueFromMapOrInput(inv *clibase.Invocation, parameterMap map[string]string, parameterSchema codersdk.ParameterSchema) (string, error) { var parameterValue string var err error if parameterMap != nil { var ok bool parameterValue, ok = parameterMap[parameterSchema.Name] if !ok { - parameterValue, err = cliui.ParameterSchema(cmd, parameterSchema) + parameterValue, err = cliui.ParameterSchema(inv, parameterSchema) if err != nil { return "", err } } } else { - parameterValue, err = cliui.ParameterSchema(cmd, parameterSchema) + parameterValue, err = cliui.ParameterSchema(inv, parameterSchema) if err != nil { return "", err } @@ -72,20 +72,20 @@ func getParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string return parameterValue, nil } -func getWorkspaceBuildParameterValueFromMapOrInput(cmd *cobra.Command, parameterMap map[string]string, templateVersionParameter codersdk.TemplateVersionParameter) (*codersdk.WorkspaceBuildParameter, error) { +func getWorkspaceBuildParameterValueFromMapOrInput(inv *clibase.Invocation, parameterMap map[string]string, templateVersionParameter codersdk.TemplateVersionParameter) (*codersdk.WorkspaceBuildParameter, error) { var parameterValue string var err error if parameterMap != nil { var ok bool parameterValue, ok = parameterMap[templateVersionParameter.Name] if !ok { - parameterValue, err = cliui.RichParameter(cmd, templateVersionParameter) + parameterValue, err = cliui.RichParameter(inv, templateVersionParameter) if err != nil { return nil, err } } } else { - parameterValue, err = cliui.RichParameter(cmd, templateVersionParameter) + parameterValue, err = cliui.RichParameter(inv, templateVersionParameter) if err != nil { return nil, err } diff --git a/cli/parameters.go b/cli/parameters.go index 3fb54973ab98b..021d94521aaad 100644 --- a/cli/parameters.go +++ b/cli/parameters.go @@ -1,13 +1,13 @@ package cli import ( - "github.com/spf13/cobra" + "github.com/coder/coder/cli/clibase" ) -func parameters() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) parameters() *clibase.Cmd { + cmd := &clibase.Cmd{ Short: "List parameters for a given scope", - Example: formatExamples( + Long: formatExamples( example{ Command: "coder parameters list workspace my-workspace", }, @@ -20,12 +20,9 @@ func parameters() *cobra.Command { // constructing curl requests. Hidden: true, Aliases: []string{"params"}, - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() + Children: []*clibase.Cmd{ + r.parameterList(), }, } - cmd.AddCommand( - parameterList(), - ) return cmd } diff --git a/cli/parameterslist.go b/cli/parameterslist.go index 1249f2a642be7..86829ae69b5ce 100644 --- a/cli/parameterslist.go +++ b/cli/parameterslist.go @@ -4,32 +4,32 @@ import ( "fmt" "github.com/google/uuid" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func parameterList() *cobra.Command { +func (r *RootCmd) parameterList() *clibase.Cmd { formatter := cliui.NewOutputFormatter( cliui.TableFormat([]codersdk.Parameter{}, []string{"name", "scope", "destination scheme"}), cliui.JSONFormat(), ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + + cmd := &clibase.Cmd{ Use: "list", Aliases: []string{"ls"}, - Args: cobra.ExactArgs(2), - RunE: func(cmd *cobra.Command, args []string) error { - scope, name := args[0], args[1] - - client, err := CreateClient(cmd) - if err != nil { - return err - } + Middleware: clibase.Chain( + clibase.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + scope, name := inv.Args[0], inv.Args[1] - organization, err := CurrentOrganization(cmd, client) + organization, err := CurrentOrganization(inv, client) if err != nil { return xerrors.Errorf("get current organization: %w", err) } @@ -37,13 +37,13 @@ func parameterList() *cobra.Command { var scopeID uuid.UUID switch codersdk.ParameterScope(scope) { case codersdk.ParameterWorkspace: - workspace, err := namedWorkspace(cmd, client, name) + workspace, err := namedWorkspace(inv.Context(), client, name) if err != nil { return err } scopeID = workspace.ID case codersdk.ParameterTemplate: - template, err := client.TemplateByName(cmd.Context(), organization.ID, name) + template, err := client.TemplateByName(inv.Context(), organization.ID, name) if err != nil { return xerrors.Errorf("get workspace template: %w", err) } @@ -57,7 +57,7 @@ func parameterList() *cobra.Command { // Could be a template_version id or a job id. Check for the // version id. - tv, err := client.TemplateVersion(cmd.Context(), scopeID) + tv, err := client.TemplateVersion(inv.Context(), scopeID) if err == nil { scopeID = tv.Job.ID } @@ -68,21 +68,21 @@ func parameterList() *cobra.Command { }) } - params, err := client.Parameters(cmd.Context(), codersdk.ParameterScope(scope), scopeID) + params, err := client.Parameters(inv.Context(), codersdk.ParameterScope(scope), scopeID) if err != nil { return xerrors.Errorf("fetch params: %w", err) } - out, err := formatter.Format(cmd.Context(), params) + out, err := formatter.Format(inv.Context(), params) if err != nil { return xerrors.Errorf("render output: %w", err) } - _, err = fmt.Fprintln(cmd.OutOrStdout(), out) + _, err = fmt.Fprintln(inv.Stdout, out) return err }, } - formatter.AttachFlags(cmd) + formatter.AttachOptions(&cmd.Options) return cmd } diff --git a/cli/ping.go b/cli/ping.go index 09cdca42747dc..4ef022c7febfc 100644 --- a/cli/ping.go +++ b/cli/ping.go @@ -5,46 +5,48 @@ import ( "fmt" "time" - "github.com/spf13/cobra" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func ping() *cobra.Command { +func (r *RootCmd) ping() *clibase.Cmd { var ( - pingNum int + pingNum int64 pingTimeout time.Duration pingWait time.Duration - verbose bool ) - cmd := &cobra.Command{ + + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "ping ", Short: "Ping a workspace", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspaceName := args[0] - _, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, workspaceName, false) + workspaceName := inv.Args[0] + _, workspaceAgent, err := getWorkspaceAndAgent( + ctx, inv, client, + codersdk.Me, workspaceName, + ) if err != nil { return err } var logger slog.Logger - if verbose { - logger = slog.Make(sloghuman.Sink(cmd.OutOrStdout())).Leveled(slog.LevelDebug) + if r.verbose { + logger = slog.Make(sloghuman.Sink(inv.Stdout)).Leveled(slog.LevelDebug) } conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{Logger: logger}) @@ -70,8 +72,8 @@ func ping() *cobra.Command { cancel() if err != nil { if xerrors.Is(err, context.DeadlineExceeded) { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "ping to %q timed out \n", workspaceName) - if n == pingNum { + _, _ = fmt.Fprintf(inv.Stdout, "ping to %q timed out \n", workspaceName) + if n == int(pingNum) { return nil } continue @@ -84,8 +86,8 @@ func ping() *cobra.Command { continue } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "ping to %q failed %s\n", workspaceName, err.Error()) - if n == pingNum { + _, _ = fmt.Fprintf(inv.Stdout, "ping to %q failed %s\n", workspaceName, err.Error()) + if n == int(pingNum) { return nil } continue @@ -95,7 +97,7 @@ func ping() *cobra.Command { var via string if p2p { if !didP2p { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "p2p connection established in", + _, _ = fmt.Fprintln(inv.Stdout, "p2p connection established in", cliui.Styles.DateTimeStamp.Render(time.Since(start).Round(time.Millisecond).String()), ) } @@ -117,22 +119,40 @@ func ping() *cobra.Command { ) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "pong from %s %s in %s\n", + _, _ = fmt.Fprintf(inv.Stdout, "pong from %s %s in %s\n", cliui.Styles.Keyword.Render(workspaceName), via, cliui.Styles.DateTimeStamp.Render(dur.String()), ) - if n == pingNum { + if n == int(pingNum) { return nil } } }, } - cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Enables verbose logging.") - cmd.Flags().DurationVarP(&pingWait, "wait", "", time.Second, "Specifies how long to wait between pings.") - cmd.Flags().DurationVarP(&pingTimeout, "timeout", "t", 5*time.Second, "Specifies how long to wait for a ping to complete.") - cmd.Flags().IntVarP(&pingNum, "num", "n", 10, "Specifies the number of pings to perform.") + cmd.Options = clibase.OptionSet{ + { + Flag: "wait", + Description: "Specifies how long to wait between pings.", + Default: "1s", + Value: clibase.DurationOf(&pingWait), + }, + { + Flag: "timeout", + FlagShorthand: "t", + Default: "5s", + Description: "Specifies how long to wait for a ping to complete.", + Value: clibase.DurationOf(&pingTimeout), + }, + { + Flag: "num", + FlagShorthand: "n", + Default: "10", + Description: "Specifies the number of pings to perform.", + Value: clibase.Int64Of(&pingNum), + }, + } return cmd } diff --git a/cli/ping_test.go b/cli/ping_test.go index f599e38b4cd8c..959c11c8ed9b4 100644 --- a/cli/ping_test.go +++ b/cli/ping_test.go @@ -22,12 +22,12 @@ func TestPing(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t, nil) - cmd, root := clitest.New(t, "ping", workspace.Name) + inv, root := clitest.New(t, "ping", workspace.Name) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetErr(pty.Output()) - cmd.SetOut(pty.Output()) + inv.Stdin = pty.Input() + inv.Stderr = pty.Output() + inv.Stdout = pty.Output() agentClient := agentsdk.New(client.URL) agentClient.SetSessionToken(agentToken) @@ -43,7 +43,7 @@ func TestPing(t *testing.T) { defer cancel() cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) }) diff --git a/cli/portforward.go b/cli/portforward.go index b3728212a904a..c746216889a55 100644 --- a/cli/portforward.go +++ b/cli/portforward.go @@ -12,26 +12,25 @@ import ( "syscall" "github.com/pion/udp" - "github.com/spf13/cobra" "golang.org/x/xerrors" "github.com/coder/coder/agent" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func portForward() *cobra.Command { +func (r *RootCmd) portForward() *clibase.Cmd { var ( tcpForwards []string // : udpForwards []string // : ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "port-forward ", Short: "Forward ports from machine to a workspace", Aliases: []string{"tunnel"}, - Args: cobra.ExactArgs(1), - Example: formatExamples( + Long: formatExamples( example{ Description: "Port forward a single TCP port from 1234 in the workspace to port 5678 on your local machine", Command: "coder port-forward --tcp 5678:1234", @@ -49,8 +48,12 @@ func portForward() *cobra.Command { Command: "coder port-forward --tcp 8080,9000:3000,9090-9092,10000-10002:10010-10012", }, ), - RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() specs, err := parsePortForwards(tcpForwards, udpForwards) @@ -58,19 +61,14 @@ func portForward() *cobra.Command { return xerrors.Errorf("parse port-forward specs: %w", err) } if len(specs) == 0 { - err = cmd.Help() + err = inv.Command.HelpHandler(inv) if err != nil { return xerrors.Errorf("generate help output: %w", err) } return xerrors.New("no port-forwards requested") } - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0]) if err != nil { return err } @@ -78,13 +76,13 @@ func portForward() *cobra.Command { return xerrors.New("workspace must be in start transition to port-forward") } if workspace.LatestBuild.Job.CompletedAt == nil { - err = cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID) + err = cliui.WorkspaceBuild(ctx, inv.Stderr, client, workspace.LatestBuild.ID) if err != nil { return err } } - err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -116,7 +114,7 @@ func portForward() *cobra.Command { defer closeAllListeners() for i, spec := range specs { - l, err := listenAndPortForward(ctx, cmd, conn, wg, spec) + l, err := listenAndPortForward(ctx, inv, conn, wg, spec) if err != nil { return err } @@ -137,7 +135,7 @@ func portForward() *cobra.Command { case <-ctx.Done(): closeErr = ctx.Err() case <-sigs: - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "\nReceived signal, closing all listeners and active connections") + _, _ = fmt.Fprintln(inv.Stderr, "\nReceived signal, closing all listeners and active connections") } cancel() @@ -145,19 +143,33 @@ func portForward() *cobra.Command { }() conn.AwaitReachable(ctx) - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Ready!") + _, _ = fmt.Fprintln(inv.Stderr, "Ready!") wg.Wait() return closeErr }, } - cliflag.StringArrayVarP(cmd.Flags(), &tcpForwards, "tcp", "p", "CODER_PORT_FORWARD_TCP", nil, "Forward TCP port(s) from the workspace to the local machine") - cliflag.StringArrayVarP(cmd.Flags(), &udpForwards, "udp", "", "CODER_PORT_FORWARD_UDP", nil, "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols") + cmd.Options = clibase.OptionSet{ + { + Flag: "tcp", + FlagShorthand: "p", + Env: "CODER_PORT_FORWARD_TCP", + Description: "Forward TCP port(s) from the workspace to the local machine.", + Value: clibase.StringArrayOf(&tcpForwards), + }, + { + Flag: "udp", + Env: "CODER_PORT_FORWARD_UDP", + Description: "Forward UDP port(s) from the workspace to the local machine. The UDP connection has TCP-like semantics to support stateful UDP protocols.", + Value: clibase.StringArrayOf(&udpForwards), + }, + } + return cmd } -func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersdk.WorkspaceAgentConn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) +func listenAndPortForward(ctx context.Context, inv *clibase.Invocation, conn *codersdk.WorkspaceAgentConn, wg *sync.WaitGroup, spec portForwardSpec) (net.Listener, error) { + _, _ = fmt.Fprintf(inv.Stderr, "Forwarding '%v://%v' locally to '%v://%v' in the workspace\n", spec.listenNetwork, spec.listenAddress, spec.dialNetwork, spec.dialAddress) var ( l net.Listener @@ -200,8 +212,8 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersd if xerrors.Is(err, net.ErrClosed) { return } - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err) - _, _ = fmt.Fprintln(cmd.OutOrStderr(), "Killing listener") + _, _ = fmt.Fprintf(inv.Stderr, "Error accepting connection from '%v://%v': %v\n", spec.listenNetwork, spec.listenAddress, err) + _, _ = fmt.Fprintln(inv.Stderr, "Killing listener") return } @@ -209,7 +221,7 @@ func listenAndPortForward(ctx context.Context, cmd *cobra.Command, conn *codersd defer netConn.Close() remoteConn, err := conn.DialContext(ctx, spec.dialNetwork, spec.dialAddress) if err != nil { - _, _ = fmt.Fprintf(cmd.OutOrStderr(), "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) + _, _ = fmt.Fprintf(inv.Stderr, "Failed to dial '%v://%v' in workspace: %s\n", spec.dialNetwork, spec.dialAddress, err) return } defer remoteConn.Close() diff --git a/cli/portforward_test.go b/cli/portforward_test.go index f74dd8a6429ea..cf3cc99a7d6bf 100644 --- a/cli/portforward_test.go +++ b/cli/portforward_test.go @@ -31,14 +31,12 @@ func TestPortForward(t *testing.T) { client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "port-forward", "blah") + inv, root := clitest.New(t, "port-forward", "blah") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() - err := cmd.Execute() + err := inv.Run() require.Error(t, err) require.ErrorContains(t, err, "no port-forwards") @@ -133,17 +131,17 @@ func TestPortForward(t *testing.T) { // Launch port-forward in a goroutine so we can start dialing // the "local" listener. - cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) + inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { - errC <- cmd.ExecuteContext(ctx) + errC <- inv.WithContext(ctx).Run() }() pty.ExpectMatch("Ready!") @@ -181,17 +179,17 @@ func TestPortForward(t *testing.T) { // Launch port-forward in a goroutine so we can start dialing // the "local" listeners. - cmd, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2) + inv, root := clitest.New(t, "-v", "port-forward", workspace.Name, flag1, flag2) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { - errC <- cmd.ExecuteContext(ctx) + errC <- inv.WithContext(ctx).Run() }() pty.ExpectMatch("Ready!") @@ -238,17 +236,15 @@ func TestPortForward(t *testing.T) { // Launch port-forward in a goroutine so we can start dialing // the "local" listeners. - cmd, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...) + inv, root := clitest.New(t, append([]string{"-v", "port-forward", workspace.Name}, flags...)...) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() ctx, cancel := context.WithCancel(context.Background()) defer cancel() errC := make(chan error) go func() { - errC <- cmd.ExecuteContext(ctx) + errC <- inv.WithContext(ctx).Run() }() pty.ExpectMatch("Ready!") @@ -304,12 +300,12 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk. coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) // Start workspace agent in a goroutine - cmd, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) + inv, root := clitest.New(t, "agent", "--agent-token", agentToken, "--agent-url", client.URL.String()) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdin = pty.Input() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() errC := make(chan error) agentCtx, agentCancel := context.WithCancel(ctx) t.Cleanup(func() { @@ -318,7 +314,7 @@ func runAgent(t *testing.T, client *codersdk.Client, userID uuid.UUID) codersdk. require.NoError(t, err) }) go func() { - errC <- cmd.ExecuteContext(agentCtx) + errC <- inv.WithContext(agentCtx).Run() }() coderdtest.AwaitWorkspaceAgents(t, client, workspace.ID) diff --git a/cli/publickey.go b/cli/publickey.go index 3872baf594946..7d4501c9cd26e 100644 --- a/cli/publickey.go +++ b/cli/publickey.go @@ -3,30 +3,26 @@ package cli import ( "strings" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func publickey() *cobra.Command { +func (r *RootCmd) publickey() *clibase.Cmd { var reset bool - - cmd := &cobra.Command{ - Use: "publickey", - Aliases: []string{"pubkey"}, - Short: "Output your Coder public key used for Git operations", - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return xerrors.Errorf("create codersdk client: %w", err) - } - + client := new(codersdk.Client) + cmd := &clibase.Cmd{ + Use: "publickey", + Aliases: []string{"pubkey"}, + Short: "Output your Coder public key used for Git operations", + Middleware: r.InitClient(client), + Handler: func(inv *clibase.Invocation) error { if reset { // Confirm prompt if using --reset. We don't want to accidentally // reset our public key. - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + _, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm regenerate a new sshkey for your workspaces? This will require updating the key " + "on any services it is registered with. This action cannot be reverted.", IsConfirm: true, @@ -36,33 +32,38 @@ func publickey() *cobra.Command { } // Reset the public key, let the retrieve re-read it. - _, err = client.RegenerateGitSSHKey(cmd.Context(), codersdk.Me) + _, err = client.RegenerateGitSSHKey(inv.Context(), codersdk.Me) if err != nil { return err } } - key, err := client.GitSSHKey(cmd.Context(), codersdk.Me) + key, err := client.GitSSHKey(inv.Context(), codersdk.Me) if err != nil { return xerrors.Errorf("create codersdk client: %w", err) } - cmd.Println(cliui.Styles.Wrap.Render( - "This is your public key for using " + cliui.Styles.Field.Render("git") + " in " + - "Coder. All clones with SSH will be authenticated automatically 🪄.", - )) - cmd.Println() - cmd.Println(cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))) - cmd.Println() - cmd.Println("Add to GitHub and GitLab:") - cmd.Println(cliui.Styles.Prompt.String() + "https://github.com/settings/ssh/new") - cmd.Println(cliui.Styles.Prompt.String() + "https://gitlab.com/-/profile/keys") + cliui.Infof(inv.Stdout, + "This is your public key for using "+cliui.Styles.Field.Render("git")+" in "+ + "Coder. All clones with SSH will be authenticated automatically 🪄.\n\n", + ) + cliui.Infof(inv.Stdout, cliui.Styles.Code.Render(strings.TrimSpace(key.PublicKey))+"\n\n") + cliui.Infof(inv.Stdout, "Add to GitHub and GitLab:"+"\n") + cliui.Infof(inv.Stdout, cliui.Styles.Prompt.String()+"https://github.com/settings/ssh/new"+"\n") + cliui.Infof(inv.Stdout, cliui.Styles.Prompt.String()+"https://gitlab.com/-/profile/keys"+"\n") return nil }, } - cmd.Flags().BoolVar(&reset, "reset", false, "Regenerate your public key. This will require updating the key on any services it's registered with.") - cliui.AllowSkipPrompt(cmd) + + cmd.Options = clibase.OptionSet{ + { + Flag: "reset", + Description: "Regenerate your public key. This will require updating the key on any services it's registered with.", + Value: clibase.BoolOf(&reset), + }, + cliui.SkipPromptOption(), + } return cmd } diff --git a/cli/publickey_test.go b/cli/publickey_test.go index f0bef2c65359e..a5664ec2bda07 100644 --- a/cli/publickey_test.go +++ b/cli/publickey_test.go @@ -16,11 +16,11 @@ func TestPublicKey(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "publickey") + inv, root := clitest.New(t, "publickey") clitest.SetupConfig(t, client, root) buf := new(bytes.Buffer) - cmd.SetOut(buf) - err := cmd.Execute() + inv.Stdout = buf + err := inv.Run() require.NoError(t, err) publicKey := buf.String() require.NotEmpty(t, publicKey) diff --git a/cli/rename.go b/cli/rename.go index ac364b80ea93b..e0443e75ed6ff 100644 --- a/cli/rename.go +++ b/cli/rename.go @@ -3,34 +3,34 @@ package cli import ( "fmt" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func rename() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) rename() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "rename ", Short: "Rename a workspace", - Args: cobra.ExactArgs(2), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return xerrors.Errorf("get workspace: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n\n", + _, _ = fmt.Fprintf(inv.Stdout, "%s\n\n", cliui.Styles.Wrap.Render("WARNING: A rename can result in data loss if a resource references the workspace name in the template (e.g volumes). Please backup any data before proceeding."), ) - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "See: %s\n\n", "https://coder.com/docs/coder-oss/latest/templates/resource-persistence#%EF%B8%8F-persistence-pitfalls") - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, _ = fmt.Fprintf(inv.Stdout, "See: %s\n\n", "https://coder.com/docs/coder-oss/latest/templates/resource-persistence#%EF%B8%8F-persistence-pitfalls") + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("Type %q to confirm rename:", workspace.Name), Validate: func(s string) error { if s == workspace.Name { @@ -43,17 +43,18 @@ func rename() *cobra.Command { return err } - err = client.UpdateWorkspace(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceRequest{ - Name: args[1], + err = client.UpdateWorkspace(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceRequest{ + Name: inv.Args[1], }) if err != nil { return xerrors.Errorf("rename workspace: %w", err) } + _, _ = fmt.Fprintf(inv.Stdout, "Workspace %q renamed to %q\n", workspace.Name, inv.Args[1]) return nil }, } - cliui.AllowSkipPrompt(cmd) + cmd.Options = append(cmd.Options, cliui.SkipPromptOption()) return cmd } diff --git a/cli/rename_test.go b/cli/rename_test.go index f965bfa3e3636..6cd92ff9e1451 100644 --- a/cli/rename_test.go +++ b/cli/rename_test.go @@ -5,7 +5,6 @@ import ( "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/coder/coder/cli/clitest" "github.com/coder/coder/coderd/coderdtest" @@ -30,21 +29,15 @@ func TestRename(t *testing.T) { // Only append one letter because it's easy to exceed maximum length: // E.g. "compassionate-chandrasekhar82" + "t". want := workspace.Name + "t" - cmd, root := clitest.New(t, "rename", workspace.Name, want, "--yes") + inv, root := clitest.New(t, "rename", workspace.Name, want, "--yes") clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - - errC := make(chan error, 1) - go func() { - errC <- cmd.ExecuteContext(ctx) - }() + pty.Attach(inv) + clitest.Start(t, inv) pty.ExpectMatch("confirm rename:") pty.WriteLine(workspace.Name) - - require.NoError(t, <-errC) + pty.ExpectMatch("renamed to") ws, err := client.Workspace(ctx, workspace.ID) assert.NoError(t, err) diff --git a/cli/resetpassword.go b/cli/resetpassword.go index 8aea553730f1c..dcf206dd680d6 100644 --- a/cli/resetpassword.go +++ b/cli/resetpassword.go @@ -4,25 +4,24 @@ import ( "database/sql" "fmt" - "github.com/spf13/cobra" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/database/migrations" "github.com/coder/coder/coderd/userpassword" ) -func resetPassword() *cobra.Command { +func (*RootCmd) resetPassword() *clibase.Cmd { var postgresURL string - root := &cobra.Command{ - Use: "reset-password ", - Short: "Directly connect to the database to reset a user's password", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - username := args[0] + root := &clibase.Cmd{ + Use: "reset-password ", + Short: "Directly connect to the database to reset a user's password", + Middleware: clibase.RequireNArgs(1), + Handler: func(inv *clibase.Invocation) error { + username := inv.Args[0] sqlDB, err := sql.Open("postgres", postgresURL) if err != nil { @@ -40,14 +39,14 @@ func resetPassword() *cobra.Command { } db := database.New(sqlDB) - user, err := db.GetUserByEmailOrUsername(cmd.Context(), database.GetUserByEmailOrUsernameParams{ + user, err := db.GetUserByEmailOrUsername(inv.Context(), database.GetUserByEmailOrUsernameParams{ Username: username, }) if err != nil { return xerrors.Errorf("retrieving user: %w", err) } - password, err := cliui.Prompt(cmd, cliui.PromptOptions{ + password, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Enter new " + cliui.Styles.Field.Render("password") + ":", Secret: true, Validate: func(s string) error { @@ -57,7 +56,7 @@ func resetPassword() *cobra.Command { if err != nil { return xerrors.Errorf("password prompt: %w", err) } - confirmedPassword, err := cliui.Prompt(cmd, cliui.PromptOptions{ + confirmedPassword, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm " + cliui.Styles.Field.Render("password") + ":", Secret: true, Validate: cliui.ValidateNotEmpty, @@ -74,7 +73,7 @@ func resetPassword() *cobra.Command { return xerrors.Errorf("hash password: %w", err) } - err = db.UpdateUserHashedPassword(cmd.Context(), database.UpdateUserHashedPasswordParams{ + err = db.UpdateUserHashedPassword(inv.Context(), database.UpdateUserHashedPasswordParams{ ID: user.ID, HashedPassword: []byte(hashedPassword), }) @@ -82,12 +81,19 @@ func resetPassword() *cobra.Command { return xerrors.Errorf("updating password: %w", err) } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nPassword has been reset for user %s!\n", cliui.Styles.Keyword.Render(user.Username)) + _, _ = fmt.Fprintf(inv.Stdout, "\nPassword has been reset for user %s!\n", cliui.Styles.Keyword.Render(user.Username)) return nil }, } - cliflag.StringVarP(root.Flags(), &postgresURL, "postgres-url", "", "CODER_PG_CONNECTION_URL", "", "URL of a PostgreSQL database to connect to") + root.Options = clibase.OptionSet{ + { + Flag: "postgres-url", + Description: "URL of a PostgreSQL database to connect to.", + Env: "CODER_PG_CONNECTION_URL", + Value: clibase.StringOf(&postgresURL), + }, + } return root } diff --git a/cli/resetpassword_test.go b/cli/resetpassword_test.go index 3bf45c271b758..40cfc1042dcdc 100644 --- a/cli/resetpassword_test.go +++ b/cli/resetpassword_test.go @@ -37,7 +37,7 @@ func TestResetPassword(t *testing.T) { defer closeFunc() ctx, cancelFunc := context.WithCancel(context.Background()) serverDone := make(chan struct{}) - serverCmd, cfg := clitest.New(t, + serverinv, cfg := clitest.New(t, "server", "--http-address", ":0", "--access-url", "http://example.com", @@ -46,7 +46,7 @@ func TestResetPassword(t *testing.T) { ) go func() { defer close(serverDone) - err = serverCmd.ExecuteContext(ctx) + err = serverinv.WithContext(ctx).Run() assert.NoError(t, err) }() var rawURL string @@ -67,15 +67,15 @@ func TestResetPassword(t *testing.T) { // reset the password - resetCmd, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username) + resetinv, cmdCfg := clitest.New(t, "reset-password", "--postgres-url", connectionURL, username) clitest.SetupConfig(t, client, cmdCfg) cmdDone := make(chan struct{}) pty := ptytest.New(t) - resetCmd.SetIn(pty.Input()) - resetCmd.SetOut(pty.Output()) + resetinv.Stdin = pty.Input() + resetinv.Stdout = pty.Output() go func() { defer close(cmdDone) - err = resetCmd.Execute() + err = resetinv.Run() assert.NoError(t, err) }() diff --git a/cli/restart.go b/cli/restart.go index 687297d371f5f..51ffb2abbf871 100644 --- a/cli/restart.go +++ b/cli/restart.go @@ -4,23 +4,29 @@ import ( "fmt" "time" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func restart() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) restart() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "restart ", Short: "Restart a workspace", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - out := cmd.OutOrStdout() + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Options: clibase.OptionSet{ + cliui.SkipPromptOption(), + }, + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() + out := inv.Stdout - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + _, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm restart workspace?", IsConfirm: true, }) @@ -28,11 +34,7 @@ func restart() *cobra.Command { return err } - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } @@ -63,6 +65,5 @@ func restart() *cobra.Command { return nil }, } - cliui.AllowSkipPrompt(cmd) return cmd } diff --git a/cli/restart_test.go b/cli/restart_test.go index 9ad55a05137da..d1dfa6bd3b497 100644 --- a/cli/restart_test.go +++ b/cli/restart_test.go @@ -25,18 +25,16 @@ func TestRestart(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - ctx, _ := testutil.Context(t) + ctx := testutil.Context(t, testutil.WaitLong) - cmd, root := clitest.New(t, "restart", workspace.Name, "--yes") + inv, root := clitest.New(t, "restart", workspace.Name, "--yes") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) done := make(chan error, 1) go func() { - done <- cmd.ExecuteContext(ctx) + done <- inv.WithContext(ctx).Run() }() pty.ExpectMatch("Stopping workspace") pty.ExpectMatch("Starting workspace") diff --git a/cli/root.go b/cli/root.go index 8d6f14b06f9a1..280455258e111 100644 --- a/cli/root.go +++ b/cli/root.go @@ -1,32 +1,37 @@ package cli import ( + "bufio" "context" + "errors" "flag" "fmt" "io" + "math/rand" "net" "net/http" "net/url" "os" "os/signal" "path/filepath" + "regexp" "runtime" "strings" "syscall" - "text/template" "time" + "unicode/utf8" + "golang.org/x/crypto/ssh/terminal" + "golang.org/x/exp/slices" "golang.org/x/xerrors" "cdr.dev/slog" "github.com/charmbracelet/lipgloss" "github.com/mattn/go-isatty" - "github.com/spf13/cobra" "github.com/coder/coder/buildinfo" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd" @@ -66,84 +71,82 @@ const ( var errUnauthenticated = xerrors.New(notLoggedInMessage) -func init() { - // Set cobra template functions in init to avoid conflicts in tests. - cobra.AddTemplateFuncs(templateFunctions) -} - -func Core() []*cobra.Command { +func (r *RootCmd) Core() []*clibase.Cmd { // Please re-sort this list alphabetically if you change it! - return []*cobra.Command{ - configSSH(), - create(), - deleteWorkspace(), - dotfiles(), - gitssh(), - list(), - login(), - logout(), - parameters(), - ping(), - portForward(), - publickey(), - rename(), - resetPassword(), - restart(), - scaletest(), - schedules(), - show(), - speedtest(), - ssh(), - start(), - state(), - stop(), - templates(), - tokens(), - update(), - users(), - versionCmd(), - vscodeSSH(), - workspaceAgent(), - } -} - -func AGPL() []*cobra.Command { - all := append(Core(), Server(func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) { + return []*clibase.Cmd{ + r.dotfiles(), + r.login(), + r.logout(), + r.portForward(), + r.publickey(), + r.resetPassword(), + r.state(), + r.templates(), + r.users(), + r.tokens(), + r.version(), + + // Workspace Commands + r.configSSH(), + r.rename(), + r.ping(), + r.create(), + r.deleteWorkspace(), + r.list(), + r.schedules(), + r.show(), + r.speedtest(), + r.ssh(), + r.start(), + r.stop(), + r.update(), + r.restart(), + r.parameters(), + + // Hidden + r.workspaceAgent(), + r.scaletest(), + r.gitssh(), + r.vscodeSSH(), + } +} + +func (r *RootCmd) AGPL() []*clibase.Cmd { + all := append(r.Core(), r.Server(func(_ context.Context, o *coderd.Options) (*coderd.API, io.Closer, error) { api := coderd.New(o) return api, api, nil })) return all } -func Root(subcommands []*cobra.Command) *cobra.Command { - // The GIT_ASKPASS environment variable must point at - // a binary with no arguments. To prevent writing - // cross-platform scripts to invoke the Coder binary - // with a `gitaskpass` subcommand, we override the entrypoint - // to check if the command was invoked. - isGitAskpass := false +// Main is the entrypoint for the Coder CLI. +func (r *RootCmd) RunMain(subcommands []*clibase.Cmd) { + rand.Seed(time.Now().UnixMicro()) + + cmd, err := r.Command(subcommands) + if err != nil { + panic(err) + } + + err = cmd.Invoke().WithOS().Run() + if err != nil { + if errors.Is(err, cliui.Canceled) { + //nolint:revive + os.Exit(1) + } + f := prettyErrorFormatter{w: os.Stderr} + f.format(err) + //nolint:revive + os.Exit(1) + } +} +func (r *RootCmd) Command(subcommands []*clibase.Cmd) (*clibase.Cmd, error) { fmtLong := `Coder %s — A tool for provisioning self-hosted development environments with Terraform. ` - cmd := &cobra.Command{ - Use: "coder", - SilenceErrors: true, - SilenceUsage: true, - Long: fmt.Sprintf(fmtLong, buildinfo.Version()), - Args: func(cmd *cobra.Command, args []string) error { - if gitauth.CheckCommand(args, os.Environ()) { - isGitAskpass = true - return nil - } - return cobra.NoArgs(cmd, args) - }, - RunE: func(cmd *cobra.Command, args []string) error { - if isGitAskpass { - return gitAskpass().RunE(cmd, args) - } - return cmd.Help() - }, - Example: formatExamples( + cmd := &clibase.Cmd{ + Use: "coder [global-flags] ", + Long: fmt.Sprintf(fmtLong, buildinfo.Version()) + formatExamples( example{ Description: "Start a Coder server", Command: "coder server", @@ -153,30 +156,204 @@ func Root(subcommands []*cobra.Command) *cobra.Command { Command: "coder templates init", }, ), + Handler: func(i *clibase.Invocation) error { + // fmt.Fprintf(i.Stderr, "env debug: %+v", i.Environ) + // The GIT_ASKPASS environment variable must point at + // a binary with no arguments. To prevent writing + // cross-platform scripts to invoke the Coder binary + // with a `gitaskpass` subcommand, we override the entrypoint + // to check if the command was invoked. + if gitauth.CheckCommand(i.Args, i.Environ.ToOS()) { + return r.gitAskpass().Handler(i) + } + return i.Command.HelpHandler(i) + }, + } + + cmd.AddSubcommands(subcommands...) + + // Set default help handler for all commands. + cmd.Walk(func(c *clibase.Cmd) { + if c.HelpHandler == nil { + c.HelpHandler = helpFn() + } + }) + + var merr error + // Add [flags] to usage when appropriate. + cmd.Walk(func(cmd *clibase.Cmd) { + const flags = "[flags]" + if strings.Contains(cmd.Use, flags) { + merr = errors.Join( + merr, + xerrors.Errorf( + "command %q shouldn't have %q in usage since it's automatically populated", + cmd.FullUsage(), + flags, + ), + ) + return + } + + var hasFlag bool + for _, opt := range cmd.Options { + if opt.Flag != "" { + hasFlag = true + break + } + } + + if !hasFlag { + return + } + + // We insert [flags] between the command's name and its arguments. + tokens := strings.SplitN(cmd.Use, " ", 2) + if len(tokens) == 1 { + cmd.Use = fmt.Sprintf("%s %s", tokens[0], flags) + return + } + cmd.Use = fmt.Sprintf("%s %s %s", tokens[0], flags, tokens[1]) + }) + + // Add alises when appropriate. + cmd.Walk(func(cmd *clibase.Cmd) { + // TODO: we should really be consistent about naming. + if cmd.Name() == "delete" || cmd.Name() == "remove" { + if slices.Contains(cmd.Aliases, "rm") { + merr = errors.Join( + merr, + xerrors.Errorf("command %q shouldn't have alias %q since it's added automatically", cmd.FullName(), "rm"), + ) + return + } + cmd.Aliases = append(cmd.Aliases, "rm") + } + }) + + // Sanity-check command options. + cmd.Walk(func(cmd *clibase.Cmd) { + for _, opt := range cmd.Options { + // Verify that every option is configurable. + if opt.Flag == "" && opt.Env == "" { + if cmd.Name() == "server" { + // The server command is funky and has YAML-only options, e.g. + // support links. + return + } + merr = errors.Join( + merr, + xerrors.Errorf("option %q in %q should have a flag or env", opt.Name, cmd.FullName()), + ) + } + } + }) + if merr != nil { + return nil, merr } - cmd.AddCommand(subcommands...) - fixUnknownSubcommandError(cmd.Commands()) + if r.agentURL == nil { + r.agentURL = new(url.URL) + } + if r.clientURL == nil { + r.clientURL = new(url.URL) + } - cmd.SetUsageTemplate(usageTemplateCobra()) + globalGroup := &clibase.Group{ + Name: "Global", + Description: `Global options are applied to all commands. They can be set using environment variables or flags.`, + } + cmd.Options = clibase.OptionSet{ + { + Flag: varURL, + Env: envURL, + Description: "URL to a deployment.", + Value: clibase.URLOf(r.clientURL), + Group: globalGroup, + }, + { + Flag: varToken, + Env: envSessionToken, + Description: fmt.Sprintf("Specify an authentication token. For security reasons setting %s is preferred.", envSessionToken), + Value: clibase.StringOf(&r.token), + Group: globalGroup, + }, + { + Flag: varAgentToken, + Description: "An agent authentication token.", + Value: clibase.StringOf(&r.agentToken), + Hidden: true, + Group: globalGroup, + }, + { + Flag: varAgentURL, + Env: "CODER_AGENT_URL", + Description: "URL for an agent to access your deployment.", + Value: clibase.URLOf(r.agentURL), + Hidden: true, + Group: globalGroup, + }, + { + Flag: varNoVersionCheck, + Env: envNoVersionCheck, + Description: "Suppress warning when client and server versions do not match.", + Value: clibase.BoolOf(&r.noVersionCheck), + Group: globalGroup, + }, + { + Flag: varNoFeatureWarning, + Env: envNoFeatureWarning, + Description: "Suppress warnings about unlicensed features.", + Value: clibase.BoolOf(&r.noFeatureWarning), + Group: globalGroup, + }, + { + Flag: varHeader, + Env: "CODER_HEADER", + Description: "Additional HTTP headers added to all requests. Provide as " + `key=value` + ". Can be specified multiple times.", + Value: clibase.StringArrayOf(&r.header), + Group: globalGroup, + }, + { + Flag: varNoOpen, + Env: "CODER_NO_OPEN", + Description: "Suppress opening the browser after logging in.", + Value: clibase.BoolOf(&r.noOpen), + Hidden: true, + Group: globalGroup, + }, + { + Flag: varForceTty, + Env: "CODER_FORCE_TTY", + Hidden: true, + Description: "Force the use of a TTY.", + Value: clibase.BoolOf(&r.forceTTY), + Group: globalGroup, + }, + { + Flag: varVerbose, + FlagShorthand: "v", + Env: "CODER_VERBOSE", + Description: "Enable verbose output.", + Value: clibase.BoolOf(&r.verbose), + Group: globalGroup, + }, + { + Flag: config.FlagName, + Env: "CODER_CONFIG_DIR", + Description: "Path to the global `coder` config directory.", + Default: config.DefaultDir(), + Value: clibase.StringOf(&r.globalConfig), + Group: globalGroup, + }, + } - cliflag.String(cmd.PersistentFlags(), varURL, "", envURL, "", "URL to a deployment.") - cliflag.Bool(cmd.PersistentFlags(), varNoVersionCheck, "", envNoVersionCheck, false, "Suppress warning when client and server versions do not match.") - cliflag.Bool(cmd.PersistentFlags(), varNoFeatureWarning, "", envNoFeatureWarning, false, "Suppress warnings about unlicensed features.") - cliflag.String(cmd.PersistentFlags(), varToken, "", envSessionToken, "", fmt.Sprintf("Specify an authentication token. For security reasons setting %s is preferred.", envSessionToken)) - cliflag.String(cmd.PersistentFlags(), varAgentToken, "", "CODER_AGENT_TOKEN", "", "An agent authentication token.") - _ = cmd.PersistentFlags().MarkHidden(varAgentToken) - cliflag.String(cmd.PersistentFlags(), varAgentURL, "", "CODER_AGENT_URL", "", "URL for an agent to access your deployment.") - _ = cmd.PersistentFlags().MarkHidden(varAgentURL) - cliflag.String(cmd.PersistentFlags(), config.FlagName, "", "CODER_CONFIG_DIR", config.DefaultDir(), "Path to the global `coder` config directory.") - cliflag.StringArray(cmd.PersistentFlags(), varHeader, "", "CODER_HEADER", []string{}, "HTTP headers added to all requests. Provide as \"Key=Value\"") - cmd.PersistentFlags().Bool(varForceTty, false, "Force the `coder` command to run as if connected to a TTY.") - _ = cmd.PersistentFlags().MarkHidden(varForceTty) - cmd.PersistentFlags().Bool(varNoOpen, false, "Block automatically opening URLs in the browser.") - _ = cmd.PersistentFlags().MarkHidden(varNoOpen) - cliflag.Bool(cmd.PersistentFlags(), varVerbose, "v", "CODER_VERBOSE", false, "Enable verbose output.") + err := cmd.PrepareAll() + if err != nil { + return nil, err + } - return cmd + return cmd, nil } type contextKey int @@ -194,41 +371,12 @@ func LoggerFromContext(ctx context.Context) (slog.Logger, bool) { return l, ok } -// fixUnknownSubcommandError modifies the provided commands so that the -// ones with subcommands output the correct error message when an -// unknown subcommand is invoked. -// -// Example: -// -// unknown command "bad" for "coder templates" -func fixUnknownSubcommandError(commands []*cobra.Command) { - for _, sc := range commands { - if sc.HasSubCommands() { - if sc.Run == nil && sc.RunE == nil { - if sc.Args != nil { - // In case the developer does not know about this - // behavior in Cobra they must verify correct - // behavior. For instance, settings Args to - // `cobra.ExactArgs(0)` will not give the same - // message as `cobra.NoArgs`. Likewise, omitting the - // run function will not give the wanted error. - panic("developer error: subcommand has subcommands and Args but no Run or RunE") - } - sc.Args = cobra.NoArgs - sc.Run = func(*cobra.Command, []string) {} - } - - fixUnknownSubcommandError(sc.Commands()) - } - } -} - -// versionCmd prints the coder version -func versionCmd() *cobra.Command { - return &cobra.Command{ +// version prints the coder version +func (*RootCmd) version() *clibase.Cmd { + return &clibase.Cmd{ Use: "version", Short: "Show coder version", - RunE: func(cmd *cobra.Command, args []string) error { + Handler: func(inv *clibase.Invocation) error { var str strings.Builder _, _ = str.WriteString("Coder ") if buildinfo.IsAGPL() { @@ -247,7 +395,7 @@ func versionCmd() *cobra.Command { _, _ = str.WriteString(fmt.Sprintf("Full build of Coder, supports the %s subcommand.\n", cliui.Styles.Code.Render("server"))) } - _, _ = fmt.Fprint(cmd.OutOrStdout(), str.String()) + _, _ = fmt.Fprint(inv.Stdout, str.String()) return nil }, } @@ -257,119 +405,140 @@ func isTest() bool { return flag.Lookup("test.v") != nil } -// CreateClient returns a new client from the command context. +// RootCmd contains parameters and helpers useful to all commands. +type RootCmd struct { + clientURL *url.URL + token string + globalConfig string + header []string + agentToken string + agentURL *url.URL + forceTTY bool + noOpen bool + verbose bool + + noVersionCheck bool + noFeatureWarning bool +} + +// InitClient sets client to a new client. // It reads from global configuration files if flags are not set. -func CreateClient(cmd *cobra.Command) (*codersdk.Client, error) { - root := createConfig(cmd) - rawURL, err := cmd.Flags().GetString(varURL) - if err != nil || rawURL == "" { - rawURL, err = root.URL().Read() - if err != nil { - // If the configuration files are absent, the user is logged out - if os.IsNotExist(err) { - return nil, errUnauthenticated +func (r *RootCmd) InitClient(client *codersdk.Client) clibase.MiddlewareFunc { + if client == nil { + panic("client is nil") + } + if r == nil { + panic("root is nil") + } + return func(next clibase.HandlerFunc) clibase.HandlerFunc { + return func(i *clibase.Invocation) error { + conf := r.createConfig() + var err error + if r.clientURL == nil || r.clientURL.String() == "" { + rawURL, err := conf.URL().Read() + // If the configuration files are absent, the user is logged out + if os.IsNotExist(err) { + return (errUnauthenticated) + } + if err != nil { + return err + } + + r.clientURL, err = url.Parse(strings.TrimSpace(rawURL)) + if err != nil { + return err + } } - return nil, err - } - } - serverURL, err := url.Parse(strings.TrimSpace(rawURL)) - if err != nil { - return nil, err - } - token, err := cmd.Flags().GetString(varToken) - if err != nil || token == "" { - token, err = root.Session().Read() - if err != nil { - // If the configuration files are absent, the user is logged out - if os.IsNotExist(err) { - return nil, errUnauthenticated + + if r.token == "" { + r.token, err = conf.Session().Read() + // If the configuration files are absent, the user is logged out + if os.IsNotExist(err) { + return (errUnauthenticated) + } + if err != nil { + return err + } } - return nil, err - } - } - client, err := createUnauthenticatedClient(cmd, serverURL) - if err != nil { - return nil, err - } - client.SetSessionToken(token) - // We send these requests in parallel to minimize latency. - var ( - versionErr = make(chan error) - warningErr = make(chan error) - ) - go func() { - versionErr <- checkVersions(cmd, client) - close(versionErr) - }() + err = r.setClient(client, r.clientURL) + if err != nil { + return err + } - go func() { - warningErr <- checkWarnings(cmd, client) - close(warningErr) - }() + client.SetSessionToken(r.token) + + // We send these requests in parallel to minimize latency. + var ( + versionErr = make(chan error) + warningErr = make(chan error) + ) + go func() { + versionErr <- r.checkVersions(i, client) + close(versionErr) + }() + + go func() { + warningErr <- r.checkWarnings(i, client) + close(warningErr) + }() + + if err = <-versionErr; err != nil { + // Just log the error here. We never want to fail a command + // due to a pre-run. + _, _ = fmt.Fprintf(i.Stderr, + cliui.Styles.Warn.Render("check versions error: %s"), err) + _, _ = fmt.Fprintln(i.Stderr) + } - if err = <-versionErr; err != nil { - // Just log the error here. We never want to fail a command - // due to a pre-run. - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), - cliui.Styles.Warn.Render("check versions error: %s"), err) - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) - } + if err = <-warningErr; err != nil { + // Same as above + _, _ = fmt.Fprintf(i.Stderr, + cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err) + _, _ = fmt.Fprintln(i.Stderr) + } - if err = <-warningErr; err != nil { - // Same as above - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), - cliui.Styles.Warn.Render("check entitlement warnings error: %s"), err) - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) + return next(i) + } } - - return client, nil } -func createUnauthenticatedClient(cmd *cobra.Command, serverURL *url.URL) (*codersdk.Client, error) { - client := codersdk.New(serverURL) - headers, err := cmd.Flags().GetStringArray(varHeader) - if err != nil { - return nil, err - } +func (r *RootCmd) setClient(client *codersdk.Client, serverURL *url.URL) error { transport := &headerTransport{ transport: http.DefaultTransport, header: http.Header{}, } - for _, header := range headers { + for _, header := range r.header { parts := strings.SplitN(header, "=", 2) if len(parts) < 2 { - return nil, xerrors.Errorf("split header %q had less than two parts", header) + return xerrors.Errorf("split header %q had less than two parts", header) } transport.header.Add(parts[0], parts[1]) } - client.HTTPClient.Transport = transport - return client, nil + client.URL = serverURL + client.HTTPClient = &http.Client{ + Transport: transport, + } + return nil +} + +func (r *RootCmd) createUnauthenticatedClient(serverURL *url.URL) (*codersdk.Client, error) { + var client codersdk.Client + err := r.setClient(&client, serverURL) + return &client, err } // createAgentClient returns a new client from the command context. // It works just like CreateClient, but uses the agent token and URL instead. -func createAgentClient(cmd *cobra.Command) (*agentsdk.Client, error) { - rawURL, err := cmd.Flags().GetString(varAgentURL) - if err != nil { - return nil, err - } - serverURL, err := url.Parse(rawURL) - if err != nil { - return nil, err - } - token, err := cmd.Flags().GetString(varAgentToken) - if err != nil { - return nil, err - } - client := agentsdk.New(serverURL) - client.SetSessionToken(token) +func (r *RootCmd) createAgentClient() (*agentsdk.Client, error) { + client := agentsdk.New(r.agentURL) + client.SetSessionToken(r.agentToken) return client, nil } // CurrentOrganization returns the currently active organization for the authenticated user. -func CurrentOrganization(cmd *cobra.Command, client *codersdk.Client) (codersdk.Organization, error) { - orgs, err := client.OrganizationsByUser(cmd.Context(), codersdk.Me) +func CurrentOrganization(inv *clibase.Invocation, client *codersdk.Client) (codersdk.Organization, error) { + orgs, err := client.OrganizationsByUser(inv.Context(), codersdk.Me) if err != nil { return codersdk.Organization{}, nil } @@ -381,7 +550,7 @@ func CurrentOrganization(cmd *cobra.Command, client *codersdk.Client) (codersdk. // namedWorkspace fetches and returns a workspace by an identifier, which may be either // a bare name (for a workspace owned by the current user) or a "user/workspace" combination, // where user is either a username or UUID. -func namedWorkspace(cmd *cobra.Command, client *codersdk.Client, identifier string) (codersdk.Workspace, error) { +func namedWorkspace(ctx context.Context, client *codersdk.Client, identifier string) (codersdk.Workspace, error) { parts := strings.Split(identifier, "/") var owner, name string @@ -396,30 +565,24 @@ func namedWorkspace(cmd *cobra.Command, client *codersdk.Client, identifier stri return codersdk.Workspace{}, xerrors.Errorf("invalid workspace name: %q", identifier) } - return client.WorkspaceByOwnerAndName(cmd.Context(), owner, name, codersdk.WorkspaceOptions{}) + return client.WorkspaceByOwnerAndName(ctx, owner, name, codersdk.WorkspaceOptions{}) } // createConfig consumes the global configuration flag to produce a config root. -func createConfig(cmd *cobra.Command) config.Root { - globalRoot, err := cmd.Flags().GetString(config.FlagName) - if err != nil { - panic(err) - } - return config.Root(globalRoot) +func (r *RootCmd) createConfig() config.Root { + return config.Root(r.globalConfig) } // isTTY returns whether the passed reader is a TTY or not. -// This accepts a reader to work with Cobra's "InOrStdin" -// function for simple testing. -func isTTY(cmd *cobra.Command) bool { +func isTTY(inv *clibase.Invocation) bool { // If the `--force-tty` command is available, and set, // assume we're in a tty. This is primarily for cases on Windows // where we may not be able to reliably detect this automatically (ie, tests) - forceTty, err := cmd.Flags().GetBool(varForceTty) + forceTty, err := inv.ParsedFlags().GetBool(varForceTty) if forceTty && err == nil { return true } - file, ok := cmd.InOrStdin().(*os.File) + file, ok := inv.Stdin.(*os.File) if !ok { return false } @@ -427,125 +590,30 @@ func isTTY(cmd *cobra.Command) bool { } // isTTYOut returns whether the passed reader is a TTY or not. -// This accepts a reader to work with Cobra's "OutOrStdout" -// function for simple testing. -func isTTYOut(cmd *cobra.Command) bool { - return isTTYWriter(cmd, cmd.OutOrStdout) +func isTTYOut(inv *clibase.Invocation) bool { + return isTTYWriter(inv, inv.Stdout) } // isTTYErr returns whether the passed reader is a TTY or not. -// This accepts a reader to work with Cobra's "ErrOrStderr" -// function for simple testing. -func isTTYErr(cmd *cobra.Command) bool { - return isTTYWriter(cmd, cmd.ErrOrStderr) +func isTTYErr(inv *clibase.Invocation) bool { + return isTTYWriter(inv, inv.Stderr) } -func isTTYWriter(cmd *cobra.Command, writer func() io.Writer) bool { +func isTTYWriter(inv *clibase.Invocation, writer io.Writer) bool { // If the `--force-tty` command is available, and set, // assume we're in a tty. This is primarily for cases on Windows // where we may not be able to reliably detect this automatically (ie, tests) - forceTty, err := cmd.Flags().GetBool(varForceTty) + forceTty, err := inv.ParsedFlags().GetBool(varForceTty) if forceTty && err == nil { return true } - file, ok := writer().(*os.File) + file, ok := writer.(*os.File) if !ok { return false } return isatty.IsTerminal(file.Fd()) } -var templateFunctions = template.FuncMap{ - "usageHeader": usageHeader, - "isWorkspaceCommand": isWorkspaceCommand, -} - -func usageHeader(s string) string { - // Customizes the color of headings to make subcommands more visually - // appealing. - return cliui.Styles.Placeholder.Render(s) -} - -func isWorkspaceCommand(cmd *cobra.Command) bool { - if _, ok := cmd.Annotations["workspaces"]; ok { - return true - } - var ws bool - cmd.VisitParents(func(cmd *cobra.Command) { - if _, ok := cmd.Annotations["workspaces"]; ok { - ws = true - } - }) - return ws -} - -// We will eventually replace this with the clibase template describedc -// in usage.go. We don't want to continue working around -// Cobra's feature-set. -func usageTemplateCobra() string { - // usageHeader is defined in init(). - return `{{usageHeader "Usage:"}} -{{- if .Runnable}} - {{.UseLine}} -{{end}} -{{- if .HasAvailableSubCommands}} - {{.CommandPath}} [command] -{{end}} - -{{- if gt (len .Aliases) 0}} -{{usageHeader "Aliases:"}} - {{.NameAndAliases}} -{{end}} - -{{- if .HasExample}} -{{usageHeader "Get Started:"}} -{{.Example}} -{{end}} - -{{- $isRootHelp := (not .HasParent)}} -{{- if .HasAvailableSubCommands}} -{{usageHeader "Commands:"}} - {{- range .Commands}} - {{- $isRootWorkspaceCommand := (and $isRootHelp (isWorkspaceCommand .))}} - {{- if (or (and .IsAvailableCommand (not $isRootWorkspaceCommand)) (eq .Name "help"))}} - {{rpad .Name .NamePadding }} {{.Short}} - {{- end}} - {{- end}} -{{end}} - -{{- if (and $isRootHelp .HasAvailableSubCommands)}} -{{usageHeader "Workspace Commands:"}} - {{- range .Commands}} - {{- if (and .IsAvailableCommand (isWorkspaceCommand .))}} - {{rpad .Name .NamePadding }} {{.Short}} - {{- end}} - {{- end}} -{{end}} - -{{- if .HasAvailableLocalFlags}} -{{usageHeader "Flags:"}} -{{.LocalFlags.FlagUsagesWrapped 100 | trimTrailingWhitespaces}} -{{end}} - -{{- if .HasAvailableInheritedFlags}} -{{usageHeader "Global Flags:"}} -{{.InheritedFlags.FlagUsagesWrapped 100 | trimTrailingWhitespaces}} -{{end}} - -{{- if .HasHelpSubCommands}} -{{usageHeader "Additional help topics:"}} - {{- range .Commands}} - {{- if .IsAdditionalHelpTopicCommand}} - {{rpad .CommandPath .CommandPathPadding}} {{.Short}} - {{- end}} - {{- end}} -{{end}} - -{{- if .HasAvailableSubCommands}} -Use "{{.CommandPath}} [command] --help" for more information about a command. -{{end}}` -} - // example represents a standard example for command usage, to be used // with formatExamples. type example struct { @@ -574,36 +642,12 @@ func formatExamples(examples ...example) string { return sb.String() } -// FormatCobraError colorizes and adds "--help" docs to cobra commands. -func FormatCobraError(err error, cmd *cobra.Command) string { - helpErrMsg := fmt.Sprintf("Run '%s --help' for usage.", cmd.CommandPath()) - - var ( - httpErr *codersdk.Error - output strings.Builder - ) - - if xerrors.As(err, &httpErr) { - _, _ = fmt.Fprintln(&output, httpErr.Friendly()) - } - - // If the httpErr is nil then we just have a regular error in which - // case we want to print out what's happening. - if httpErr == nil || cliflag.IsSetBool(cmd, varVerbose) { - _, _ = fmt.Fprintln(&output, err.Error()) - } - - _, _ = fmt.Fprint(&output, helpErrMsg) - - return cliui.Styles.Error.Render(output.String()) -} - -func checkVersions(cmd *cobra.Command, client *codersdk.Client) error { - if cliflag.IsSetBool(cmd, varNoVersionCheck) { +func (r *RootCmd) checkVersions(i *clibase.Invocation, client *codersdk.Client) error { + if r.noVersionCheck { return nil } - ctx, cancel := context.WithTimeout(cmd.Context(), 10*time.Second) + ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second) defer cancel() clientVersion := buildinfo.Version() @@ -629,25 +673,25 @@ func checkVersions(cmd *cobra.Command, client *codersdk.Client) error { if !buildinfo.VersionsMatch(clientVersion, info.Version) { warn := cliui.Styles.Warn.Copy().Align(lipgloss.Left) - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), warn.Render(fmtWarningText), clientVersion, info.Version, strings.TrimPrefix(info.CanonicalVersion(), "v")) - _, _ = fmt.Fprintln(cmd.ErrOrStderr()) + _, _ = fmt.Fprintf(i.Stderr, warn.Render(fmtWarningText), clientVersion, info.Version, strings.TrimPrefix(info.CanonicalVersion(), "v")) + _, _ = fmt.Fprintln(i.Stderr) } return nil } -func checkWarnings(cmd *cobra.Command, client *codersdk.Client) error { - if cliflag.IsSetBool(cmd, varNoFeatureWarning) { +func (r *RootCmd) checkWarnings(i *clibase.Invocation, client *codersdk.Client) error { + if r.noFeatureWarning { return nil } - ctx, cancel := context.WithTimeout(cmd.Context(), 10*time.Second) + ctx, cancel := context.WithTimeout(i.Context(), 10*time.Second) defer cancel() entitlements, err := client.Entitlements(ctx) if err == nil { for _, w := range entitlements.Warnings { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), cliui.Styles.Warn.Render(w)) + _, _ = fmt.Fprintln(i.Stderr, cliui.Styles.Warn.Render(w)) } } return nil @@ -773,3 +817,94 @@ func isConnectionError(err error) bool { return xerrors.As(err, &dnsErr) || xerrors.As(err, &opErr) } + +type prettyErrorFormatter struct { + level int + w io.Writer +} + +func (prettyErrorFormatter) prefixLines(spaces int, s string) string { + twidth, _, err := terminal.GetSize(0) + if err != nil { + twidth = 80 + } + + s = lipgloss.NewStyle().Width(twidth - spaces).Render(s) + + var b strings.Builder + scanner := bufio.NewScanner(strings.NewReader(s)) + for i := 0; scanner.Scan(); i++ { + // The first line is already padded. + if i == 0 { + _, _ = fmt.Fprintf(&b, "%s\n", scanner.Text()) + continue + } + _, _ = fmt.Fprintf(&b, "%s%s\n", strings.Repeat(" ", spaces), scanner.Text()) + } + return strings.TrimSuffix(strings.TrimSuffix(b.String(), "\n"), " ") +} + +func (p *prettyErrorFormatter) format(err error) { + underErr := errors.Unwrap(err) + + arrowStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("#515151")) + + //nolint:errorlint + if _, ok := err.(*clibase.RunCommandError); ok && p.level == 0 && underErr != nil { + // We can do a better job now. + p.format(underErr) + return + } + + var ( + padding string + arrowWidth int + ) + if p.level > 0 { + const arrow = "┗━ " + arrowWidth = utf8.RuneCount([]byte(arrow)) + padding = strings.Repeat(" ", arrowWidth*p.level) + _, _ = fmt.Fprintf(p.w, "%v%v", padding, arrowStyle.Render(arrow)) + } + + if underErr != nil { + header := strings.TrimSuffix(err.Error(), ": "+underErr.Error()) + _, _ = fmt.Fprintf(p.w, "%s\n", p.prefixLines(len(padding)+arrowWidth, header)) + p.level++ + p.format(underErr) + return + } + + { + style := lipgloss.NewStyle().Foreground(lipgloss.Color("#D16644")).Background(lipgloss.Color("#000000")).Bold(false) + // This is the last error in a tree. + p.wrappedPrintf( + "%s\n", + p.prefixLines( + len(padding)+arrowWidth, + fmt.Sprintf( + "%s%s%s", + lipgloss.NewStyle().Inherit(style).Underline(true).Render("ERROR"), + lipgloss.NewStyle().Inherit(style).Foreground(arrowStyle.GetForeground()).Render(" ► "), + style.Render(err.Error()), + ), + ), + ) + } +} + +func (p *prettyErrorFormatter) wrappedPrintf(format string, a ...interface{}) { + s := lipgloss.NewStyle().Width(ttyWidth()).Render( + fmt.Sprintf(format, a...), + ) + + // Not sure why, but lipgloss is adding extra spaces we need to remove. + excessSpaceRe := regexp.MustCompile(`[[:blank:]]*\n[[:blank:]]*$`) + s = excessSpaceRe.ReplaceAllString(s, "\n") + + _, _ = p.w.Write( + []byte( + s, + ), + ) +} diff --git a/cli/root_internal_test.go b/cli/root_internal_test.go index fc5cebfb7daa6..e8c463e95cc90 100644 --- a/cli/root_internal_test.go +++ b/cli/root_internal_test.go @@ -24,11 +24,11 @@ func Test_formatExamples(t *testing.T) { name: "Output examples", examples: []example{ { - Description: "Hello world", + Description: "Hello world.", Command: "echo hello", }, { - Description: "Bye bye", + Description: "Bye bye.", Command: "echo bye", }, }, @@ -73,5 +73,7 @@ func TestMain(m *testing.M) { // https://github.com/natefinch/lumberjack/pull/100 goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).mill.func1"), + // The pq library appears to leave around a goroutine after Close(). + goleak.IgnoreTopFunction("github.com/lib/pq.NewDialListener"), ) } diff --git a/cli/root_test.go b/cli/root_test.go index 0f76782707723..22c1c3e36ae85 100644 --- a/cli/root_test.go +++ b/cli/root_test.go @@ -10,18 +10,17 @@ import ( "os" "path/filepath" "regexp" - "runtime" "strings" "testing" - "github.com/spf13/cobra" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" "github.com/coder/coder/buildinfo" "github.com/coder/coder/cli" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/clitest" + "github.com/coder/coder/cli/config" "github.com/coder/coder/coderd/coderdtest" "github.com/coder/coder/coderd/database/dbtestutil" "github.com/coder/coder/codersdk" @@ -34,39 +33,26 @@ var updateGoldenFiles = flag.Bool("update", false, "update .golden files") var timestampRegex = regexp.MustCompile(`(?i)\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(.\d+)?Z`) -//nolint:tparallel,paralleltest // These test sets env vars. func TestCommandHelp(t *testing.T) { - commonEnv := map[string]string{ - "HOME": "~", - "CODER_CONFIG_DIR": "~/.config/coderv2", - } - + t.Parallel() rootClient, replacements := prepareTestData(t) type testCase struct { name string cmd []string - env map[string]string } tests := []testCase{ { name: "coder --help", cmd: []string{"--help"}, }, - // Re-enable after clibase migrations. - // { - // name: "coder server --help", - // cmd: []string{"server", "--help"}, - // env: map[string]string{ - // "CODER_CACHE_DIRECTORY": "~/.cache/coder", - // }, - // }, + { + name: "coder server --help", + cmd: []string{"server", "--help"}, + }, { name: "coder agent --help", cmd: []string{"agent", "--help"}, - env: map[string]string{ - "CODER_AGENT_LOG_DIR": "/tmp", - }, }, { name: "coder list --output json", @@ -78,9 +64,12 @@ func TestCommandHelp(t *testing.T) { }, } - root := cli.Root(cli.AGPL()) + rootCmd := new(cli.RootCmd) + root, err := rootCmd.Command(rootCmd.AGPL()) + require.NoError(t, err) + ExtractCommandPathsLoop: - for _, cp := range extractVisibleCommandPaths(nil, root.Commands()) { + for _, cp := range extractVisibleCommandPaths(nil, root.Children) { name := fmt.Sprintf("coder %s --help", strings.Join(cp, " ")) cmd := append(cp, "--help") for _, tt := range tests { @@ -91,100 +80,88 @@ ExtractCommandPathsLoop: tests = append(tests, testCase{name: name, cmd: cmd}) } - wd, err := os.Getwd() - require.NoError(t, err) - if runtime.GOOS == "windows" { - wd = strings.ReplaceAll(wd, "\\", "\\\\") - } - for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - env := make(map[string]string) - for k, v := range commonEnv { - env[k] = v - } - for k, v := range tt.env { - env[k] = v - } - - // Unset all CODER_ environment variables for a clean slate. - for _, kv := range os.Environ() { - name := strings.Split(kv, "=")[0] - if _, ok := env[name]; !ok && strings.HasPrefix(name, "CODER_") { - t.Setenv(name, "") - } - } - // Override environment variables for a reproducible test. - for k, v := range env { - t.Setenv(k, v) - } + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) - ctx, _ := testutil.Context(t) + var outBuf bytes.Buffer + inv, cfg := clitest.New(t, tt.cmd...) + inv.Stderr = &outBuf + inv.Stdout = &outBuf + inv.Environ.Set("CODER_URL", rootClient.URL.String()) + inv.Environ.Set("CODER_SESSION_TOKEN", rootClient.SessionToken()) + inv.Environ.Set("CODER_CACHE_DIRECTORY", "~/.cache") - tmpwd := "/" - if runtime.GOOS == "windows" { - tmpwd = "C:\\" - } - err := os.Chdir(tmpwd) - var buf bytes.Buffer - cmd, cfg := clitest.New(t, tt.cmd...) clitest.SetupConfig(t, rootClient, cfg) - cmd.SetOut(&buf) - assert.NoError(t, err) - err = cmd.ExecuteContext(ctx) - err2 := os.Chdir(wd) - require.NoError(t, err) - require.NoError(t, err2) - got := buf.Bytes() + clitest.StartWithWaiter(t, inv.WithContext(ctx)).RequireSuccess() - replace := map[string][]byte{ - // Remove CRLF newlines (Windows). - string([]byte{'\r', '\n'}): []byte("\n"), - // The `coder templates create --help` command prints the path - // to the working directory (--directory flag default value). - fmt.Sprintf("%q", tmpwd): []byte("\"[current directory]\""), + actual := outBuf.Bytes() + if len(actual) == 0 { + t.Fatal("no output") } + for k, v := range replacements { - replace[k] = []byte(v) - } - for k, v := range replace { - got = bytes.ReplaceAll(got, []byte(k), v) + actual = bytes.ReplaceAll(actual, []byte(k), []byte(v)) } // Replace any timestamps with a placeholder. - got = timestampRegex.ReplaceAll(got, []byte("[timestamp]")) + actual = timestampRegex.ReplaceAll(actual, []byte("[timestamp]")) + + homeDir, err := os.UserHomeDir() + require.NoError(t, err) + + configDir := config.DefaultDir() + actual = bytes.ReplaceAll(actual, []byte(configDir), []byte("~/.config/coderv2")) - gf := filepath.Join("testdata", strings.Replace(tt.name, " ", "_", -1)+".golden") + actual = bytes.ReplaceAll(actual, []byte(codersdk.DefaultCacheDir()), []byte("[cache dir]")) + + // The home directory changes depending on the test environment. + actual = bytes.ReplaceAll(actual, []byte(homeDir), []byte("~")) + + goldenPath := filepath.Join("testdata", strings.Replace(tt.name, " ", "_", -1)+".golden") if *updateGoldenFiles { - t.Logf("update golden file for: %q: %s", tt.name, gf) - err = os.WriteFile(gf, got, 0o600) + t.Logf("update golden file for: %q: %s", tt.name, goldenPath) + err = os.WriteFile(goldenPath, actual, 0o600) require.NoError(t, err, "update golden file") } - want, err := os.ReadFile(gf) + expected, err := os.ReadFile(goldenPath) require.NoError(t, err, "read golden file, run \"make update-golden-files\" and commit the changes") - // Remove CRLF newlines (Windows). - want = bytes.ReplaceAll(want, []byte{'\r', '\n'}, []byte{'\n'}) - require.Equal(t, string(want), string(got), "golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes", gf) + + // Normalize files to tolerate different operating systems. + for _, r := range []struct { + old string + new string + }{ + {"\r\n", "\n"}, + {`~\.cache\coder`, "~/.cache/coder"}, + {`C:\Users\RUNNER~1\AppData\Local\Temp`, "/tmp"}, + {os.TempDir(), "/tmp"}, + } { + expected = bytes.ReplaceAll(expected, []byte(r.old), []byte(r.new)) + actual = bytes.ReplaceAll(actual, []byte(r.old), []byte(r.new)) + } + require.Equal( + t, string(expected), string(actual), + "golden file mismatch: %s, run \"make update-golden-files\", verify and commit the changes", + goldenPath, + ) }) } } -func extractVisibleCommandPaths(cmdPath []string, cmds []*cobra.Command) [][]string { +func extractVisibleCommandPaths(cmdPath []string, cmds []*clibase.Cmd) [][]string { var cmdPaths [][]string for _, c := range cmds { if c.Hidden { continue } - // TODO: re-enable after clibase migration. - if c.Name() == "server" { - continue - } cmdPath := append(cmdPath, c.Name()) cmdPaths = append(cmdPaths, cmdPath) - cmdPaths = append(cmdPaths, extractVisibleCommandPaths(cmdPath, c.Commands())...) + cmdPaths = append(cmdPaths, extractVisibleCommandPaths(cmdPath, c.Children)...) } return cmdPaths } @@ -241,113 +218,13 @@ func prepareTestData(t *testing.T) (*codersdk.Client, map[string]string) { func TestRoot(t *testing.T) { t.Parallel() - t.Run("FormatCobraError", func(t *testing.T) { - t.Parallel() - - t.Run("OK", func(t *testing.T) { - t.Parallel() - - cmd, _ := clitest.New(t, "delete") - - cmd, err := cmd.ExecuteC() - errStr := cli.FormatCobraError(err, cmd) - require.Contains(t, errStr, "Run 'coder delete --help' for usage.") - }) - - t.Run("Verbose", func(t *testing.T) { - t.Parallel() - - // Test that the verbose error is masked without verbose flag. - t.Run("NoVerboseAPIError", func(t *testing.T) { - t.Parallel() - - cmd, _ := clitest.New(t) - - cmd.RunE = func(cmd *cobra.Command, args []string) error { - var err error = &codersdk.Error{ - Response: codersdk.Response{ - Message: "This is a message.", - }, - Helper: "Try this instead.", - } - - err = xerrors.Errorf("wrap me: %w", err) - - return err - } - - cmd, err := cmd.ExecuteC() - errStr := cli.FormatCobraError(err, cmd) - require.Contains(t, errStr, "This is a message. Try this instead.") - require.NotContains(t, errStr, err.Error()) - }) - - // Assert that a regular error is not masked when verbose is not - // specified. - t.Run("NoVerboseRegularError", func(t *testing.T) { - t.Parallel() - - cmd, _ := clitest.New(t) - - cmd.RunE = func(cmd *cobra.Command, args []string) error { - return xerrors.Errorf("this is a non-codersdk error: %w", xerrors.Errorf("a wrapped error")) - } - - cmd, err := cmd.ExecuteC() - errStr := cli.FormatCobraError(err, cmd) - require.Contains(t, errStr, err.Error()) - }) - - // Test that both the friendly error and the verbose error are - // displayed when verbose is passed. - t.Run("APIError", func(t *testing.T) { - t.Parallel() - - cmd, _ := clitest.New(t, "--verbose") - - cmd.RunE = func(cmd *cobra.Command, args []string) error { - var err error = &codersdk.Error{ - Response: codersdk.Response{ - Message: "This is a message.", - }, - Helper: "Try this instead.", - } - - err = xerrors.Errorf("wrap me: %w", err) - - return err - } - - cmd, err := cmd.ExecuteC() - errStr := cli.FormatCobraError(err, cmd) - require.Contains(t, errStr, "This is a message. Try this instead.") - require.Contains(t, errStr, err.Error()) - }) - - // Assert that a regular error is not masked when verbose specified. - t.Run("RegularError", func(t *testing.T) { - t.Parallel() - - cmd, _ := clitest.New(t, "--verbose") - - cmd.RunE = func(cmd *cobra.Command, args []string) error { - return xerrors.Errorf("this is a non-codersdk error: %w", xerrors.Errorf("a wrapped error")) - } - - cmd, err := cmd.ExecuteC() - errStr := cli.FormatCobraError(err, cmd) - require.Contains(t, errStr, err.Error()) - }) - }) - }) - t.Run("Version", func(t *testing.T) { t.Parallel() buf := new(bytes.Buffer) - cmd, _ := clitest.New(t, "version") - cmd.SetOut(buf) - err := cmd.Execute() + inv, _ := clitest.New(t, "version") + inv.Stdout = buf + err := inv.Run() require.NoError(t, err) output := buf.String() @@ -370,9 +247,9 @@ func TestRoot(t *testing.T) { })) defer srv.Close() buf := new(bytes.Buffer) - cmd, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL) - cmd.SetOut(buf) + inv, _ := clitest.New(t, "--header", "X-Testing=wow", "login", srv.URL) + inv.Stdout = buf // This won't succeed, because we're using the login cmd to assert requests. - _ = cmd.Execute() + _ = inv.Run() }) } diff --git a/cli/scaletest.go b/cli/scaletest.go index bea8c7fd17c9d..7e6c79ad8abfa 100644 --- a/cli/scaletest.go +++ b/cli/scaletest.go @@ -14,11 +14,10 @@ import ( "time" "github.com/google/uuid" - "github.com/spf13/cobra" "go.opentelemetry.io/otel/trace" "golang.org/x/xerrors" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/httpapi" "github.com/coder/coder/coderd/tracing" @@ -33,21 +32,19 @@ import ( const scaletestTracerName = "coder_scaletest" -func scaletest() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) scaletest() *clibase.Cmd { + cmd := &clibase.Cmd{ Use: "scaletest", Short: "Run a scale test against the Coder API", - Long: "Perform scale tests against the Coder server.", - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() + Handler: func(inv *clibase.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*clibase.Cmd{ + r.scaletestCleanup(), + r.scaletestCreateWorkspaces(), }, } - cmd.AddCommand( - scaletestCleanup(), - scaletestCreateWorkspaces(), - ) - return cmd } @@ -58,11 +55,34 @@ type scaletestTracingFlags struct { tracePropagate bool } -func (s *scaletestTracingFlags) attach(cmd *cobra.Command) { - cliflag.BoolVarP(cmd.Flags(), &s.traceEnable, "trace", "", "CODER_LOADTEST_TRACE", false, "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md") - cliflag.BoolVarP(cmd.Flags(), &s.traceCoder, "trace-coder", "", "CODER_LOADTEST_TRACE_CODER", false, "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.") - cliflag.StringVarP(cmd.Flags(), &s.traceHoneycombAPIKey, "trace-honeycomb-api-key", "", "CODER_LOADTEST_TRACE_HONEYCOMB_API_KEY", "", "Enables trace exporting to Honeycomb.io using the provided API key.") - cliflag.BoolVarP(cmd.Flags(), &s.tracePropagate, "trace-propagate", "", "CODER_LOADTEST_TRACE_PROPAGATE", false, "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.") +func (s *scaletestTracingFlags) attach(opts *clibase.OptionSet) { + *opts = append( + *opts, + clibase.Option{ + Flag: "trace", + Env: "CODER_SCALETEST_TRACE", + Description: "Whether application tracing data is collected. It exports to a backend configured by environment variables. See: https://github.com/open-telemetry/opentelemetry-specification/blob/main/specification/protocol/exporter.md.", + Value: clibase.BoolOf(&s.traceEnable), + }, + clibase.Option{ + Flag: "trace-coder", + Env: "CODER_SCALETEST_TRACE_CODER", + Description: "Whether opentelemetry traces are sent to Coder. We recommend keeping this disabled unless we advise you to enable it.", + Value: clibase.BoolOf(&s.traceCoder), + }, + clibase.Option{ + Flag: "trace-honeycomb-api-key", + Env: "CODER_SCALETEST_TRACE_HONEYCOMB_API_KEY", + Description: "Enables trace exporting to Honeycomb.io using the provided API key.", + Value: clibase.StringOf(&s.traceHoneycombAPIKey), + }, + clibase.Option{ + Flag: "trace-propagate", + Env: "CODER_SCALETEST_TRACE_PROPAGATE", + Description: "Enables trace propagation to the Coder backend, which will be used to correlate server-side spans with client-side spans. Only enable this if the server is configured with the exact same tracing configuration as the client.", + Value: clibase.BoolOf(&s.tracePropagate), + }, + ) } // provider returns a trace.TracerProvider, a close function and a bool showing @@ -96,24 +116,45 @@ func (s *scaletestTracingFlags) provider(ctx context.Context) (trace.TracerProvi type scaletestStrategyFlags struct { cleanup bool - concurrency int + concurrency int64 timeout time.Duration timeoutPerJob time.Duration } -func (s *scaletestStrategyFlags) attach(cmd *cobra.Command) { - concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_LOADTEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited." - timeoutLong, timeoutEnv, timeoutDescription := "timeout", "CODER_LOADTEST_TIMEOUT", "Timeout for the entire test run. 0 means unlimited." - jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription := "job-timeout", "CODER_LOADTEST_JOB_TIMEOUT", "Timeout per job. Jobs may take longer to complete under higher concurrency limits." +func (s *scaletestStrategyFlags) attach(opts *clibase.OptionSet) { + concurrencyLong, concurrencyEnv, concurrencyDescription := "concurrency", "CODER_SCALETEST_CONCURRENCY", "Number of concurrent jobs to run. 0 means unlimited." + timeoutLong, timeoutEnv, timeoutDescription := "timeout", "CODER_SCALETEST_TIMEOUT", "Timeout for the entire test run. 0 means unlimited." + jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription := "job-timeout", "CODER_SCALETEST_JOB_TIMEOUT", "Timeout per job. Jobs may take longer to complete under higher concurrency limits." if s.cleanup { - concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_LOADTEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs") - timeoutLong, timeoutEnv, timeoutDescription = "cleanup-"+timeoutLong, "CODER_LOADTEST_CLEANUP_TIMEOUT", strings.ReplaceAll(timeoutDescription, "test", "cleanup") - jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription = "cleanup-"+jobTimeoutLong, "CODER_LOADTEST_CLEANUP_JOB_TIMEOUT", strings.ReplaceAll(jobTimeoutDescription, "jobs", "cleanup jobs") + concurrencyLong, concurrencyEnv, concurrencyDescription = "cleanup-"+concurrencyLong, "CODER_SCALETEST_CLEANUP_CONCURRENCY", strings.ReplaceAll(concurrencyDescription, "jobs", "cleanup jobs") + timeoutLong, timeoutEnv, timeoutDescription = "cleanup-"+timeoutLong, "CODER_SCALETEST_CLEANUP_TIMEOUT", strings.ReplaceAll(timeoutDescription, "test", "cleanup") + jobTimeoutLong, jobTimeoutEnv, jobTimeoutDescription = "cleanup-"+jobTimeoutLong, "CODER_SCALETEST_CLEANUP_JOB_TIMEOUT", strings.ReplaceAll(jobTimeoutDescription, "jobs", "cleanup jobs") } - cliflag.IntVarP(cmd.Flags(), &s.concurrency, concurrencyLong, "", concurrencyEnv, 1, concurrencyDescription) - cliflag.DurationVarP(cmd.Flags(), &s.timeout, timeoutLong, "", timeoutEnv, 30*time.Minute, timeoutDescription) - cliflag.DurationVarP(cmd.Flags(), &s.timeoutPerJob, jobTimeoutLong, "", jobTimeoutEnv, 5*time.Minute, jobTimeoutDescription) + *opts = append( + *opts, + clibase.Option{ + Flag: concurrencyLong, + Env: concurrencyEnv, + Description: concurrencyDescription, + Default: "1", + Value: clibase.Int64Of(&s.concurrency), + }, + clibase.Option{ + Flag: timeoutLong, + Env: timeoutEnv, + Description: timeoutDescription, + Default: "30m", + Value: clibase.DurationOf(&s.timeout), + }, + clibase.Option{ + Flag: jobTimeoutLong, + Env: jobTimeoutEnv, + Description: jobTimeoutDescription, + Default: "5m", + Value: clibase.DurationOf(&s.timeoutPerJob), + }, + ) } func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy { @@ -124,7 +165,7 @@ func (s *scaletestStrategyFlags) toStrategy() harness.ExecutionStrategy { strategy = harness.ConcurrentExecutionStrategy{} } else { strategy = harness.ParallelExecutionStrategy{ - Limit: s.concurrency, + Limit: int(s.concurrency), } } @@ -208,8 +249,14 @@ type scaletestOutputFlags struct { outputSpecs []string } -func (s *scaletestOutputFlags) attach(cmd *cobra.Command) { - cliflag.StringArrayVarP(cmd.Flags(), &s.outputSpecs, "output", "", "CODER_SCALETEST_OUTPUTS", []string{"text"}, `Output format specs in the format "[:]". Not specifying a path will default to stdout. Available formats: text, json.`) +func (s *scaletestOutputFlags) attach(opts *clibase.OptionSet) { + *opts = append(*opts, clibase.Option{ + Flag: "output", + Env: "CODER_SCALETEST_OUTPUTS", + Description: `Output format specs in the format "[:]". Not specifying a path will default to stdout. Available formats: text, json.`, + Default: "text", + Value: clibase.StringArrayOf(&s.outputSpecs), + }) } func (s *scaletestOutputFlags) parse() ([]scaleTestOutput, error) { @@ -308,21 +355,21 @@ func (r *userCleanupRunner) Run(ctx context.Context, _ string, _ io.Writer) erro return nil } -func scaletestCleanup() *cobra.Command { +func (r *RootCmd) scaletestCleanup() *clibase.Cmd { cleanupStrategy := &scaletestStrategyFlags{cleanup: true} + client := new(codersdk.Client) - cmd := &cobra.Command{ + cmd := &clibase.Cmd{ Use: "cleanup", - Short: "Cleanup any orphaned scaletest resources", - Long: "Cleanup scaletest workspaces, then cleanup scaletest users. The strategy flags will apply to each stage of the cleanup process.", - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - client, err := CreateClient(cmd) - if err != nil { - return err - } - - _, err = requireAdmin(ctx, client) + Short: "Cleanup scaletest workspaces, then cleanup scaletest users.", + Long: "The strategy flags will apply to each stage of the cleanup process.", + Middleware: clibase.Chain( + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() + + _, err := requireAdmin(ctx, client) if err != nil { return err } @@ -336,7 +383,7 @@ func scaletestCleanup() *cobra.Command { }, } - cmd.PrintErrln("Fetching scaletest workspaces...") + cliui.Infof(inv.Stdout, "Fetching scaletest workspaces...") var ( pageNumber = 0 limit = 100 @@ -366,9 +413,9 @@ func scaletestCleanup() *cobra.Command { workspaces = append(workspaces, pageWorkspaces...) } - cmd.PrintErrf("Found %d scaletest workspaces\n", len(workspaces)) + cliui.Errorf(inv.Stderr, "Found %d scaletest workspaces\n", len(workspaces)) if len(workspaces) != 0 { - cmd.Println("Deleting scaletest workspaces...") + cliui.Infof(inv.Stdout, "Deleting scaletest workspaces..."+"\n") harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{}) for i, w := range workspaces { @@ -384,16 +431,16 @@ func scaletestCleanup() *cobra.Command { return xerrors.Errorf("run test harness to delete workspaces (harness failure, not a test failure): %w", err) } - cmd.Println("Done deleting scaletest workspaces:") + cliui.Infof(inv.Stdout, "Done deleting scaletest workspaces:"+"\n") res := harness.Results() - res.PrintText(cmd.ErrOrStderr()) + res.PrintText(inv.Stderr) if res.TotalFail > 0 { return xerrors.Errorf("failed to delete scaletest workspaces") } } - cmd.PrintErrln("Fetching scaletest users...") + cliui.Infof(inv.Stdout, "Fetching scaletest users...") pageNumber = 0 limit = 100 var users []codersdk.User @@ -423,9 +470,9 @@ func scaletestCleanup() *cobra.Command { users = append(users, pageUsers...) } - cmd.PrintErrf("Found %d scaletest users\n", len(users)) + cliui.Errorf(inv.Stderr, "Found %d scaletest users\n", len(users)) if len(workspaces) != 0 { - cmd.Println("Deleting scaletest users...") + cliui.Infof(inv.Stdout, "Deleting scaletest users..."+"\n") harness := harness.NewTestHarness(cleanupStrategy.toStrategy(), harness.ConcurrentExecutionStrategy{}) for i, u := range users { @@ -444,9 +491,9 @@ func scaletestCleanup() *cobra.Command { return xerrors.Errorf("run test harness to delete users (harness failure, not a test failure): %w", err) } - cmd.Println("Done deleting scaletest users:") + cliui.Infof(inv.Stdout, "Done deleting scaletest users:"+"\n") res := harness.Results() - res.PrintText(cmd.ErrOrStderr()) + res.PrintText(inv.Stderr) if res.TotalFail > 0 { return xerrors.Errorf("failed to delete scaletest users") @@ -457,13 +504,13 @@ func scaletestCleanup() *cobra.Command { }, } - cleanupStrategy.attach(cmd) + cleanupStrategy.attach(&cmd.Options) return cmd } -func scaletestCreateWorkspaces() *cobra.Command { +func (r *RootCmd) scaletestCreateWorkspaces() *clibase.Cmd { var ( - count int + count int64 template string parametersFile string parameters []string // key=value @@ -494,18 +541,15 @@ func scaletestCreateWorkspaces() *cobra.Command { output = &scaletestOutputFlags{} ) - cmd := &cobra.Command{ - Use: "create-workspaces", - Short: "Creates many workspaces and waits for them to be ready", - Long: `Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard. + client := new(codersdk.Client) -It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`, - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() - client, err := CreateClient(cmd) - if err != nil { - return err - } + cmd := &clibase.Cmd{ + Use: "create-workspaces", + Short: "Creates many users, then creates a workspace for each user and waits for them finish building and fully come online. Optionally runs a command inside each workspace, and connects to the workspace over WireGuard.", + Long: `It is recommended that all rate limits are disabled on the server before running this scaletest. This test generates many login events which will be rate limited against the (most likely single) IP.`, + Middleware: r.InitClient(client), + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() me, err := requireAdmin(ctx, client) if err != nil { @@ -612,16 +656,16 @@ It is recommended that all rate limits are disabled on the server before running if err != nil { return xerrors.Errorf("start dry run workspace creation: %w", err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Planning workspace...") - err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ + _, _ = fmt.Fprintln(inv.Stdout, "Planning workspace...") + err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{ Fetch: func() (codersdk.ProvisionerJob, error) { - return client.TemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) + return client.TemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID) }, Cancel: func() error { - return client.CancelTemplateVersionDryRun(cmd.Context(), templateVersion.ID, dryRun.ID) + return client.CancelTemplateVersionDryRun(inv.Context(), templateVersion.ID, dryRun.ID) }, Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { - return client.TemplateVersionDryRunLogsAfter(cmd.Context(), templateVersion.ID, dryRun.ID, 0) + return client.TemplateVersionDryRunLogsAfter(inv.Context(), templateVersion.ID, dryRun.ID, 0) }, // Don't show log output for the dry-run unless there's an error. Silent: true, @@ -645,7 +689,7 @@ It is recommended that all rate limits are disabled on the server before running tracer := tracerProvider.Tracer(scaletestTracerName) th := harness.NewTestHarness(strategy.toStrategy(), cleanupStrategy.toStrategy()) - for i := 0; i < count; i++ { + for i := 0; i < int(count); i++ { const name = "workspacebuild" id := strconv.Itoa(i) @@ -728,7 +772,7 @@ It is recommended that all rate limits are disabled on the server before running } // TODO: live progress output - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Running load test...") + _, _ = fmt.Fprintln(inv.Stderr, "Running load test...") testCtx, testCancel := strategy.toContext(ctx) defer testCancel() err = th.Run(testCtx) @@ -738,13 +782,13 @@ It is recommended that all rate limits are disabled on the server before running res := th.Results() for _, o := range outputs { - err = o.write(res, cmd.OutOrStdout()) + err = o.write(res, inv.Stdout) if err != nil { return xerrors.Errorf("write output %q to %q: %w", o.format, o.path, err) } } - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nCleaning up...") + _, _ = fmt.Fprintln(inv.Stderr, "\nCleaning up...") cleanupCtx, cleanupCancel := cleanupStrategy.toContext(ctx) defer cleanupCancel() err = th.Cleanup(cleanupCtx) @@ -754,12 +798,12 @@ It is recommended that all rate limits are disabled on the server before running // Upload traces. if tracingEnabled { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "\nUploading traces...") + _, _ = fmt.Fprintln(inv.Stderr, "\nUploading traces...") ctx, cancel := context.WithTimeout(ctx, 1*time.Minute) defer cancel() err := closeTracing(ctx) if err != nil { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "\nError uploading traces: %+v\n", err) + _, _ = fmt.Fprintf(inv.Stderr, "\nError uploading traces: %+v\n", err) } } @@ -771,32 +815,124 @@ It is recommended that all rate limits are disabled on the server before running }, } - cliflag.IntVarP(cmd.Flags(), &count, "count", "c", "CODER_LOADTEST_COUNT", 1, "Required: Number of workspaces to create.") - cliflag.StringVarP(cmd.Flags(), &template, "template", "t", "CODER_LOADTEST_TEMPLATE", "", "Required: Name or ID of the template to use for workspaces.") - cliflag.StringVarP(cmd.Flags(), ¶metersFile, "parameters-file", "", "CODER_LOADTEST_PARAMETERS_FILE", "", "Path to a YAML file containing the parameters to use for each workspace.") - cliflag.StringArrayVarP(cmd.Flags(), ¶meters, "parameter", "", "CODER_LOADTEST_PARAMETERS", []string{}, "Parameters to use for each workspace. Can be specified multiple times. Overrides any existing parameters with the same name from --parameters-file. Format: key=value") - - cliflag.BoolVarP(cmd.Flags(), &noPlan, "no-plan", "", "CODER_LOADTEST_NO_PLAN", false, "Skip the dry-run step to plan the workspace creation. This step ensures that the given parameters are valid for the given template.") - cliflag.BoolVarP(cmd.Flags(), &noCleanup, "no-cleanup", "", "CODER_LOADTEST_NO_CLEANUP", false, "Do not clean up resources after the test completes. You can cleanup manually using `coder scaletest cleanup`.") - // cliflag.BoolVarP(cmd.Flags(), &noCleanupFailures, "no-cleanup-failures", "", "CODER_LOADTEST_NO_CLEANUP_FAILURES", false, "Do not clean up resources from failed jobs to aid in debugging failures. You can cleanup manually using `coder scaletest cleanup`.") - cliflag.BoolVarP(cmd.Flags(), &noWaitForAgents, "no-wait-for-agents", "", "CODER_LOADTEST_NO_WAIT_FOR_AGENTS", false, "Do not wait for agents to start before marking the test as succeeded. This can be useful if you are running the test against a template that does not start the agent quickly.") - - cliflag.StringVarP(cmd.Flags(), &runCommand, "run-command", "", "CODER_LOADTEST_RUN_COMMAND", "", "Command to run inside each workspace using reconnecting-pty (i.e. web terminal protocol). If not specified, no command will be run.") - cliflag.DurationVarP(cmd.Flags(), &runTimeout, "run-timeout", "", "CODER_LOADTEST_RUN_TIMEOUT", 5*time.Second, "Timeout for the command to complete.") - cliflag.BoolVarP(cmd.Flags(), &runExpectTimeout, "run-expect-timeout", "", "CODER_LOADTEST_RUN_EXPECT_TIMEOUT", false, "Expect the command to timeout. If the command does not finish within the given --run-timeout, it will be marked as succeeded. If the command finishes before the timeout, it will be marked as failed.") - cliflag.StringVarP(cmd.Flags(), &runExpectOutput, "run-expect-output", "", "CODER_LOADTEST_RUN_EXPECT_OUTPUT", "", "Expect the command to output the given string (on a single line). If the command does not output the given string, it will be marked as failed.") - cliflag.BoolVarP(cmd.Flags(), &runLogOutput, "run-log-output", "", "CODER_LOADTEST_RUN_LOG_OUTPUT", false, "Log the output of the command to the test logs. This should be left off unless you expect small amounts of output. Large amounts of output will cause high memory usage.") - - cliflag.StringVarP(cmd.Flags(), &connectURL, "connect-url", "", "CODER_LOADTEST_CONNECT_URL", "", "URL to connect to inside the the workspace over WireGuard. If not specified, no connections will be made over WireGuard.") - cliflag.StringVarP(cmd.Flags(), &connectMode, "connect-mode", "", "CODER_LOADTEST_CONNECT_MODE", "derp", "Mode to use for connecting to the workspace. Can be 'derp' or 'direct'.") - cliflag.DurationVarP(cmd.Flags(), &connectHold, "connect-hold", "", "CODER_LOADTEST_CONNECT_HOLD", 30*time.Second, "How long to hold the WireGuard connection open for.") - cliflag.DurationVarP(cmd.Flags(), &connectInterval, "connect-interval", "", "CODER_LOADTEST_CONNECT_INTERVAL", time.Second, "How long to wait between making requests to the --connect-url once the connection is established.") - cliflag.DurationVarP(cmd.Flags(), &connectTimeout, "connect-timeout", "", "CODER_LOADTEST_CONNECT_TIMEOUT", 5*time.Second, "Timeout for each request to the --connect-url.") - - tracingFlags.attach(cmd) - strategy.attach(cmd) - cleanupStrategy.attach(cmd) - output.attach(cmd) + cmd.Options = clibase.OptionSet{ + { + Flag: "count", + FlagShorthand: "c", + Env: "CODER_SCALETEST_COUNT", + Default: "1", + Description: "Required: Number of workspaces to create.", + Value: clibase.Int64Of(&count), + }, + { + Flag: "template", + FlagShorthand: "t", + Env: "CODER_SCALETEST_TEMPLATE", + Description: "Required: Name or ID of the template to use for workspaces.", + Value: clibase.StringOf(&template), + }, + { + Flag: "parameters-file", + Env: "CODER_SCALETEST_PARAMETERS_FILE", + Description: "Path to a YAML file containing the parameters to use for each workspace.", + Value: clibase.StringOf(¶metersFile), + }, + { + Flag: "parameter", + Env: "CODER_SCALETEST_PARAMETERS", + Description: "Parameters to use for each workspace. Can be specified multiple times. Overrides any existing parameters with the same name from --parameters-file. Format: key=value.", + Value: clibase.StringArrayOf(¶meters), + }, + { + Flag: "no-plan", + Env: "CODER_SCALETEST_NO_PLAN", + Description: `Skip the dry-run step to plan the workspace creation. This step ensures that the given parameters are valid for the given template.`, + Value: clibase.BoolOf(&noPlan), + }, + { + Flag: "no-cleanup", + Env: "CODER_SCALETEST_NO_CLEANUP", + Description: "Do not clean up resources after the test completes. You can cleanup manually using coder scaletest cleanup.", + Value: clibase.BoolOf(&noCleanup), + }, + { + Flag: "no-wait-for-agents", + Env: "CODER_SCALETEST_NO_WAIT_FOR_AGENTS", + Description: `Do not wait for agents to start before marking the test as succeeded. This can be useful if you are running the test against a template that does not start the agent quickly.`, + Value: clibase.BoolOf(&noWaitForAgents), + }, + { + Flag: "run-command", + Env: "CODER_SCALETEST_RUN_COMMAND", + Description: "Command to run inside each workspace using reconnecting-pty (i.e. web terminal protocol). " + "If not specified, no command will be run.", + Value: clibase.StringOf(&runCommand), + }, + { + Flag: "run-timeout", + Env: "CODER_SCALETEST_RUN_TIMEOUT", + Default: "5s", + Description: "Timeout for the command to complete.", + Value: clibase.DurationOf(&runTimeout), + }, + { + Flag: "run-expect-timeout", + Env: "CODER_SCALETEST_RUN_EXPECT_TIMEOUT", + + Description: "Expect the command to timeout." + " If the command does not finish within the given --run-timeout, it will be marked as succeeded." + " If the command finishes before the timeout, it will be marked as failed.", + Value: clibase.BoolOf(&runExpectTimeout), + }, + { + Flag: "run-expect-output", + Env: "CODER_SCALETEST_RUN_EXPECT_OUTPUT", + Description: "Expect the command to output the given string (on a single line). " + "If the command does not output the given string, it will be marked as failed.", + Value: clibase.StringOf(&runExpectOutput), + }, + { + Flag: "run-log-output", + Env: "CODER_SCALETEST_RUN_LOG_OUTPUT", + Description: "Log the output of the command to the test logs. " + "This should be left off unless you expect small amounts of output. " + "Large amounts of output will cause high memory usage.", + Value: clibase.BoolOf(&runLogOutput), + }, + { + Flag: "connect-url", + Env: "CODER_SCALETEST_CONNECT_URL", + Description: "URL to connect to inside the the workspace over WireGuard. " + "If not specified, no connections will be made over WireGuard.", + Value: clibase.StringOf(&connectURL), + }, + { + Flag: "connect-mode", + Env: "CODER_SCALETEST_CONNECT_MODE", + Default: "derp", + Description: "Mode to use for connecting to the workspace.", + Value: clibase.EnumOf(&connectMode, "derp", "direct"), + }, + { + Flag: "connect-hold", + Env: "CODER_SCALETEST_CONNECT_HOLD", + Default: "30s", + Description: "How long to hold the WireGuard connection open for.", + Value: clibase.DurationOf(&connectHold), + }, + { + Flag: "connect-interval", + Env: "CODER_SCALETEST_CONNECT_INTERVAL", + Default: "1s", + Value: clibase.DurationOf(&connectInterval), + Description: "How long to wait between making requests to the --connect-url once the connection is established.", + }, + { + Flag: "connect-timeout", + Env: "CODER_SCALETEST_CONNECT_TIMEOUT", + Default: "5s", + Description: "Timeout for each request to the --connect-url.", + Value: clibase.DurationOf(&connectTimeout), + }, + } + + tracingFlags.attach(&cmd.Options) + strategy.attach(&cmd.Options) + cleanupStrategy.attach(&cmd.Options) + output.attach(&cmd.Options) return cmd } diff --git a/cli/scaletest_test.go b/cli/scaletest_test.go index 4052d4f0e4d15..3636b8ef40dc4 100644 --- a/cli/scaletest_test.go +++ b/cli/scaletest_test.go @@ -54,7 +54,7 @@ param3: 1 err = f.Close() require.NoError(t, err) - cmd, root := clitest.New(t, "scaletest", "create-workspaces", + inv, root := clitest.New(t, "scaletest", "create-workspaces", "--count", "2", "--template", template.Name, "--parameters-file", paramsFile, @@ -77,12 +77,12 @@ param3: 1 ) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() done := make(chan any) go func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) close(done) }() @@ -148,19 +148,19 @@ param3: 1 require.Len(t, users.Users, len(seenUsers)+1) // Cleanup. - cmd, root = clitest.New(t, "scaletest", "cleanup", + inv, root = clitest.New(t, "scaletest", "cleanup", "--cleanup-concurrency", "1", "--cleanup-timeout", "30s", "--cleanup-job-timeout", "15s", ) clitest.SetupConfig(t, client, root) pty = ptytest.New(t) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() done = make(chan any) go func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) close(done) }() diff --git a/cli/schedule.go b/cli/schedule.go index ff81b8e81dc50..8fff0121ae8db 100644 --- a/cli/schedule.go +++ b/cli/schedule.go @@ -6,9 +6,9 @@ import ( "time" "github.com/jedib0t/go-pretty/v6/table" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/schedule" "github.com/coder/coder/coderd/util/ptr" @@ -46,82 +46,78 @@ When enabling scheduled stop, enter a duration in one of the following formats: * 2m (2 minutes) * 2 (2 minutes) ` - scheduleOverrideDescriptionLong = `Override the stop time of a currently running workspace instance. + scheduleOverrideDescriptionLong = ` * The new stop time is calculated from *now*. * The new stop time must be at least 30 minutes in the future. * The workspace template may restrict the maximum workspace runtime. ` ) -func schedules() *cobra.Command { - scheduleCmd := &cobra.Command{ +func (r *RootCmd) schedules() *clibase.Cmd { + scheduleCmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "schedule { show | start | stop | override } ", Short: "Schedule automated start and stop times for workspaces", - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() + Handler: func(inv *clibase.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*clibase.Cmd{ + r.scheduleShow(), + r.scheduleStart(), + r.scheduleStop(), + r.scheduleOverride(), }, } - scheduleCmd.AddCommand( - scheduleShow(), - scheduleStart(), - scheduleStop(), - scheduleOverride(), - ) - return scheduleCmd } -func scheduleShow() *cobra.Command { - showCmd := &cobra.Command{ +func (r *RootCmd) scheduleShow() *clibase.Cmd { + client := new(codersdk.Client) + showCmd := &clibase.Cmd{ Use: "show ", Short: "Show workspace schedule", Long: scheduleShowDescriptionLong, - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - return displaySchedule(workspace, cmd.OutOrStdout()) + return displaySchedule(workspace, inv.Stdout) }, } return showCmd } -func scheduleStart() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) scheduleStart() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "start { [day-of-week] [location] | manual }", - Example: formatExamples( + Long: scheduleStartDescriptionLong + "\n" + formatExamples( example{ Description: "Set the workspace to start at 9:30am (in Dublin) from Monday to Friday", Command: "coder schedule start my-workspace 9:30AM Mon-Fri Europe/Dublin", }, ), Short: "Edit workspace start schedule", - Long: scheduleStartDescriptionLong, - Args: cobra.RangeArgs(2, 4), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireRangeArgs(2, 4), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } var schedStr *string - if args[1] != "manual" { - sched, err := parseCLISchedule(args[1:]...) + if inv.Args[1] != "manual" { + sched, err := parseCLISchedule(inv.Args[1:]...) if err != nil { return err } @@ -129,93 +125,89 @@ func scheduleStart() *cobra.Command { schedStr = ptr.Ref(sched.String()) } - err = client.UpdateWorkspaceAutostart(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{ + err = client.UpdateWorkspaceAutostart(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceAutostartRequest{ Schedule: schedStr, }) if err != nil { return err } - updated, err := namedWorkspace(cmd, client, args[0]) + updated, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - return displaySchedule(updated, cmd.OutOrStdout()) + return displaySchedule(updated, inv.Stdout) }, } return cmd } -func scheduleStop() *cobra.Command { - return &cobra.Command{ - Args: cobra.ExactArgs(2), - Use: "stop { | manual }", - Example: formatExamples( +func (r *RootCmd) scheduleStop() *clibase.Cmd { + client := new(codersdk.Client) + return &clibase.Cmd{ + Use: "stop { | manual }", + Long: scheduleStopDescriptionLong + "\n" + formatExamples( example{ Command: "coder schedule stop my-workspace 2h30m", }, ), Short: "Edit workspace stop schedule", - Long: scheduleStopDescriptionLong, - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } var durMillis *int64 - if args[1] != "manual" { - dur, err := parseDuration(args[1]) + if inv.Args[1] != "manual" { + dur, err := parseDuration(inv.Args[1]) if err != nil { return err } durMillis = ptr.Ref(dur.Milliseconds()) } - if err := client.UpdateWorkspaceTTL(cmd.Context(), workspace.ID, codersdk.UpdateWorkspaceTTLRequest{ + if err := client.UpdateWorkspaceTTL(inv.Context(), workspace.ID, codersdk.UpdateWorkspaceTTLRequest{ TTLMillis: durMillis, }); err != nil { return err } - updated, err := namedWorkspace(cmd, client, args[0]) + updated, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - return displaySchedule(updated, cmd.OutOrStdout()) + return displaySchedule(updated, inv.Stdout) }, } } -func scheduleOverride() *cobra.Command { - overrideCmd := &cobra.Command{ - Args: cobra.ExactArgs(2), - Use: "override-stop ", - Example: formatExamples( +func (r *RootCmd) scheduleOverride() *clibase.Cmd { + client := new(codersdk.Client) + overrideCmd := &clibase.Cmd{ + Use: "override-stop ", + Short: "Override the stop time of a currently running workspace instance.", + Long: scheduleOverrideDescriptionLong + "\n" + formatExamples( example{ Command: "coder schedule override-stop my-workspace 90m", }, ), - Short: "Edit stop time of active workspace", - Long: scheduleOverrideDescriptionLong, - RunE: func(cmd *cobra.Command, args []string) error { - overrideDuration, err := parseDuration(args[1]) + Middleware: clibase.Chain( + clibase.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + overrideDuration, err := parseDuration(inv.Args[1]) if err != nil { return err } - client, err := CreateClient(cmd) - if err != nil { - return xerrors.Errorf("create client: %w", err) - } - - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return xerrors.Errorf("get workspace: %w", err) } @@ -227,24 +219,24 @@ func scheduleOverride() *cobra.Command { if overrideDuration < 29*time.Minute { _, _ = fmt.Fprintf( - cmd.OutOrStdout(), + inv.Stdout, "Please specify a duration of at least 30 minutes.\n", ) return nil } newDeadline := time.Now().In(loc).Add(overrideDuration) - if err := client.PutExtendWorkspace(cmd.Context(), workspace.ID, codersdk.PutExtendWorkspaceRequest{ + if err := client.PutExtendWorkspace(inv.Context(), workspace.ID, codersdk.PutExtendWorkspaceRequest{ Deadline: newDeadline, }); err != nil { return err } - updated, err := namedWorkspace(cmd, client, args[0]) + updated, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - return displaySchedule(updated, cmd.OutOrStdout()) + return displaySchedule(updated, inv.Stdout) }, } return overrideCmd diff --git a/cli/schedule_test.go b/cli/schedule_test.go index cd30de7d7f551..a3a3a781ff578 100644 --- a/cli/schedule_test.go +++ b/cli/schedule_test.go @@ -42,11 +42,11 @@ func TestScheduleShow(t *testing.T) { stdoutBuf = &bytes.Buffer{} ) - cmd, root := clitest.New(t, cmdArgs...) + inv, root := clitest.New(t, cmdArgs...) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err := cmd.Execute() + err := inv.Run() require.NoError(t, err, "unexpected error") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -79,11 +79,11 @@ func TestScheduleShow(t *testing.T) { stdoutBuf = &bytes.Buffer{} ) - cmd, root := clitest.New(t, cmdArgs...) + inv, root := clitest.New(t, cmdArgs...) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err := cmd.Execute() + err := inv.Run() require.NoError(t, err, "unexpected error") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -104,10 +104,10 @@ func TestScheduleShow(t *testing.T) { _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) ) - cmd, root := clitest.New(t, "schedule", "show", "doesnotexist") + inv, root := clitest.New(t, "schedule", "show", "doesnotexist") clitest.SetupConfig(t, client, root) - err := cmd.Execute() + err := inv.Run() require.ErrorContains(t, err, "status code 404", "unexpected error") }) } @@ -132,11 +132,11 @@ func TestScheduleStart(t *testing.T) { ) // Set a well-specified autostart schedule - cmd, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM", "Mon-Fri", tz) + inv, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM", "Mon-Fri", tz) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err, "unexpected error") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -157,11 +157,11 @@ func TestScheduleStart(t *testing.T) { stdoutBuf = &bytes.Buffer{} // unset schedule - cmd, root = clitest.New(t, "schedule", "start", workspace.Name, "manual") + inv, root = clitest.New(t, "schedule", "start", workspace.Name, "manual") clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err = cmd.Execute() + err = inv.Run() assert.NoError(t, err, "unexpected error") lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -186,11 +186,11 @@ func TestScheduleStop(t *testing.T) { ) // Set the workspace TTL - cmd, root := clitest.New(t, "schedule", "stop", workspace.Name, ttl.String()) + inv, root := clitest.New(t, "schedule", "stop", workspace.Name, ttl.String()) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err, "unexpected error") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -203,11 +203,11 @@ func TestScheduleStop(t *testing.T) { stdoutBuf = &bytes.Buffer{} // Unset the workspace TTL - cmd, root = clitest.New(t, "schedule", "stop", workspace.Name, "manual") + inv, root = clitest.New(t, "schedule", "stop", workspace.Name, "manual") clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf - err = cmd.Execute() + err = inv.Run() assert.NoError(t, err, "unexpected error") lines = strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { @@ -247,12 +247,12 @@ func TestScheduleOverride(t *testing.T) { initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond) require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute) - cmd, root := clitest.New(t, cmdArgs...) + inv, root := clitest.New(t, cmdArgs...) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf // When: we execute `coder schedule override workspace ` - err = cmd.ExecuteContext(ctx) + err = inv.WithContext(ctx).Run() require.NoError(t, err) // Then: the deadline of the latest build is updated assuming the units are minutes @@ -287,12 +287,12 @@ func TestScheduleOverride(t *testing.T) { initDeadline := time.Now().Add(time.Duration(*workspace.TTLMillis) * time.Millisecond) require.WithinDuration(t, initDeadline, workspace.LatestBuild.Deadline.Time, time.Minute) - cmd, root := clitest.New(t, cmdArgs...) + inv, root := clitest.New(t, cmdArgs...) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf // When: we execute `coder bump workspace ` - err = cmd.ExecuteContext(ctx) + err = inv.WithContext(ctx).Run() // Then: the command fails require.ErrorContains(t, err, "invalid duration") }) @@ -339,12 +339,12 @@ func TestScheduleOverride(t *testing.T) { require.Zero(t, workspace.LatestBuild.Deadline) require.NoError(t, err) - cmd, root := clitest.New(t, cmdArgs...) + inv, root := clitest.New(t, cmdArgs...) clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) + inv.Stdout = stdoutBuf // When: we execute `coder bump workspace`` - err = cmd.ExecuteContext(ctx) + err = inv.WithContext(ctx).Run() require.Error(t, err) // Then: nothing happens and the deadline remains unset @@ -370,11 +370,10 @@ func TestScheduleStartDefaults(t *testing.T) { ) // Set an underspecified schedule - cmd, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM") + inv, root := clitest.New(t, "schedule", "start", workspace.Name, "9:30AM") clitest.SetupConfig(t, client, root) - cmd.SetOut(stdoutBuf) - - err := cmd.Execute() + inv.Stdout = stdoutBuf + err := inv.Run() require.NoError(t, err, "unexpected error") lines := strings.Split(strings.TrimSpace(stdoutBuf.String()), "\n") if assert.Len(t, lines, 4) { diff --git a/cli/server.go b/cli/server.go index bb53b4218e290..28c57f5580132 100644 --- a/cli/server.go +++ b/cli/server.go @@ -41,7 +41,6 @@ import ( "github.com/prometheus/client_golang/prometheus/collectors" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/spf13/afero" - "github.com/spf13/cobra" "go.opentelemetry.io/otel/trace" "golang.org/x/mod/semver" "golang.org/x/oauth2" @@ -97,7 +96,7 @@ func ReadGitAuthProvidersFromEnv(environ []string) ([]codersdk.GitAuthConfig, er sort.Strings(environ) var providers []codersdk.GitAuthConfig - for _, v := range clibase.ParseEnviron(environ, envPrefix+"GITAUTH_") { + for _, v := range clibase.ParseEnviron(environ, "CODER_GITAUTH_") { tokens := strings.SplitN(v.Name, "_", 2) if len(tokens) != 2 { return nil, xerrors.Errorf("invalid env var: %s", v.Name) @@ -157,92 +156,29 @@ func ReadGitAuthProvidersFromEnv(environ []string) ([]codersdk.GitAuthConfig, er } // nolint:gocyclo -func Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { - root := &cobra.Command{ - Use: "server", - Short: "Start a Coder server", - DisableFlagParsing: true, - RunE: func(cmd *cobra.Command, args []string) error { +func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *clibase.Cmd { + var ( + cfg = new(codersdk.DeploymentValues) + opts = cfg.Options() + ) + serverCmd := &clibase.Cmd{ + Use: "server", + Short: "Start a Coder server", + Options: opts, + Middleware: clibase.RequireNArgs(0), + Handler: func(inv *clibase.Invocation) error { // Main command context for managing cancellation of running // services. - ctx, cancel := context.WithCancel(cmd.Context()) + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - cfg := &codersdk.DeploymentValues{} - cliOpts := cfg.Options() - var configDir clibase.String - // This is a hack to get around the fact that the Cobra-defined - // flags are not available. - cliOpts.Add(clibase.Option{ - Name: "Global Config", - Flag: config.FlagName, - Description: "Global Config is ignored in server mode.", - Hidden: true, - Default: config.DefaultDir(), - Value: &configDir, - }) - - err := cliOpts.SetDefaults() - if err != nil { - return xerrors.Errorf("set defaults: %w", err) - } - - err = cliOpts.ParseEnv(clibase.ParseEnviron(os.Environ(), envPrefix)) - if err != nil { - return xerrors.Errorf("parse env: %w", err) - } - - flagSet := cliOpts.FlagSet() - // These parents and children will be moved once we convert the - // rest of the `cli` package to clibase. - flagSet.Usage = usageFn(cmd.ErrOrStderr(), &clibase.Cmd{ - Parent: &clibase.Cmd{ - Use: "coder", - }, - Children: []*clibase.Cmd{ - { - Use: "postgres-builtin-url", - Short: "Output the connection URL for the built-in PostgreSQL deployment.", - }, - { - Use: "postgres-builtin-serve", - Short: "Run the built-in PostgreSQL deployment.", - }, - }, - Use: "server [flags]", - Short: "Start a Coder server", - Long: ` -The server provides the Coder dashboard, API, and provisioners. -If no options are provided, the server will start with a built-in postgres -and an access URL provided by Coder's cloud service. - -Use the following command to print the built-in postgres URL: - $ coder server postgres-builtin-url - -Use the following command to manually run the built-in postgres: - $ coder server postgres-builtin-serve - -Options may be provided via environment variables prefixed with "CODER_", -flags, and YAML configuration. The precedence is as follows: - 1. Defaults - 2. YAML configuration - 3. Environment variables - 4. Flags - `, - Options: cliOpts, - }) - err = flagSet.Parse(args) - if err != nil { - return xerrors.Errorf("parse flags: %w", err) - } - if cfg.WriteConfig { // TODO: this should output to a file. - n, err := cliOpts.ToYAML() + n, err := opts.ToYAML() if err != nil { return xerrors.Errorf("generate yaml: %w", err) } - enc := yaml.NewEncoder(cmd.ErrOrStderr()) + enc := yaml.NewEncoder(inv.Stderr) err = enc.Encode(n) if err != nil { return xerrors.Errorf("encode yaml: %w", err) @@ -255,7 +191,7 @@ flags, and YAML configuration. The precedence is as follows: } // Print deprecation warnings. - for _, opt := range cliOpts { + for _, opt := range opts { if opt.UseInstead == nil { continue } @@ -273,8 +209,8 @@ flags, and YAML configuration. The precedence is as follows: } warnStr += "instead.\n" - cmd.PrintErr( - cliui.Styles.Warn.Render("WARN: ") + warnStr, + cliui.Warn(inv.Stderr, + warnStr, ) } @@ -313,8 +249,8 @@ flags, and YAML configuration. The precedence is as follows: filesRateLimit = -1 } - printLogo(cmd) - logger, logCloser, err := buildLogger(cmd, cfg) + printLogo(inv) + logger, logCloser, err := buildLogger(inv, cfg) if err != nil { return xerrors.Errorf("make logger: %w", err) } @@ -360,7 +296,7 @@ flags, and YAML configuration. The precedence is as follows: shouldCoderTrace := cfg.Telemetry.Enable.Value() && !isTest() // Only override if telemetryTraceEnable was specifically set. // By default we want it to be controlled by telemetryEnable. - if cmd.Flags().Changed("telemetry-trace") { + if inv.ParsedFlags().Changed("telemetry-trace") { shouldCoderTrace = cfg.Telemetry.Trace.Value() } @@ -389,12 +325,13 @@ flags, and YAML configuration. The precedence is as follows: } } - config := config.Root(configDir) + config := r.createConfig() + builtinPostgres := false // Only use built-in if PostgreSQL URL isn't specified! if !cfg.InMemoryDatabase && cfg.PostgresURL == "" { var closeFunc func() error - cmd.Printf("Using built-in PostgreSQL (%s)\n", config.PostgresPath()) + cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)", config.PostgresPath()) pgURL, closeFunc, err := startBuiltinPostgres(ctx, config, logger) if err != nil { return err @@ -406,12 +343,12 @@ flags, and YAML configuration. The precedence is as follows: } builtinPostgres = true defer func() { - cmd.Printf("Stopping built-in PostgreSQL...\n") + cliui.Infof(inv.Stdout, "Stopping built-in PostgreSQL...") // Gracefully shut PostgreSQL down! if err := closeFunc(); err != nil { - cmd.Printf("Failed to stop built-in PostgreSQL: %v\n", err) + cliui.Errorf(inv.Stderr, "Failed to stop built-in PostgreSQL: %v", err) } else { - cmd.Printf("Stopped built-in PostgreSQL\n") + cliui.Infof(inv.Stdout, "Stopped built-in PostgreSQL") } }() } @@ -423,7 +360,7 @@ flags, and YAML configuration. The precedence is as follows: if cfg.HTTPAddress.String() != "" { httpListener, err = net.Listen("tcp", cfg.HTTPAddress.String()) if err != nil { - return xerrors.Errorf("listen %q: %w", cfg.HTTPAddress.String(), err) + return err } defer httpListener.Close() @@ -438,7 +375,7 @@ flags, and YAML configuration. The precedence is as follows: // We want to print out the address the user supplied, not the // loopback device. - cmd.Println("Started HTTP listener at", (&url.URL{Scheme: "http", Host: listenAddrStr}).String()) + _, _ = fmt.Fprintf(inv.Stdout, "Started HTTP listener at %s\n", (&url.URL{Scheme: "http", Host: listenAddrStr}).String()) // Set the http URL we want to use when connecting to ourselves. tcpAddr, tcpAddrValid := httpListener.Addr().(*net.TCPAddr) @@ -466,8 +403,8 @@ flags, and YAML configuration. The precedence is as follows: // DEPRECATED: This redirect used to default to true. // It made more sense to have the redirect be opt-in. - if os.Getenv("CODER_TLS_REDIRECT_HTTP") == "true" || cmd.Flags().Changed("tls-redirect-http-to-https") { - cmd.PrintErr(cliui.Styles.Warn.Render("WARN:") + " --tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead\n") + if inv.Environ.Get("CODER_TLS_REDIRECT_HTTP") == "true" || inv.ParsedFlags().Changed("tls-redirect-http-to-https") { + cliui.Warn(inv.Stderr, "--tls-redirect-http-to-https is deprecated, please use --redirect-to-access-url instead") cfg.RedirectToAccessURL = cfg.TLS.RedirectHTTP } @@ -483,7 +420,7 @@ flags, and YAML configuration. The precedence is as follows: } httpsListenerInner, err := net.Listen("tcp", cfg.TLS.Address.String()) if err != nil { - return xerrors.Errorf("listen %q: %w", cfg.TLS.Address.String(), err) + return err } defer httpsListenerInner.Close() @@ -502,7 +439,7 @@ flags, and YAML configuration. The precedence is as follows: // We want to print out the address the user supplied, not the // loopback device. - cmd.Println("Started TLS/HTTPS listener at", (&url.URL{Scheme: "https", Host: listenAddrStr}).String()) + _, _ = fmt.Fprintf(inv.Stdout, "Started TLS/HTTPS listener at %s\n", (&url.URL{Scheme: "https", Host: listenAddrStr}).String()) // Set the https URL we want to use when connecting to // ourselves. @@ -547,7 +484,7 @@ flags, and YAML configuration. The precedence is as follows: tunnelDone <-chan struct{} = make(chan struct{}, 1) ) if cfg.AccessURL.String() == "" { - cmd.Printf("Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n") + cliui.Infof(inv.Stderr, "Opening tunnel so workspaces can connect to your deployment. For production scenarios, specify an external access URL\n") tunnel, err = devtunnel.New(ctx, logger.Named("devtunnel"), cfg.WgtunnelHost.String()) if err != nil { return xerrors.Errorf("create tunnel: %w", err) @@ -586,14 +523,15 @@ flags, and YAML configuration. The precedence is as follows: if isLocal { reason = "isn't externally reachable" } - cmd.Printf( - "%s The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n", - cliui.Styles.Warn.Render("Warning:"), cliui.Styles.Field.Render(cfg.AccessURL.String()), reason, + cliui.Warnf( + inv.Stderr, + "The access URL %s %s, this may cause unexpected problems when creating workspaces. Generate a unique *.try.coder.app URL by not specifying an access URL.\n", + cliui.Styles.Field.Render(cfg.AccessURL.String()), reason, ) } // A newline is added before for visibility in terminal output. - cmd.Printf("\nView the Web UI: %s\n", cfg.AccessURL.String()) + cliui.Infof(inv.Stdout, "\nView the Web UI: %s\n", cfg.AccessURL.String()) // Used for zero-trust instance identity with Google Cloud. googleTokenValidator, err := idtoken.NewValidator(ctx, option.WithoutAuthentication()) @@ -943,7 +881,7 @@ flags, and YAML configuration. The precedence is as follows: // than abstracting the Coder API itself. coderAPI, coderAPICloser, err := newAPI(ctx, options) if err != nil { - return err + return xerrors.Errorf("create coder API: %w", err) } client := codersdk.New(localURL) @@ -981,10 +919,15 @@ flags, and YAML configuration. The precedence is as follows: _ = daemon.Close() } }() + + var provisionerdWaitGroup sync.WaitGroup + defer provisionerdWaitGroup.Wait() provisionerdMetrics := provisionerd.NewMetrics(options.PrometheusRegistry) for i := int64(0); i < cfg.Provisioner.Daemons.Value(); i++ { daemonCacheDir := filepath.Join(cacheDir, fmt.Sprintf("provisioner-%d", i)) - daemon, err := newProvisionerDaemon(ctx, coderAPI, provisionerdMetrics, logger, cfg, daemonCacheDir, errCh, false) + daemon, err := newProvisionerDaemon( + ctx, coderAPI, provisionerdMetrics, logger, cfg, daemonCacheDir, errCh, false, &provisionerdWaitGroup, + ) if err != nil { return xerrors.Errorf("create provisioner daemon: %w", err) } @@ -1064,7 +1007,7 @@ flags, and YAML configuration. The precedence is as follows: } }() - cmd.Println("\n==> Logs will stream in below (press ctrl+c to gracefully exit):") + cliui.Infof(inv.Stdout, "\n==> Logs will stream in below (press ctrl+c to gracefully exit):") // Updates the systemd status from activating to activated. _, err = daemon.SdNotify(false, daemon.SdNotifyReady) @@ -1084,7 +1027,7 @@ flags, and YAML configuration. The precedence is as follows: select { case <-notifyCtx.Done(): exitErr = notifyCtx.Err() - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render( + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Bold.Render( "Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit", )) case <-tunnelDone: @@ -1092,7 +1035,7 @@ flags, and YAML configuration. The precedence is as follows: case exitErr = <-errCh: } if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) { - cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr) + cliui.Errorf(inv.Stderr, "Unexpected error, shutting down server: %s\n", exitErr) } // Begin clean shut down stage, we try to shut down services @@ -1104,18 +1047,18 @@ flags, and YAML configuration. The precedence is as follows: _, err = daemon.SdNotify(false, daemon.SdNotifyStopping) if err != nil { - cmd.Printf("Notify systemd failed: %s", err) + cliui.Errorf(inv.Stderr, "Notify systemd failed: %s", err) } // Stop accepting new connections without interrupting // in-flight requests, give in-flight requests 5 seconds to // complete. - cmd.Println("Shutting down API server...") + cliui.Info(inv.Stdout, "Shutting down API server..."+"\n") err = shutdownWithTimeout(httpServer.Shutdown, 3*time.Second) if err != nil { - cmd.Printf("API server shutdown took longer than 3s: %s\n", err) + cliui.Errorf(inv.Stderr, "API server shutdown took longer than 3s: %s\n", err) } else { - cmd.Printf("Gracefully shut down API server\n") + cliui.Info(inv.Stdout, "Gracefully shut down API server\n") } // Cancel any remaining in-flight requests. shutdownConns() @@ -1130,36 +1073,36 @@ flags, and YAML configuration. The precedence is as follows: go func() { defer wg.Done() - if ok, _ := cmd.Flags().GetBool(varVerbose); ok { - cmd.Printf("Shutting down provisioner daemon %d...\n", id) + if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok { + cliui.Infof(inv.Stdout, "Shutting down provisioner daemon %d...\n", id) } err := shutdownWithTimeout(provisionerDaemon.Shutdown, 5*time.Second) if err != nil { - cmd.PrintErrf("Failed to shutdown provisioner daemon %d: %s\n", id, err) + cliui.Errorf(inv.Stderr, "Failed to shutdown provisioner daemon %d: %s\n", id, err) return } err = provisionerDaemon.Close() if err != nil { - cmd.PrintErrf("Close provisioner daemon %d: %s\n", id, err) + cliui.Errorf(inv.Stderr, "Close provisioner daemon %d: %s\n", id, err) return } - if ok, _ := cmd.Flags().GetBool(varVerbose); ok { - cmd.Printf("Gracefully shut down provisioner daemon %d\n", id) + if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok { + cliui.Infof(inv.Stdout, "Gracefully shut down provisioner daemon %d\n", id) } }() } wg.Wait() - cmd.Println("Waiting for WebSocket connections to close...") + cliui.Info(inv.Stdout, "Waiting for WebSocket connections to close..."+"\n") _ = coderAPICloser.Close() - cmd.Println("Done waiting for WebSocket connections") + cliui.Info(inv.Stdout, "Done waiting for WebSocket connections"+"\n") // Close tunnel after we no longer have in-flight connections. if tunnel != nil { - cmd.Println("Waiting for tunnel to close...") + cliui.Infof(inv.Stdout, "Waiting for tunnel to close...") _ = tunnel.Close() <-tunnel.Wait() - cmd.Println("Done waiting for tunnel") + cliui.Infof(inv.Stdout, "Done waiting for tunnel") } // Ensures a last report can be sent before exit! @@ -1168,40 +1111,49 @@ flags, and YAML configuration. The precedence is as follows: // Trigger context cancellation for any remaining services. cancel() - if xerrors.Is(exitErr, context.Canceled) { + switch { + case xerrors.Is(exitErr, context.DeadlineExceeded): + cliui.Warnf(inv.Stderr, "Graceful shutdown timed out") + // Errors here cause a significant number of benign CI failures. + return nil + case xerrors.Is(exitErr, context.Canceled): + return nil + case exitErr != nil: + return xerrors.Errorf("graceful shutdown: %w", exitErr) + default: return nil } - return exitErr }, } var pgRawURL bool - postgresBuiltinURLCmd := &cobra.Command{ + + postgresBuiltinURLCmd := &clibase.Cmd{ Use: "postgres-builtin-url", Short: "Output the connection URL for the built-in PostgreSQL deployment.", - RunE: func(cmd *cobra.Command, _ []string) error { - cfg := createConfig(cmd) - url, err := embeddedPostgresURL(cfg) + Handler: func(inv *clibase.Invocation) error { + url, err := embeddedPostgresURL(r.createConfig()) if err != nil { return err } if pgRawURL { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url) + _, _ = fmt.Fprintf(inv.Stdout, "%s\n", url) } else { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) + _, _ = fmt.Fprintf(inv.Stdout, "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) } return nil }, } - postgresBuiltinServeCmd := &cobra.Command{ + + postgresBuiltinServeCmd := &clibase.Cmd{ Use: "postgres-builtin-serve", Short: "Run the built-in PostgreSQL deployment.", - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() - cfg := createConfig(cmd) - logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) - if ok, _ := cmd.Flags().GetBool(varVerbose); ok { + cfg := r.createConfig() + logger := slog.Make(sloghuman.Sink(inv.Stderr)) + if ok, _ := inv.ParsedFlags().GetBool(varVerbose); ok { logger = logger.Leveled(slog.LevelDebug) } @@ -1215,25 +1167,34 @@ flags, and YAML configuration. The precedence is as follows: defer func() { _ = closePg() }() if pgRawURL { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", url) + _, _ = fmt.Fprintf(inv.Stdout, "%s\n", url) } else { - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) + _, _ = fmt.Fprintf(inv.Stdout, "%s\n", cliui.Styles.Code.Render(fmt.Sprintf("psql %q", url))) } <-ctx.Done() return nil }, } - postgresBuiltinURLCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.") - postgresBuiltinServeCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.") - createAdminUserCommand := newCreateAdminUserCommand() - root.SetHelpFunc(func(cmd *cobra.Command, args []string) { - // Help is handled by clibase in command body. - }) - root.AddCommand(postgresBuiltinURLCmd, postgresBuiltinServeCmd, createAdminUserCommand) + createAdminUserCmd := r.newCreateAdminUserCommand() + + rawURLOpt := clibase.Option{ + Flag: "raw-url", + + Value: clibase.BoolOf(&pgRawURL), + Description: "Output the raw connection URL instead of a psql command.", + } + createAdminUserCmd.Options.Add(rawURLOpt) + postgresBuiltinURLCmd.Options.Add(rawURLOpt) + postgresBuiltinServeCmd.Options.Add(rawURLOpt) + + serverCmd.Children = append( + serverCmd.Children, + createAdminUserCmd, postgresBuiltinURLCmd, postgresBuiltinServeCmd, + ) - return root + return serverCmd } // isLocalURL returns true if the hostname of the provided URL appears to @@ -1269,6 +1230,7 @@ func newProvisionerDaemon( cacheDir string, errCh chan error, dev bool, + wg *sync.WaitGroup, ) (srv *provisionerd.Server, err error) { ctx, cancel := context.WithCancel(ctx) defer func() { @@ -1283,12 +1245,16 @@ func newProvisionerDaemon( } terraformClient, terraformServer := provisionersdk.MemTransportPipe() + wg.Add(1) go func() { + defer wg.Done() <-ctx.Done() _ = terraformClient.Close() _ = terraformServer.Close() }() + wg.Add(1) go func() { + defer wg.Done() defer cancel() err := terraform.Serve(ctx, &terraform.ServeOptions{ @@ -1317,12 +1283,16 @@ func newProvisionerDaemon( // include echo provisioner when in dev mode if dev { echoClient, echoServer := provisionersdk.MemTransportPipe() + wg.Add(1) go func() { + defer wg.Done() <-ctx.Done() _ = echoClient.Close() _ = echoServer.Close() }() + wg.Add(1) go func() { + defer wg.Done() defer cancel() err := echo.Serve(ctx, afero.NewOsFs(), &provisionersdk.ServeOptions{Listener: echoServer}) @@ -1355,13 +1325,13 @@ func newProvisionerDaemon( } // nolint: revive -func printLogo(cmd *cobra.Command) { +func printLogo(inv *clibase.Invocation) { // Only print the logo in TTYs. - if !isTTYOut(cmd) { + if !isTTYOut(inv) { return } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "%s - Your Self-Hosted Remote Development Platform\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version())) + _, _ = fmt.Fprintf(inv.Stdout, "%s - Your Self-Hosted Remote Development Platform\n", cliui.Styles.Bold.Render("Coder "+buildinfo.Version())) } func loadCertificates(tlsCertFiles, tlsKeyFiles []string) ([]tls.Certificate, error) { @@ -1760,7 +1730,7 @@ func isLocalhost(host string) bool { return host == "localhost" || host == "127.0.0.1" || host == "::1" } -func buildLogger(cmd *cobra.Command, cfg *codersdk.DeploymentValues) (slog.Logger, func(), error) { +func buildLogger(inv *clibase.Invocation, cfg *codersdk.DeploymentValues) (slog.Logger, func(), error) { var ( sinks = []slog.Sink{} closers = []func() error{} @@ -1771,10 +1741,10 @@ func buildLogger(cmd *cobra.Command, cfg *codersdk.DeploymentValues) (slog.Logge case "": case "/dev/stdout": - sinks = append(sinks, sinkFn(cmd.OutOrStdout())) + sinks = append(sinks, sinkFn(inv.Stdout)) case "/dev/stderr": - sinks = append(sinks, sinkFn(cmd.ErrOrStderr())) + sinks = append(sinks, sinkFn(inv.Stderr)) default: fi, err := os.OpenFile(loc, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0o644) diff --git a/cli/server_createadminuser.go b/cli/server_createadminuser.go index d21a7f07cce1e..c947675e287bb 100644 --- a/cli/server_createadminuser.go +++ b/cli/server_createadminuser.go @@ -4,16 +4,15 @@ package cli import ( "fmt" - "os" "os/signal" "sort" "github.com/google/uuid" - "github.com/spf13/cobra" "golang.org/x/xerrors" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/gitsshkey" @@ -23,7 +22,7 @@ import ( "github.com/coder/coder/codersdk" ) -func newCreateAdminUserCommand() *cobra.Command { +func (r *RootCmd) newCreateAdminUserCommand() *clibase.Cmd { var ( newUserDBURL string newUserSSHKeygenAlgorithm string @@ -31,36 +30,20 @@ func newCreateAdminUserCommand() *cobra.Command { newUserEmail string newUserPassword string ) - createAdminUserCommand := &cobra.Command{ + createAdminUserCommand := &clibase.Cmd{ Use: "create-admin-user", Short: "Create a new admin user with the given username, email and password and adds it to every organization.", - RunE: func(cmd *cobra.Command, args []string) error { - ctx := cmd.Context() + Handler: func(inv *clibase.Invocation) error { + ctx := inv.Context() sshKeygenAlgorithm, err := gitsshkey.ParseAlgorithm(newUserSSHKeygenAlgorithm) if err != nil { return xerrors.Errorf("parse ssh keygen algorithm %q: %w", newUserSSHKeygenAlgorithm, err) } - if val, exists := os.LookupEnv("CODER_POSTGRES_URL"); exists { - newUserDBURL = val - } - if val, exists := os.LookupEnv("CODER_SSH_KEYGEN_ALGORITHM"); exists { - newUserSSHKeygenAlgorithm = val - } - if val, exists := os.LookupEnv("CODER_USERNAME"); exists { - newUserUsername = val - } - if val, exists := os.LookupEnv("CODER_EMAIL"); exists { - newUserEmail = val - } - if val, exists := os.LookupEnv("CODER_PASSWORD"); exists { - newUserPassword = val - } - - cfg := createConfig(cmd) - logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) - if ok, _ := cmd.Flags().GetBool(varVerbose); ok { + cfg := r.createConfig() + logger := slog.Make(sloghuman.Sink(inv.Stderr)) + if r.verbose { logger = logger.Leveled(slog.LevelDebug) } @@ -68,7 +51,7 @@ func newCreateAdminUserCommand() *cobra.Command { defer cancel() if newUserDBURL == "" { - cmd.Printf("Using built-in PostgreSQL (%s)\n", cfg.PostgresPath()) + cliui.Infof(inv.Stdout, "Using built-in PostgreSQL (%s)\n", cfg.PostgresPath()) url, closePg, err := startBuiltinPostgres(ctx, cfg, logger) if err != nil { return err @@ -110,7 +93,7 @@ func newCreateAdminUserCommand() *cobra.Command { } if newUserUsername == "" { - newUserUsername, err = cliui.Prompt(cmd, cliui.PromptOptions{ + newUserUsername, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Username", Validate: func(val string) error { if val == "" { @@ -124,7 +107,7 @@ func newCreateAdminUserCommand() *cobra.Command { } } if newUserEmail == "" { - newUserEmail, err = cliui.Prompt(cmd, cliui.PromptOptions{ + newUserEmail, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Email", Validate: func(val string) error { if val == "" { @@ -138,7 +121,7 @@ func newCreateAdminUserCommand() *cobra.Command { } } if newUserPassword == "" { - newUserPassword, err = cliui.Prompt(cmd, cliui.PromptOptions{ + newUserPassword, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Password", Secret: true, Validate: func(val string) error { @@ -153,7 +136,7 @@ func newCreateAdminUserCommand() *cobra.Command { } // Prompt again. - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm password", Secret: true, Validate: func(val string) error { @@ -191,7 +174,7 @@ func newCreateAdminUserCommand() *cobra.Command { return orgs[i].Name < orgs[j].Name }) - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Creating user...") + _, _ = fmt.Fprintln(inv.Stderr, "Creating user...") newUser, err = tx.InsertUser(ctx, database.InsertUserParams{ ID: uuid.New(), Email: newUserEmail, @@ -206,7 +189,7 @@ func newCreateAdminUserCommand() *cobra.Command { return xerrors.Errorf("insert user: %w", err) } - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Generating user SSH key...") + _, _ = fmt.Fprintln(inv.Stderr, "Generating user SSH key...") privateKey, publicKey, err := gitsshkey.Generate(sshKeygenAlgorithm) if err != nil { return xerrors.Errorf("generate user gitsshkey: %w", err) @@ -223,7 +206,7 @@ func newCreateAdminUserCommand() *cobra.Command { } for _, org := range orgs { - _, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Adding user to organization %q (%s) as admin...\n", org.Name, org.ID.String()) + _, _ = fmt.Fprintf(inv.Stderr, "Adding user to organization %q (%s) as admin...\n", org.Name, org.ID.String()) _, err := tx.InsertOrganizationMember(ctx, database.InsertOrganizationMemberParams{ OrganizationID: org.ID, UserID: newUser.ID, @@ -242,21 +225,50 @@ func newCreateAdminUserCommand() *cobra.Command { return err } - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "User created successfully.") - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "ID: "+newUser.ID.String()) - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Username: "+newUser.Username) - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Email: "+newUser.Email) - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), "Password: ********") + _, _ = fmt.Fprintln(inv.Stderr, "") + _, _ = fmt.Fprintln(inv.Stderr, "User created successfully.") + _, _ = fmt.Fprintln(inv.Stderr, "ID: "+newUser.ID.String()) + _, _ = fmt.Fprintln(inv.Stderr, "Username: "+newUser.Username) + _, _ = fmt.Fprintln(inv.Stderr, "Email: "+newUser.Email) + _, _ = fmt.Fprintln(inv.Stderr, "Password: ********") return nil }, } - createAdminUserCommand.Flags().StringVar(&newUserDBURL, "postgres-url", "", "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case). Consumes $CODER_POSTGRES_URL.") - createAdminUserCommand.Flags().StringVar(&newUserSSHKeygenAlgorithm, "ssh-keygen-algorithm", "ed25519", "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\". Consumes $CODER_SSH_KEYGEN_ALGORITHM.") - createAdminUserCommand.Flags().StringVar(&newUserUsername, "username", "", "The username of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_USERNAME.") - createAdminUserCommand.Flags().StringVar(&newUserEmail, "email", "", "The email of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_EMAIL.") - createAdminUserCommand.Flags().StringVar(&newUserPassword, "password", "", "The password of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_PASSWORD.") + + createAdminUserCommand.Options.Add( + clibase.Option{ + Env: "CODER_POSTGRES_URL", + Flag: "postgres-url", + Description: "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case).", + Value: clibase.StringOf(&newUserDBURL), + }, + clibase.Option{ + Env: "CODER_SSH_KEYGEN_ALGORITHM", + Flag: "ssh-keygen-algorithm", + Description: "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\".", + Default: "ed25519", + Value: clibase.StringOf(&newUserSSHKeygenAlgorithm), + }, + clibase.Option{ + Env: "CODER_USERNAME", + Flag: "username", + Description: "The username of the new user. If not specified, you will be prompted via stdin.", + Value: clibase.StringOf(&newUserUsername), + }, + clibase.Option{ + Env: "CODER_EMAIL", + Flag: "email", + Description: "The email of the new user. If not specified, you will be prompted via stdin.", + Value: clibase.StringOf(&newUserEmail), + }, + clibase.Option{ + Env: "CODER_PASSWORD", + Flag: "password", + Description: "The password of the new user. If not specified, you will be prompted via stdin.", + Value: clibase.StringOf(&newUserPassword), + }, + ) return createAdminUserCommand } diff --git a/cli/server_createadminuser_test.go b/cli/server_createadminuser_test.go index d222d122bff50..3d6bebf2a4a15 100644 --- a/cli/server_createadminuser_test.go +++ b/cli/server_createadminuser_test.go @@ -92,9 +92,7 @@ func TestServerCreateAdminUser(t *testing.T) { defer sqlDB.Close() db := database.New(sqlDB) - // Sometimes generating SSH keys takes a really long time if there isn't - // enough entropy. We don't want the tests to fail in these cases. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() pingCtx, pingCancel := context.WithTimeout(ctx, testutil.WaitShort) @@ -120,7 +118,7 @@ func TestServerCreateAdminUser(t *testing.T) { }) require.NoError(t, err) - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "create-admin-user", "--postgres-url", connectionURL, "--ssh-keygen-algorithm", "ed25519", @@ -129,14 +127,9 @@ func TestServerCreateAdminUser(t *testing.T) { "--password", password, ) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - errC := make(chan error, 1) - go func() { - err := root.ExecuteContext(ctx) - t.Log("root.ExecuteContext() returned:", err) - errC <- err - }() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + clitest.Start(t, inv) pty.ExpectMatchContext(ctx, "Creating user...") pty.ExpectMatchContext(ctx, "Generating user SSH key...") @@ -147,13 +140,11 @@ func TestServerCreateAdminUser(t *testing.T) { pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, "****") - require.NoError(t, <-errC) - verifyUser(t, connectionURL, username, email, password) }) - //nolint:paralleltest t.Run("Env", func(t *testing.T) { + t.Parallel() if runtime.GOOS != "linux" || testing.Short() { // Skip on non-Linux because it spawns a PostgreSQL instance. t.SkipNow() @@ -162,35 +153,26 @@ func TestServerCreateAdminUser(t *testing.T) { require.NoError(t, err) defer closeFunc() - // Sometimes generating SSH keys takes a really long time if there isn't - // enough entropy. We don't want the tests to fail in these cases. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() - t.Setenv("CODER_POSTGRES_URL", connectionURL) - t.Setenv("CODER_SSH_KEYGEN_ALGORITHM", "ed25519") - t.Setenv("CODER_USERNAME", username) - t.Setenv("CODER_EMAIL", email) - t.Setenv("CODER_PASSWORD", password) + inv, _ := clitest.New(t, "server", "create-admin-user") + inv.Environ.Set("CODER_POSTGRES_URL", connectionURL) + inv.Environ.Set("CODER_SSH_KEYGEN_ALGORITHM", "ed25519") + inv.Environ.Set("CODER_USERNAME", username) + inv.Environ.Set("CODER_EMAIL", email) + inv.Environ.Set("CODER_PASSWORD", password) - root, _ := clitest.New(t, "server", "create-admin-user") pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - errC := make(chan error, 1) - go func() { - err := root.ExecuteContext(ctx) - t.Log("root.ExecuteContext() returned:", err) - errC <- err - }() + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + clitest.Start(t, inv) pty.ExpectMatchContext(ctx, "User created successfully.") pty.ExpectMatchContext(ctx, username) pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, "****") - require.NoError(t, <-errC) - verifyUser(t, connectionURL, username, email, password) }) @@ -205,34 +187,25 @@ func TestServerCreateAdminUser(t *testing.T) { require.NoError(t, err) defer closeFunc() - // Sometimes generating SSH keys takes a really long time if there isn't - // enough entropy. We don't want the tests to fail in these cases. - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitSuperLong) + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitMedium) defer cancel() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "create-admin-user", "--postgres-url", connectionURL, "--ssh-keygen-algorithm", "ed25519", ) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - errC := make(chan error, 1) - go func() { - err := root.ExecuteContext(ctx) - t.Log("root.ExecuteContext() returned:", err) - errC <- err - }() - - pty.ExpectMatchContext(ctx, "> Username") + pty := ptytest.New(t).Attach(inv) + + clitest.Start(t, inv) + + pty.ExpectMatchContext(ctx, "Username") pty.WriteLine(username) - pty.ExpectMatchContext(ctx, "> Email") + pty.ExpectMatchContext(ctx, "Email") pty.WriteLine(email) - pty.ExpectMatchContext(ctx, "> Password") + pty.ExpectMatchContext(ctx, "Password") pty.WriteLine(password) - pty.ExpectMatchContext(ctx, "> Confirm password") + pty.ExpectMatchContext(ctx, "Confirm password") pty.WriteLine(password) pty.ExpectMatchContext(ctx, "User created successfully.") @@ -240,8 +213,6 @@ func TestServerCreateAdminUser(t *testing.T) { pty.ExpectMatchContext(ctx, email) pty.ExpectMatchContext(ctx, "****") - require.NoError(t, <-errC) - verifyUser(t, connectionURL, username, email, password) }) @@ -267,10 +238,10 @@ func TestServerCreateAdminUser(t *testing.T) { "--password", "x", ) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) + root.Stdout = pty.Output() + root.Stderr = pty.Output() - err = root.ExecuteContext(ctx) + err = root.WithContext(ctx).Run() require.Error(t, err) require.ErrorContains(t, err, "'email' failed on the 'email' tag") require.ErrorContains(t, err, "'username' failed on the 'username' tag") diff --git a/cli/server_slim.go b/cli/server_slim.go index b54cf8c88d52b..417a32ff13ae7 100644 --- a/cli/server_slim.go +++ b/cli/server_slim.go @@ -8,72 +8,24 @@ import ( "io" "os" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd" ) -func Server(_ func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *cobra.Command { - root := &cobra.Command{ - Use: "server", - Short: "Start a Coder server", - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - serverUnsupported(cmd.ErrOrStderr()) - return nil - }, - } - - var pgRawURL bool - postgresBuiltinURLCmd := &cobra.Command{ - Use: "postgres-builtin-url", - Short: "Output the connection URL for the built-in PostgreSQL deployment.", - Hidden: true, - RunE: func(cmd *cobra.Command, _ []string) error { - serverUnsupported(cmd.ErrOrStderr()) - return nil - }, - } - postgresBuiltinServeCmd := &cobra.Command{ - Use: "postgres-builtin-serve", - Short: "Run the built-in PostgreSQL deployment.", - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - serverUnsupported(cmd.ErrOrStderr()) - return nil - }, - } - - var ( - newUserDBURL string - newUserSSHKeygenAlgorithm string - newUserUsername string - newUserEmail string - newUserPassword string - ) - createAdminUserCommand := &cobra.Command{ - Use: "create-admin-user", - Short: "Create a new admin user with the given username, email and password and adds it to every organization.", - Hidden: true, - RunE: func(cmd *cobra.Command, args []string) error { - serverUnsupported(cmd.ErrOrStderr()) +func (r *RootCmd) Server(_ func(context.Context, *coderd.Options) (*coderd.API, io.Closer, error)) *clibase.Cmd { + root := &clibase.Cmd{ + Use: "server", + Short: "Start a Coder server", + // We accept RawArgs so all commands and flags are accepted. + RawArgs: true, + Hidden: true, + Handler: func(inv *clibase.Invocation) error { + serverUnsupported(inv.Stderr) return nil }, } - // We still have to attach the flags to the commands so users don't get - // an error when they try to use them. - postgresBuiltinURLCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.") - postgresBuiltinServeCmd.Flags().BoolVar(&pgRawURL, "raw-url", false, "Output the raw connection URL instead of a psql command.") - createAdminUserCommand.Flags().StringVar(&newUserDBURL, "postgres-url", "", "URL of a PostgreSQL database. If empty, the built-in PostgreSQL deployment will be used (Coder must not be already running in this case). Consumes $CODER_POSTGRES_URL.") - createAdminUserCommand.Flags().StringVar(&newUserSSHKeygenAlgorithm, "ssh-keygen-algorithm", "ed25519", "The algorithm to use for generating ssh keys. Accepted values are \"ed25519\", \"ecdsa\", or \"rsa4096\". Consumes $CODER_SSH_KEYGEN_ALGORITHM.") - createAdminUserCommand.Flags().StringVar(&newUserUsername, "username", "", "The username of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_USERNAME.") - createAdminUserCommand.Flags().StringVar(&newUserEmail, "email", "", "The email of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_EMAIL.") - createAdminUserCommand.Flags().StringVar(&newUserPassword, "password", "", "The password of the new user. If not specified, you will be prompted via stdin. Consumes $CODER_PASSWORD.") - - root.AddCommand(postgresBuiltinURLCmd, postgresBuiltinServeCmd, createAdminUserCommand) - return root } diff --git a/cli/server_test.go b/cli/server_test.go index cd887c502cd19..757d38fd91457 100644 --- a/cli/server_test.go +++ b/cli/server_test.go @@ -108,78 +108,66 @@ func TestServer(t *testing.T) { connectionURL, closeFunc, err := postgres.Open() require.NoError(t, err) defer closeFunc() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - root, cfg := clitest.New(t, + // Postgres + race detector + CI = slow. + ctx := testutil.Context(t, testutil.WaitSuperLong*3) + + inv, cfg := clitest.New(t, "server", "--http-address", ":0", "--access-url", "http://example.com", "--postgres-url", connectionURL, "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + clitest.Start(t, inv.WithContext(ctx)) accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) _, err = client.CreateFirstUser(ctx, coderdtest.FirstUserParams) require.NoError(t, err) - cancelFunc() - require.NoError(t, <-errC) }) t.Run("BuiltinPostgres", func(t *testing.T) { t.Parallel() if testing.Short() { t.SkipNow() } - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--http-address", ":0", "--access-url", "http://example.com", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + + const superDuperLong = testutil.WaitSuperLong * 3 + + ctx := testutil.Context(t, superDuperLong) + clitest.Start(t, inv.WithContext(ctx)) + //nolint:gocritic // Embedded postgres take a while to fire up. require.Eventually(t, func() bool { rawURL, err := cfg.URL().Read() return err == nil && rawURL != "" - }, 3*time.Minute, testutil.IntervalFast, "failed to get access URL") - cancelFunc() - require.NoError(t, <-errC) + }, superDuperLong, testutil.IntervalFast, "failed to get access URL") }) t.Run("BuiltinPostgresURL", func(t *testing.T) { t.Parallel() root, _ := clitest.New(t, "server", "postgres-builtin-url") pty := ptytest.New(t) - root.SetOutput(pty.Output()) - err := root.Execute() + root.Stdout = pty.Output() + err := root.Run() require.NoError(t, err) pty.ExpectMatch("psql") }) t.Run("BuiltinPostgresURLRaw", func(t *testing.T) { t.Parallel() - ctx, _ := testutil.Context(t) + ctx := testutil.Context(t, testutil.WaitLong) root, _ := clitest.New(t, "server", "postgres-builtin-url", "--raw-url") pty := ptytest.New(t) - root.SetOutput(pty.Output()) - err := root.ExecuteContext(ctx) + root.Stdout = pty.Output() + err := root.WithContext(ctx).Run() require.NoError(t, err) got := pty.ReadLine(ctx) @@ -192,93 +180,62 @@ func TestServer(t *testing.T) { // reachable. t.Run("LocalAccessURL", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--http-address", ":0", "--access-url", "http://localhost:3000/", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) // Just wait for startup _ = waitAccessURL(t, cfg) pty.ExpectMatch("this may cause unexpected problems when creating workspaces") pty.ExpectMatch("View the Web UI: http://localhost:3000/") - - cancelFunc() - require.NoError(t, <-errC) }) // Validate that an https scheme is prepended to a remote access URL // and that a warning is printed for a host that cannot be resolved. t.Run("RemoteAccessURL", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--http-address", ":0", "--access-url", "https://foobarbaz.mydomain", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + pty := ptytest.New(t).Attach(inv) + + clitest.Start(t, inv) // Just wait for startup _ = waitAccessURL(t, cfg) pty.ExpectMatch("this may cause unexpected problems when creating workspaces") pty.ExpectMatch("View the Web UI: https://foobarbaz.mydomain") - - cancelFunc() - require.NoError(t, <-errC) }) t.Run("NoWarningWithRemoteAccessURL", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--http-address", ":0", "--access-url", "https://google.com", "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetIn(pty.Input()) - root.SetOut(pty.Output()) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) // Just wait for startup _ = waitAccessURL(t, cfg) pty.ExpectMatch("View the Web UI: https://google.com") - - cancelFunc() - require.NoError(t, <-errC) }) t.Run("NoSchemeAccessURL", func(t *testing.T) { @@ -293,7 +250,7 @@ func TestServer(t *testing.T) { "--access-url", "google.com", "--cache-dir", t.TempDir(), ) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() require.Error(t, err) }) @@ -312,7 +269,7 @@ func TestServer(t *testing.T) { "--tls-min-version", "tls9", "--cache-dir", t.TempDir(), ) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() require.Error(t, err) }) t.Run("TLSBadClientAuth", func(t *testing.T) { @@ -330,7 +287,7 @@ func TestServer(t *testing.T) { "--tls-client-auth", "something", "--cache-dir", t.TempDir(), ) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() require.Error(t, err) }) t.Run("TLSInvalid", func(t *testing.T) { @@ -382,7 +339,7 @@ func TestServer(t *testing.T) { } args = append(args, c.args...) root, _ := clitest.New(t, args...) - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() require.Error(t, err) t.Logf("args: %v", args) require.ErrorContains(t, err, c.errContains) @@ -406,7 +363,7 @@ func TestServer(t *testing.T) { "--tls-key-file", keyPath, "--cache-dir", t.TempDir(), ) - clitest.Start(ctx, t, root) + clitest.Start(t, root.WithContext(ctx)) // Verify HTTPS accessURL := waitAccessURL(t, cfg) @@ -445,8 +402,8 @@ func TestServer(t *testing.T) { "--cache-dir", t.TempDir(), ) pty := ptytest.New(t) - root.SetOut(pty.Output()) - clitest.Start(ctx, t, root) + root.Stdout = pty.Output() + clitest.Start(t, root.WithContext(ctx)) accessURL := waitAccessURL(t, cfg) require.Equal(t, "https", accessURL.Scheme) @@ -511,7 +468,7 @@ func TestServer(t *testing.T) { defer cancelFunc() certPath, keyPath := generateTLSCertificate(t) - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--http-address", ":0", @@ -523,17 +480,11 @@ func TestServer(t *testing.T) { "--tls-key-file", keyPath, "--cache-dir", t.TempDir(), ) - pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) // We can't use waitAccessURL as it will only return the HTTP URL. - const httpLinePrefix = "Started HTTP listener at " + const httpLinePrefix = "Started HTTP listener at" pty.ExpectMatch(httpLinePrefix) httpLine := pty.ReadLine(ctx) httpAddr := strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) @@ -572,9 +523,6 @@ func TestServer(t *testing.T) { defer client.HTTPClient.CloseIdleConnections() _, err = client.HasFirstUser(ctx) require.NoError(t, err) - - cancelFunc() - require.NoError(t, <-errC) }) t.Run("TLSRedirect", func(t *testing.T) { @@ -670,15 +618,11 @@ func TestServer(t *testing.T) { flags = append(flags, "--redirect-to-access-url") } - root, _ := clitest.New(t, flags...) + inv, _ := clitest.New(t, flags...) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) + pty.Attach(inv) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + clitest.Start(t, inv) var ( httpAddr string @@ -686,14 +630,14 @@ func TestServer(t *testing.T) { ) // We can't use waitAccessURL as it will only return the HTTP URL. if c.httpListener { - const httpLinePrefix = "Started HTTP listener at " + const httpLinePrefix = "Started HTTP listener at" pty.ExpectMatch(httpLinePrefix) httpLine := pty.ReadLine(ctx) httpAddr = strings.TrimSpace(strings.TrimPrefix(httpLine, httpLinePrefix)) require.NotEmpty(t, httpAddr) } if c.tlsListener { - const tlsLinePrefix = "Started TLS/HTTPS listener at " + const tlsLinePrefix = "Started TLS/HTTPS listener at" pty.ExpectMatch(tlsLinePrefix) tlsLine := pty.ReadLine(ctx) tlsAddr = strings.TrimSpace(strings.TrimPrefix(tlsLine, tlsLinePrefix)) @@ -742,8 +686,6 @@ func TestServer(t *testing.T) { if err != nil { require.ErrorContains(t, err, "Invalid application URL") } - cancelFunc() - require.NoError(t, <-errC) } }) } @@ -762,18 +704,19 @@ func TestServer(t *testing.T) { ) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) + root.Stdout = pty.Output() + root.Stderr = pty.Output() serverStop := make(chan error, 1) go func() { - err := root.ExecuteContext(ctx) + err := root.WithContext(ctx).Run() if err != nil { t.Error(err) } close(serverStop) }() - pty.ExpectMatch("Started HTTP listener at http://0.0.0.0:") + pty.ExpectMatch("Started HTTP listener") + pty.ExpectMatch("http://0.0.0.0:") cancelFunc() <-serverStop @@ -781,32 +724,19 @@ func TestServer(t *testing.T) { t.Run("CanListenUnspecifiedv6", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--http-address", "[::]:0", "--access-url", "http://example.com", ) - pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - serverClose := make(chan struct{}, 1) - go func() { - err := root.ExecuteContext(ctx) - if err != nil { - t.Error(err) - } - close(serverClose) - }() - - pty.ExpectMatch("Started HTTP listener at http://[::]:") + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) - cancelFunc() - <-serverClose + pty.ExpectMatch("Started HTTP listener at") + pty.ExpectMatch("http://[::]:") }) t.Run("NoAddress", func(t *testing.T) { @@ -814,14 +744,14 @@ func TestServer(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--http-address", ":80", "--tls-enable=false", "--tls-address", "", ) - err := root.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() require.Error(t, err) require.ErrorContains(t, err, "tls-address") }) @@ -831,13 +761,13 @@ func TestServer(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--tls-enable=true", "--tls-address", "", ) - err := root.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() require.Error(t, err) require.ErrorContains(t, err, "must not be empty") }) @@ -854,7 +784,7 @@ func TestServer(t *testing.T) { ctx, cancelFunc := context.WithCancel(context.Background()) defer cancelFunc() - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--address", ":0", @@ -862,9 +792,9 @@ func TestServer(t *testing.T) { "--cache-dir", t.TempDir(), ) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - clitest.Start(ctx, t, root) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() + clitest.Start(t, inv.WithContext(ctx)) pty.ExpectMatch("is deprecated") @@ -892,9 +822,9 @@ func TestServer(t *testing.T) { "--cache-dir", t.TempDir(), ) pty := ptytest.New(t) - root.SetOutput(pty.Output()) - root.SetErr(pty.Output()) - clitest.Start(ctx, t, root) + root.Stdout = pty.Output() + root.Stderr = pty.Output() + clitest.Start(t, root.WithContext(ctx)) pty.ExpectMatch("is deprecated") @@ -935,7 +865,7 @@ func TestServer(t *testing.T) { ) serverErr := make(chan error, 1) go func() { - serverErr <- root.ExecuteContext(ctx) + serverErr <- root.WithContext(ctx).Run() }() _ = waitAccessURL(t, cfg) currentProcess, err := os.FindProcess(os.Getpid()) @@ -949,10 +879,8 @@ func TestServer(t *testing.T) { }) t.Run("TracerNoLeak", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--http-address", ":0", @@ -960,18 +888,14 @@ func TestServer(t *testing.T) { "--trace=true", "--cache-dir", t.TempDir(), ) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() - cancelFunc() - require.NoError(t, <-errC) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + clitest.Start(t, inv.WithContext(ctx)) + cancel() require.Error(t, goleak.Find()) }) t.Run("Telemetry", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() deployment := make(chan struct{}, 64) snapshot := make(chan *telemetry.Snapshot, 64) @@ -990,7 +914,7 @@ func TestServer(t *testing.T) { server := httptest.NewServer(r) defer server.Close() - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--in-memory", "--http-address", ":0", @@ -999,21 +923,13 @@ func TestServer(t *testing.T) { "--telemetry-url", server.URL, "--cache-dir", t.TempDir(), ) - errC := make(chan error, 1) - go func() { - errC <- root.ExecuteContext(ctx) - }() + clitest.Start(t, inv) <-deployment <-snapshot - cancelFunc() - <-errC }) t.Run("Prometheus", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() - random, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) _ = random.Close() @@ -1021,7 +937,7 @@ func TestServer(t *testing.T) { require.True(t, valid) randomPort := tcpAddr.Port - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--http-address", ":0", @@ -1031,10 +947,11 @@ func TestServer(t *testing.T) { "--prometheus-address", ":"+strconv.Itoa(randomPort), "--cache-dir", t.TempDir(), ) - serverErr := make(chan error, 1) - go func() { - serverErr <- root.ExecuteContext(ctx) - }() + + ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) + defer cancel() + + clitest.Start(t, inv) _ = waitAccessURL(t, cfg) var res *http.Response @@ -1045,6 +962,7 @@ func TestServer(t *testing.T) { res, err = http.DefaultClient.Do(req) return err == nil }, testutil.WaitShort, testutil.IntervalFast) + defer res.Body.Close() scanner := bufio.NewScanner(res.Body) hasActiveUsers := false @@ -1065,16 +983,12 @@ func TestServer(t *testing.T) { require.NoError(t, scanner.Err()) require.True(t, hasActiveUsers) require.True(t, hasWorkspaces) - cancelFunc() - <-serverErr }) t.Run("GitHubOAuth", func(t *testing.T) { t.Parallel() - ctx, cancelFunc := context.WithCancel(context.Background()) - defer cancelFunc() fakeRedirect := "https://fake-url.com" - root, cfg := clitest.New(t, + inv, cfg := clitest.New(t, "server", "--in-memory", "--http-address", ":0", @@ -1084,10 +998,7 @@ func TestServer(t *testing.T) { "--oauth2-github-client-secret", "fake", "--oauth2-github-enterprise-base-url", fakeRedirect, ) - serverErr := make(chan error, 1) - go func() { - serverErr <- root.ExecuteContext(ctx) - }() + clitest.Start(t, inv) accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) client.HTTPClient.CheckRedirect = func(req *http.Request, via []*http.Request) error { @@ -1095,7 +1006,7 @@ func TestServer(t *testing.T) { } githubURL, err := accessURL.Parse("/api/v2/users/oauth2/github") require.NoError(t, err) - req, err := http.NewRequestWithContext(ctx, http.MethodGet, githubURL.String(), nil) + req, err := http.NewRequestWithContext(inv.Context(), http.MethodGet, githubURL.String(), nil) require.NoError(t, err) res, err := client.HTTPClient.Do(req) require.NoError(t, err) @@ -1103,8 +1014,6 @@ func TestServer(t *testing.T) { fakeURL, err := res.Location() require.NoError(t, err) require.True(t, strings.HasPrefix(fakeURL.String(), fakeRedirect), fakeURL.String()) - cancelFunc() - <-serverErr }) t.Run("RateLimit", func(t *testing.T) { @@ -1123,7 +1032,7 @@ func TestServer(t *testing.T) { ) serverErr := make(chan error, 1) go func() { - serverErr <- root.ExecuteContext(ctx) + serverErr <- root.WithContext(ctx).Run() }() accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) @@ -1152,7 +1061,7 @@ func TestServer(t *testing.T) { ) serverErr := make(chan error, 1) go func() { - serverErr <- root.ExecuteContext(ctx) + serverErr <- root.WithContext(ctx).Run() }() accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) @@ -1180,7 +1089,7 @@ func TestServer(t *testing.T) { ) serverErr := make(chan error, 1) go func() { - serverErr <- root.ExecuteContext(ctx) + serverErr <- root.WithContext(ctx).Run() }() accessURL := waitAccessURL(t, cfg) client := codersdk.New(accessURL) @@ -1230,9 +1139,9 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", "--log-human", fiName, ) - clitest.Start(context.Background(), t, root) + clitest.Start(t, root) - waitFile(t, fiName, testutil.WaitShort) + waitFile(t, fiName, testutil.WaitLong) }) t.Run("Human", func(t *testing.T) { @@ -1247,7 +1156,7 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", "--log-human", fi, ) - clitest.Start(context.Background(), t, root) + clitest.Start(t, root) waitFile(t, fi, testutil.WaitShort) }) @@ -1264,7 +1173,7 @@ func TestServer(t *testing.T) { "--access-url", "http://example.com", "--log-json", fi, ) - clitest.Start(context.Background(), t, root) + clitest.Start(t, root) waitFile(t, fi, testutil.WaitShort) }) @@ -1276,7 +1185,7 @@ func TestServer(t *testing.T) { fi := testutil.TempFile(t, "", "coder-logging-test-*") - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--verbose", "--in-memory", @@ -1286,22 +1195,13 @@ func TestServer(t *testing.T) { ) // Attach pty so we get debug output from the command if this test // fails. - pty := ptytest.New(t) - root.SetOut(pty.Output()) - root.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) - serverErr := make(chan error, 1) - go func() { - serverErr <- root.ExecuteContext(ctx) - }() - defer func() { - cancelFunc() - <-serverErr - }() + clitest.Start(t, inv.WithContext(ctx)) // Wait for server to listen on HTTP, this is a good // starting point for expecting logs. - _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ") + _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at") waitFile(t, fi, testutil.WaitSuperLong) }) @@ -1319,7 +1219,7 @@ func TestServer(t *testing.T) { // which can take a long time and end up failing the test. // This is why we wait extra long below for server to listen on // HTTP. - root, _ := clitest.New(t, + inv, _ := clitest.New(t, "server", "--verbose", "--in-memory", @@ -1331,15 +1231,13 @@ func TestServer(t *testing.T) { ) // Attach pty so we get debug output from the command if this test // fails. - pty := ptytest.New(t) - root.SetOut(pty.Output()) - root.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) - clitest.Start(ctx, t, root) + clitest.Start(t, inv) // Wait for server to listen on HTTP, this is a good // starting point for expecting logs. - _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at ") + _ = pty.ExpectMatchContext(ctx, "Started HTTP listener at") waitFile(t, fi1, testutil.WaitSuperLong) waitFile(t, fi2, testutil.WaitSuperLong) diff --git a/cli/show.go b/cli/show.go index 9ed91d3e511a9..3dff78fcaefdc 100644 --- a/cli/show.go +++ b/cli/show.go @@ -1,32 +1,32 @@ package cli import ( - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" + "github.com/coder/coder/codersdk" ) -func show() *cobra.Command { - return &cobra.Command{ - Annotations: workspaceCommand, - Use: "show ", - Short: "Display details of a workspace's resources and agents", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - buildInfo, err := client.BuildInfo(cmd.Context()) +func (r *RootCmd) show() *clibase.Cmd { + client := new(codersdk.Client) + return &clibase.Cmd{ + Use: "show ", + Short: "Display details of a workspace's resources and agents", + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + buildInfo, err := client.BuildInfo(inv.Context()) if err != nil { return xerrors.Errorf("get server version: %w", err) } - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return xerrors.Errorf("get workspace: %w", err) } - return cliui.WorkspaceResources(cmd.OutOrStdout(), workspace.LatestBuild.Resources, cliui.WorkspaceResourcesOptions{ + return cliui.WorkspaceResources(inv.Stdout, workspace.LatestBuild.Resources, cliui.WorkspaceResourcesOptions{ WorkspaceName: workspace.Name, ServerVersion: buildInfo.Version, }) diff --git a/cli/show_test.go b/cli/show_test.go index 088c0c21e60d8..6f5faaa3fde11 100644 --- a/cli/show_test.go +++ b/cli/show_test.go @@ -31,15 +31,13 @@ func TestShow(t *testing.T) { "show", workspace.Name, } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) doneChan := make(chan struct{}) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) go func() { defer close(doneChan) - err := cmd.Execute() + err := inv.Run() assert.NoError(t, err) }() matches := []struct { diff --git a/cli/speedtest.go b/cli/speedtest.go index 2fc62227fdd58..986088e2ea238 100644 --- a/cli/speedtest.go +++ b/cli/speedtest.go @@ -6,43 +6,41 @@ import ( "time" "github.com/jedib0t/go-pretty/v6/table" - "github.com/spf13/cobra" "golang.org/x/xerrors" tsspeedtest "tailscale.com/net/speedtest" "cdr.dev/slog" "cdr.dev/slog/sloggers/sloghuman" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func speedtest() *cobra.Command { +func (r *RootCmd) speedtest() *clibase.Cmd { var ( direct bool duration time.Duration direction string ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "speedtest ", - Args: cobra.ExactArgs(1), Short: "Run upload and download tests from your machine to a workspace", - RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - client, err := CreateClient(cmd) - if err != nil { - return xerrors.Errorf("create codersdk client: %w", err) - } - - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], false) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0]) if err != nil { return err } - err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -53,9 +51,9 @@ func speedtest() *cobra.Command { } logger, ok := LoggerFromContext(ctx) if !ok { - logger = slog.Make(sloghuman.Sink(cmd.ErrOrStderr())) + logger = slog.Make(sloghuman.Sink(inv.Stderr)) } - if cliflag.IsSetBool(cmd, varVerbose) { + if r.verbose { logger = logger.Leveled(slog.LevelDebug) } conn, err := client.DialWorkspaceAgent(ctx, workspaceAgent.ID, &codersdk.DialWorkspaceAgentOptions{ @@ -84,14 +82,14 @@ func speedtest() *cobra.Command { } peer := status.Peer[status.Peers()[0]] if !p2p && direct { - cmd.Printf("Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) + cliui.Infof(inv.Stdout, "Waiting for a direct connection... (%dms via %s)\n", dur.Milliseconds(), peer.Relay) continue } via := peer.Relay if via == "" { via = "direct" } - cmd.Printf("%dms via %s\n", dur.Milliseconds(), via) + cliui.Infof(inv.Stdout, "%dms via %s\n", dur.Milliseconds(), via) break } } else { @@ -106,7 +104,7 @@ func speedtest() *cobra.Command { default: return xerrors.Errorf("invalid direction: %q", direction) } - cmd.Printf("Starting a %ds %s test...\n", int(duration.Seconds()), tsDir) + cliui.Infof(inv.Stdout, "Starting a %ds %s test...\n", int(duration.Seconds()), tsDir) results, err := conn.Speedtest(ctx, tsDir, duration) if err != nil { return err @@ -123,16 +121,31 @@ func speedtest() *cobra.Command { fmt.Sprintf("%.4f Mbits/sec", r.MBitsPerSecond()), }) } - _, err = fmt.Fprintln(cmd.OutOrStdout(), tableWriter.Render()) + _, err = fmt.Fprintln(inv.Stdout, tableWriter.Render()) return err }, } - cliflag.BoolVarP(cmd.Flags(), &direct, "direct", "d", "", false, - "Specifies whether to wait for a direct connection before testing speed.") - cliflag.StringVarP(cmd.Flags(), &direction, "direction", "", "", "down", - "Specifies whether to run in reverse mode where the client receives and the server sends. (up|down)", - ) - cmd.Flags().DurationVarP(&duration, "time", "t", tsspeedtest.DefaultDuration, - "Specifies the duration to monitor traffic.") + cmd.Options = clibase.OptionSet{ + { + Description: "Specifies whether to wait for a direct connection before testing speed.", + Flag: "direct", + FlagShorthand: "d", + + Value: clibase.BoolOf(&direct), + }, + { + Description: "Specifies whether to run in reverse mode where the client receives and the server sends.", + Flag: "direction", + Default: "down", + Value: clibase.EnumOf(&direction, "up", "down"), + }, + { + Description: "Specifies the duration to monitor traffic.", + Flag: "time", + FlagShorthand: "t", + Default: tsspeedtest.DefaultDuration.String(), + Value: clibase.DurationOf(&duration), + }, + } return cmd } diff --git a/cli/speedtest_test.go b/cli/speedtest_test.go index 3cb2956975525..b05e3689347a3 100644 --- a/cli/speedtest_test.go +++ b/cli/speedtest_test.go @@ -48,18 +48,18 @@ func TestSpeedtest(t *testing.T) { a.LifecycleState == codersdk.WorkspaceAgentLifecycleReady }, testutil.WaitLong, testutil.IntervalFast, "agent is not ready") - cmd, root := clitest.New(t, "speedtest", workspace.Name) + inv, root := clitest.New(t, "speedtest", workspace.Name) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + inv.Stdout = pty.Output() + inv.Stderr = pty.Output() ctx, cancel = context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() ctx = cli.ContextWithLogger(ctx, slogtest.Make(t, nil).Named("speedtest").Leveled(slog.LevelDebug)) cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) }) <-cmdDone diff --git a/cli/ssh.go b/cli/ssh.go index 5adeba63bbae6..3b02ea387f82d 100644 --- a/cli/ssh.go +++ b/cli/ssh.go @@ -18,14 +18,13 @@ import ( "github.com/gofrs/flock" "github.com/google/uuid" "github.com/mattn/go-isatty" - "github.com/spf13/cobra" gossh "golang.org/x/crypto/ssh" gosshagent "golang.org/x/crypto/ssh/agent" "golang.org/x/term" "golang.org/x/xerrors" "github.com/coder/coder/agent" - "github.com/coder/coder/cli/cliflag" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/autobuild/notify" "github.com/coder/coder/coderd/util/ptr" @@ -38,55 +37,41 @@ var ( autostopNotifyCountdown = []time.Duration{30 * time.Minute} ) -func ssh() *cobra.Command { +func (r *RootCmd) ssh() *clibase.Cmd { var ( stdio bool - shuffle bool forwardAgent bool forwardGPG bool identityAgent string wsPollInterval time.Duration noWait bool ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "ssh ", Short: "Start a shell into a workspace", - Args: cobra.ArbitraryArgs, - RunE: func(cmd *cobra.Command, args []string) error { - ctx, cancel := context.WithCancel(cmd.Context()) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + ctx, cancel := context.WithCancel(inv.Context()) defer cancel() - client, err := CreateClient(cmd) - if err != nil { - return err - } - - if shuffle { - err := cobra.ExactArgs(0)(cmd, args) - if err != nil { - return err - } - } else { - err := cobra.MinimumNArgs(1)(cmd, args) - if err != nil { - return err - } - } - - workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, cmd, client, codersdk.Me, args[0], shuffle) + workspace, workspaceAgent, err := getWorkspaceAndAgent(ctx, inv, client, codersdk.Me, inv.Args[0]) if err != nil { return err } updateWorkspaceBanner, outdated := verifyWorkspaceOutdated(client, workspace) - if outdated && isTTYErr(cmd) { - _, _ = fmt.Fprintln(cmd.ErrOrStderr(), updateWorkspaceBanner) + if outdated && isTTYErr(inv) { + _, _ = fmt.Fprintln(inv.Stderr, updateWorkspaceBanner) } // OpenSSH passes stderr directly to the calling TTY. // This is required in "stdio" mode so a connecting indicator can be displayed. - err = cliui.Agent(ctx, cmd.ErrOrStderr(), cliui.AgentOptions{ + err = cliui.Agent(ctx, inv.Stderr, cliui.AgentOptions{ WorkspaceName: workspace.Name, Fetch: func(ctx context.Context) (codersdk.WorkspaceAgent, error) { return client.WorkspaceAgent(ctx, workspaceAgent.ID) @@ -120,9 +105,9 @@ func ssh() *cobra.Command { defer rawSSH.Close() go func() { - _, _ = io.Copy(cmd.OutOrStdout(), rawSSH) + _, _ = io.Copy(inv.Stdout, rawSSH) }() - _, _ = io.Copy(rawSSH, cmd.InOrStdin()) + _, _ = io.Copy(rawSSH, inv.Stdin) return nil } @@ -168,15 +153,15 @@ func ssh() *cobra.Command { if err != nil { return xerrors.Errorf("upload GPG public keys and ownertrust to workspace: %w", err) } - closer, err := forwardGPGAgent(ctx, cmd.ErrOrStderr(), sshClient) + closer, err := forwardGPGAgent(ctx, inv.Stderr, sshClient) if err != nil { return xerrors.Errorf("forward GPG socket: %w", err) } defer closer.Close() } - stdoutFile, validOut := cmd.OutOrStdout().(*os.File) - stdinFile, validIn := cmd.InOrStdin().(*os.File) + stdoutFile, validOut := inv.Stdout.(*os.File) + stdinFile, validIn := inv.Stdin.(*os.File) if validOut && validIn && isatty.IsTerminal(stdoutFile.Fd()) { state, err := term.MakeRaw(int(stdinFile.Fd())) if err != nil { @@ -208,9 +193,9 @@ func ssh() *cobra.Command { return err } - sshSession.Stdin = cmd.InOrStdin() - sshSession.Stdout = cmd.OutOrStdout() - sshSession.Stderr = cmd.ErrOrStderr() + sshSession.Stdin = inv.Stdin + sshSession.Stdout = inv.Stdout + sshSession.Stderr = inv.Stderr err = sshSession.Shell() if err != nil { @@ -243,53 +228,70 @@ func ssh() *cobra.Command { return nil }, } - cliflag.BoolVarP(cmd.Flags(), &stdio, "stdio", "", "CODER_SSH_STDIO", false, "Specifies whether to emit SSH output over stdin/stdout.") - cliflag.BoolVarP(cmd.Flags(), &shuffle, "shuffle", "", "CODER_SSH_SHUFFLE", false, "Specifies whether to choose a random workspace") - _ = cmd.Flags().MarkHidden("shuffle") - cliflag.BoolVarP(cmd.Flags(), &forwardAgent, "forward-agent", "A", "CODER_SSH_FORWARD_AGENT", false, "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK") - cliflag.BoolVarP(cmd.Flags(), &forwardGPG, "forward-gpg", "G", "CODER_SSH_FORWARD_GPG", false, "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.") - cliflag.StringVarP(cmd.Flags(), &identityAgent, "identity-agent", "", "CODER_SSH_IDENTITY_AGENT", "", "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled") - cliflag.DurationVarP(cmd.Flags(), &wsPollInterval, "workspace-poll-interval", "", "CODER_WORKSPACE_POLL_INTERVAL", workspacePollInterval, "Specifies how often to poll for workspace automated shutdown.") - cliflag.BoolVarP(cmd.Flags(), &noWait, "no-wait", "", "CODER_SSH_NO_WAIT", false, "Specifies whether to wait for a workspace to become ready before logging in (only applicable when the login before ready option has not been enabled). Note that the workspace agent may still be in the process of executing the startup script and the workspace may be in an incomplete state.") + cmd.Options = clibase.OptionSet{ + { + Flag: "stdio", + Env: "CODER_SSH_STDIO", + Description: "Specifies whether to emit SSH output over stdin/stdout.", + Value: clibase.BoolOf(&stdio), + }, + { + Flag: "forward-agent", + FlagShorthand: "A", + Env: "CODER_SSH_FORWARD_AGENT", + Description: "Specifies whether to forward the SSH agent specified in $SSH_AUTH_SOCK.", + Value: clibase.BoolOf(&forwardAgent), + }, + { + Flag: "forward-gpg", + FlagShorthand: "G", + Env: "CODER_SSH_FORWARD_GPG", + Description: "Specifies whether to forward the GPG agent. Unsupported on Windows workspaces, but supports all clients. Requires gnupg (gpg, gpgconf) on both the client and workspace. The GPG agent must already be running locally and will not be started for you. If a GPG agent is already running in the workspace, it will be attempted to be killed.", + Value: clibase.BoolOf(&forwardGPG), + }, + { + Flag: "identity-agent", + Env: "CODER_SSH_IDENTITY_AGENT", + Description: "Specifies which identity agent to use (overrides $SSH_AUTH_SOCK), forward agent must also be enabled.", + Value: clibase.StringOf(&identityAgent), + }, + { + Flag: "workspace-poll-interval", + Env: "CODER_WORKSPACE_POLL_INTERVAL", + Description: "Specifies how often to poll for workspace automated shutdown.", + Default: "1m", + Value: clibase.DurationOf(&wsPollInterval), + }, + { + Flag: "no-wait", + Env: "CODER_SSH_NO_WAIT", + Description: "Specifies whether to wait for a workspace to become ready before logging in (only applicable when the login before ready option has not been enabled). Note that the workspace agent may still be in the process of executing the startup script and the workspace may be in an incomplete state.", + Value: clibase.BoolOf(&noWait), + }, + } return cmd } // getWorkspaceAgent returns the workspace and agent selected using either the // `[.]` syntax via `in` or picks a random workspace and agent // if `shuffle` is true. -func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *codersdk.Client, userID string, in string, shuffle bool) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive +func getWorkspaceAndAgent(ctx context.Context, inv *clibase.Invocation, client *codersdk.Client, userID string, in string) (codersdk.Workspace, codersdk.WorkspaceAgent, error) { //nolint:revive var ( workspace codersdk.Workspace workspaceParts = strings.Split(in, ".") err error ) - if shuffle { - res, err := client.Workspaces(ctx, codersdk.WorkspaceFilter{ - Owner: userID, - }) - if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err - } - if len(res.Workspaces) == 0 { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("no workspaces to shuffle") - } - workspace, err = cryptorand.Element(res.Workspaces) - if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err - } - } else { - workspace, err = namedWorkspace(cmd, client, workspaceParts[0]) - if err != nil { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err - } + workspace, err = namedWorkspace(inv.Context(), client, workspaceParts[0]) + if err != nil { + return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err } if workspace.LatestBuild.Transition != codersdk.WorkspaceTransitionStart { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("workspace must be in start transition to ssh") } if workspace.LatestBuild.Job.CompletedAt == nil { - err := cliui.WorkspaceBuild(ctx, cmd.ErrOrStderr(), client, workspace.LatestBuild.ID) + err := cliui.WorkspaceBuild(ctx, inv.Stderr, client, workspace.LatestBuild.ID) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err } @@ -322,9 +324,6 @@ func getWorkspaceAndAgent(ctx context.Context, cmd *cobra.Command, client *coder } if workspaceAgent.ID == uuid.Nil { if len(agents) > 1 { - if !shuffle { - return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, xerrors.New("you must specify the name of an agent") - } workspaceAgent, err = cryptorand.Element(agents) if err != nil { return codersdk.Workspace{}, codersdk.WorkspaceAgent{}, err diff --git a/cli/ssh_test.go b/cli/ssh_test.go index 1fb13c0593d87..ec1dc1cb46b74 100644 --- a/cli/ssh_test.go +++ b/cli/ssh_test.go @@ -87,18 +87,15 @@ func TestSSH(t *testing.T) { t.Parallel() client, workspace, agentToken := setupWorkspaceForAgent(t, nil) - cmd, root := clitest.New(t, "ssh", workspace.Name) + inv, root := clitest.New(t, "ssh", workspace.Name) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetErr(pty.Output()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) }) pty.ExpectMatch("Waiting") @@ -128,18 +125,18 @@ func TestSSH(t *testing.T) { a[0].TroubleshootingUrl = wantURL return a }) - cmd, root := clitest.New(t, "ssh", workspace.Name) + inv, root := clitest.New(t, "ssh", workspace.Name) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetErr(pty.Output()) - cmd.SetOut(pty.Output()) + inv.Stdin = pty.Input() + inv.Stderr = pty.Output() + inv.Stdout = pty.Output() ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.ErrorIs(t, err, cliui.Canceled) }) pty.ExpectMatch(wantURL) @@ -173,13 +170,13 @@ func TestSSH(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - cmd, root := clitest.New(t, "ssh", "--stdio", workspace.Name) + inv, root := clitest.New(t, "ssh", "--stdio", workspace.Name) clitest.SetupConfig(t, client, root) - cmd.SetIn(clientOutput) - cmd.SetOut(serverInput) - cmd.SetErr(io.Discard) + inv.Stdin = clientOutput + inv.Stdout = serverInput + inv.Stderr = io.Discard cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err) }) @@ -262,19 +259,17 @@ func TestSSH(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) defer cancel() - cmd, root := clitest.New(t, + inv, root := clitest.New(t, "ssh", workspace.Name, "--forward-agent", "--identity-agent", agentSock, // Overrides $SSH_AUTH_SOCK. ) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - cmd.SetErr(pty.Output()) + pty := ptytest.New(t).Attach(inv) + inv.Stderr = pty.Output() cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err, "ssh command failed") }) @@ -466,18 +461,18 @@ Expire-Date: 0 }) defer agentCloser.Close() - cmd, root := clitest.New(t, + inv, root := clitest.New(t, "ssh", workspace.Name, "--forward-gpg", ) clitest.SetupConfig(t, client, root) tpty := ptytest.New(t) - cmd.SetIn(tpty.Input()) - cmd.SetOut(tpty.Output()) - cmd.SetErr(tpty.Output()) + inv.Stdin = tpty.Input() + inv.Stdout = tpty.Output() + inv.Stderr = tpty.Output() cmdDone := tGo(t, func() { - err := cmd.ExecuteContext(ctx) + err := inv.WithContext(ctx).Run() assert.NoError(t, err, "ssh command failed") }) // Prevent the test from hanging if the asserts below kill the test diff --git a/cli/start.go b/cli/start.go index 7bf4782e14bad..6ce6093afe774 100644 --- a/cli/start.go +++ b/cli/start.go @@ -4,43 +4,44 @@ import ( "fmt" "time" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func start() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) start() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "start ", Short: "Start a workspace", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Options: clibase.OptionSet{ + cliui.SkipPromptOption(), + }, + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ Transition: codersdk.WorkspaceTransitionStart, }) if err != nil { return err } - err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) + err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID) if err != nil { return err } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been started at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) + _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been started at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) return nil }, } - cliui.AllowSkipPrompt(cmd) return cmd } diff --git a/cli/state.go b/cli/state.go index cbb2074dacf7b..dd18e56d90f41 100644 --- a/cli/state.go +++ b/cli/state.go @@ -6,78 +6,92 @@ import ( "os" "strconv" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func state() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) state() *clibase.Cmd { + cmd := &clibase.Cmd{ Use: "state", Short: "Manually manage Terraform state to fix broken workspaces", - RunE: func(cmd *cobra.Command, args []string) error { - return cmd.Help() + Handler: func(inv *clibase.Invocation) error { + return inv.Command.HelpHandler(inv) + }, + Children: []*clibase.Cmd{ + r.statePull(), + r.statePush(), }, } - cmd.AddCommand(statePull(), statePush()) return cmd } -func statePull() *cobra.Command { - var buildNumber int - cmd := &cobra.Command{ +func (r *RootCmd) statePull() *clibase.Cmd { + var buildNumber int64 + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "pull [file]", Short: "Pull a Terraform state file from a workspace.", - Args: cobra.MinimumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } + Middleware: clibase.Chain( + clibase.RequireRangeArgs(1, 2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + var err error var build codersdk.WorkspaceBuild if buildNumber == 0 { - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } build = workspace.LatestBuild } else { - build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(cmd.Context(), codersdk.Me, args[0], strconv.Itoa(buildNumber)) + build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(inv.Context(), codersdk.Me, inv.Args[0], strconv.FormatInt(buildNumber, 10)) if err != nil { return err } } - state, err := client.WorkspaceBuildState(cmd.Context(), build.ID) + state, err := client.WorkspaceBuildState(inv.Context(), build.ID) if err != nil { return err } - if len(args) < 2 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), string(state)) + if len(inv.Args) < 2 { + _, _ = fmt.Fprintln(inv.Stdout, string(state)) return nil } - return os.WriteFile(args[1], state, 0o600) + return os.WriteFile(inv.Args[1], state, 0o600) }, } - cmd.Flags().IntVarP(&buildNumber, "build", "b", 0, "Specify a workspace build to target by name.") + cmd.Options = clibase.OptionSet{ + buildNumberOption(&buildNumber), + } return cmd } -func statePush() *cobra.Command { - var buildNumber int - cmd := &cobra.Command{ +func buildNumberOption(n *int64) clibase.Option { + return clibase.Option{ + Flag: "build", + FlagShorthand: "b", + Description: "Specify a workspace build to target by name. Defaults to latest.", + Value: clibase.Int64Of(n), + } +} + +func (r *RootCmd) statePush() *clibase.Cmd { + var buildNumber int64 + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "push ", - Args: cobra.ExactArgs(2), Short: "Push a Terraform state file to a workspace.", - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + Middleware: clibase.Chain( + clibase.RequireNArgs(2), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } @@ -85,23 +99,23 @@ func statePush() *cobra.Command { if buildNumber == 0 { build = workspace.LatestBuild } else { - build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(cmd.Context(), codersdk.Me, args[0], strconv.Itoa(buildNumber)) + build, err = client.WorkspaceBuildByUsernameAndWorkspaceNameAndBuildNumber(inv.Context(), codersdk.Me, inv.Args[0], strconv.FormatInt((buildNumber), 10)) if err != nil { return err } } var state []byte - if args[1] == "-" { - state, err = io.ReadAll(cmd.InOrStdin()) + if inv.Args[1] == "-" { + state, err = io.ReadAll(inv.Stdin) } else { - state, err = os.ReadFile(args[1]) + state, err = os.ReadFile(inv.Args[1]) } if err != nil { return err } - build, err = client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + build, err = client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ TemplateVersionID: build.TemplateVersionID, Transition: build.Transition, ProvisionerState: state, @@ -109,9 +123,11 @@ func statePush() *cobra.Command { if err != nil { return err } - return cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStderr(), client, build.ID) + return cliui.WorkspaceBuild(inv.Context(), inv.Stderr, client, build.ID) }, } - cmd.Flags().IntVarP(&buildNumber, "build", "b", 0, "Specify a workspace build to target by name.") + cmd.Options = clibase.OptionSet{ + buildNumberOption(&buildNumber), + } return cmd } diff --git a/cli/state_test.go b/cli/state_test.go index 5d05313eb5414..2a208fd64d25c 100644 --- a/cli/state_test.go +++ b/cli/state_test.go @@ -38,9 +38,9 @@ func TestStatePull(t *testing.T) { workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) statefilePath := filepath.Join(t.TempDir(), "state") - cmd, root := clitest.New(t, "state", "pull", workspace.Name, statefilePath) + inv, root := clitest.New(t, "state", "pull", workspace.Name, statefilePath) clitest.SetupConfig(t, client, root) - err := cmd.Execute() + err := inv.Run() require.NoError(t, err) gotState, err := os.ReadFile(statefilePath) require.NoError(t, err) @@ -65,11 +65,11 @@ func TestStatePull(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "state", "pull", workspace.Name) + inv, root := clitest.New(t, "state", "pull", workspace.Name) var gotState bytes.Buffer - cmd.SetOut(&gotState) + inv.Stdout = &gotState clitest.SetupConfig(t, client, root) - err := cmd.Execute() + err := inv.Run() require.NoError(t, err) require.Equal(t, wantState, bytes.TrimSpace(gotState.Bytes())) }) @@ -96,9 +96,9 @@ func TestStatePush(t *testing.T) { require.NoError(t, err) err = stateFile.Close() require.NoError(t, err) - cmd, root := clitest.New(t, "state", "push", workspace.Name, stateFile.Name()) + inv, root := clitest.New(t, "state", "push", workspace.Name, stateFile.Name()) clitest.SetupConfig(t, client, root) - err = cmd.Execute() + err = inv.Run() require.NoError(t, err) }) @@ -114,10 +114,10 @@ func TestStatePush(t *testing.T) { template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) workspace := coderdtest.CreateWorkspace(t, client, user.OrganizationID, template.ID) coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID) - cmd, root := clitest.New(t, "state", "push", "--build", strconv.Itoa(int(workspace.LatestBuild.BuildNumber)), workspace.Name, "-") + inv, root := clitest.New(t, "state", "push", "--build", strconv.Itoa(int(workspace.LatestBuild.BuildNumber)), workspace.Name, "-") clitest.SetupConfig(t, client, root) - cmd.SetIn(strings.NewReader("some magic state")) - err := cmd.Execute() + inv.Stdin = strings.NewReader("some magic state") + err := inv.Run() require.NoError(t, err) }) } diff --git a/cli/stop.go b/cli/stop.go index 9bb355ef0bd5a..442b6b662ea8b 100644 --- a/cli/stop.go +++ b/cli/stop.go @@ -4,20 +4,26 @@ import ( "fmt" "time" - "github.com/spf13/cobra" - + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func stop() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) stop() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Annotations: workspaceCommand, Use: "stop ", Short: "Stop a workspace", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - _, err := cliui.Prompt(cmd, cliui.PromptOptions{ + Middleware: clibase.Chain( + clibase.RequireNArgs(1), + r.InitClient(client), + ), + Options: clibase.OptionSet{ + cliui.SkipPromptOption(), + }, + Handler: func(inv *clibase.Invocation) error { + _, err := cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm stop workspace?", IsConfirm: true, }) @@ -25,30 +31,25 @@ func stop() *cobra.Command { return err } - client, err := CreateClient(cmd) - if err != nil { - return err - } - workspace, err := namedWorkspace(cmd, client, args[0]) + workspace, err := namedWorkspace(inv.Context(), client, inv.Args[0]) if err != nil { return err } - build, err := client.CreateWorkspaceBuild(cmd.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ + build, err := client.CreateWorkspaceBuild(inv.Context(), workspace.ID, codersdk.CreateWorkspaceBuildRequest{ Transition: codersdk.WorkspaceTransitionStop, }) if err != nil { return err } - err = cliui.WorkspaceBuild(cmd.Context(), cmd.OutOrStdout(), client, build.ID) + err = cliui.WorkspaceBuild(inv.Context(), inv.Stdout, client, build.ID) if err != nil { return err } - _, _ = fmt.Fprintf(cmd.OutOrStdout(), "\nThe %s workspace has been stopped at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) + _, _ = fmt.Fprintf(inv.Stdout, "\nThe %s workspace has been stopped at %s!\n", cliui.Styles.Keyword.Render(workspace.Name), cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))) return nil }, } - cliui.AllowSkipPrompt(cmd) return cmd } diff --git a/cli/templatecreate.go b/cli/templatecreate.go index be5bc59d3c0b2..823a3cd1e45a4 100644 --- a/cli/templatecreate.go +++ b/cli/templatecreate.go @@ -11,9 +11,9 @@ import ( "unicode/utf8" "github.com/google/uuid" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/coderd/database" "github.com/coder/coder/coderd/util/ptr" @@ -21,7 +21,7 @@ import ( "github.com/coder/coder/provisionerd" ) -func templateCreate() *cobra.Command { +func (r *RootCmd) templateCreate() *clibase.Cmd { var ( provisioner string provisionerTags []string @@ -32,22 +32,21 @@ func templateCreate() *cobra.Command { uploadFlags templateUploadFlags ) - cmd := &cobra.Command{ + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "create [name]", Short: "Create a template from the current directory or as specified by flag", - Args: cobra.MaximumNArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - client, err := CreateClient(cmd) + Middleware: clibase.Chain( + clibase.RequireRangeArgs(0, 1), + r.InitClient(client), + ), + Handler: func(inv *clibase.Invocation) error { + organization, err := CurrentOrganization(inv, client) if err != nil { return err } - organization, err := CurrentOrganization(cmd, client) - if err != nil { - return err - } - - templateName, err := uploadFlags.templateName(args) + templateName, err := uploadFlags.templateName(inv.Args) if err != nil { return err } @@ -56,13 +55,13 @@ func templateCreate() *cobra.Command { return xerrors.Errorf("Template name must be less than 32 characters") } - _, err = client.TemplateByName(cmd.Context(), organization.ID, templateName) + _, err = client.TemplateByName(inv.Context(), organization.ID, templateName) if err == nil { return xerrors.Errorf("A template already exists named %q!", templateName) } // Confirm upload of the directory. - resp, err := uploadFlags.upload(cmd, client) + resp, err := uploadFlags.upload(inv, client) if err != nil { return err } @@ -72,7 +71,7 @@ func templateCreate() *cobra.Command { return err } - job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{ + job, _, err := createValidTemplateVersion(inv, createValidTemplateVersionArgs{ Client: client, Organization: organization, Provisioner: database.ProvisionerType(provisioner), @@ -87,7 +86,7 @@ func templateCreate() *cobra.Command { } if !uploadFlags.stdin() { - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: "Confirm create?", IsConfirm: true, }) @@ -102,34 +101,58 @@ func templateCreate() *cobra.Command { DefaultTTLMillis: ptr.Ref(defaultTTL.Milliseconds()), } - _, err = client.CreateTemplate(cmd.Context(), organization.ID, createReq) + _, err = client.CreateTemplate(inv.Context(), organization.ID, createReq) if err != nil { return err } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "\n"+cliui.Styles.Wrap.Render( + _, _ = fmt.Fprintln(inv.Stdout, "\n"+cliui.Styles.Wrap.Render( "The "+cliui.Styles.Keyword.Render(templateName)+" template has been created at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"! "+ "Developers can provision a workspace with this template using:")+"\n") - _, _ = fmt.Fprintln(cmd.OutOrStdout(), " "+cliui.Styles.Code.Render(fmt.Sprintf("coder create --template=%q [workspace name]", templateName))) - _, _ = fmt.Fprintln(cmd.OutOrStdout()) + _, _ = fmt.Fprintln(inv.Stdout, " "+cliui.Styles.Code.Render(fmt.Sprintf("coder create --template=%q [workspace name]", templateName))) + _, _ = fmt.Fprintln(inv.Stdout) return nil }, } - cmd.Flags().StringVarP(¶meterFile, "parameter-file", "", "", "Specify a file path with parameter values.") - cmd.Flags().StringVarP(&variablesFile, "variables-file", "", "", "Specify a file path with values for Terraform-managed variables.") - cmd.Flags().StringArrayVarP(&variables, "variable", "", []string{}, "Specify a set of values for Terraform-managed variables.") - cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.") - cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 24*time.Hour, "Specify a default TTL for workspaces created from this template.") - uploadFlags.register(cmd.Flags()) - cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend") - // This is for testing! - err := cmd.Flags().MarkHidden("test.provisioner") - if err != nil { - panic(err) + cmd.Options = clibase.OptionSet{ + { + Flag: "parameter-file", + Description: "Specify a file path with parameter values.", + Value: clibase.StringOf(¶meterFile), + }, + { + Flag: "variables-file", + Description: "Specify a file path with values for Terraform-managed variables.", + Value: clibase.StringOf(&variablesFile), + }, + { + Flag: "variable", + Description: "Specify a set of values for Terraform-managed variables.", + Value: clibase.StringArrayOf(&variables), + }, + { + Flag: "provisioner-tag", + Description: "Specify a set of tags to target provisioner daemons.", + Value: clibase.StringArrayOf(&provisionerTags), + }, + { + Flag: "default-ttl", + Description: "Specify a default TTL for workspaces created from this template.", + Default: "24h", + Value: clibase.DurationOf(&defaultTTL), + }, + uploadFlags.option(), + { + Flag: "test.provisioner", + Description: "Customize the provisioner backend.", + Default: "terraform", + Value: clibase.StringOf(&provisioner), + Hidden: true, + }, + cliui.SkipPromptOption(), } - cliui.AllowSkipPrompt(cmd) return cmd } @@ -153,7 +176,7 @@ type createValidTemplateVersionArgs struct { ProvisionerTags map[string]string } -func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { +func createValidTemplateVersion(inv *clibase.Invocation, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) { client := args.Client variableValues, err := loadVariableValuesFromFile(args.VariablesFile) @@ -179,21 +202,21 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers if args.Template != nil { req.TemplateID = args.Template.ID } - version, err := client.CreateTemplateVersion(cmd.Context(), args.Organization.ID, req) + version, err := client.CreateTemplateVersion(inv.Context(), args.Organization.ID, req) if err != nil { return nil, nil, err } - err = cliui.ProvisionerJob(cmd.Context(), cmd.OutOrStdout(), cliui.ProvisionerJobOptions{ + err = cliui.ProvisionerJob(inv.Context(), inv.Stdout, cliui.ProvisionerJobOptions{ Fetch: func() (codersdk.ProvisionerJob, error) { - version, err := client.TemplateVersion(cmd.Context(), version.ID) + version, err := client.TemplateVersion(inv.Context(), version.ID) return version.Job, err }, Cancel: func() error { - return client.CancelTemplateVersion(cmd.Context(), version.ID) + return client.CancelTemplateVersion(inv.Context(), version.ID) }, Logs: func() (<-chan codersdk.ProvisionerJobLog, io.Closer, error) { - return client.TemplateVersionLogsAfter(cmd.Context(), version.ID, 0) + return client.TemplateVersionLogsAfter(inv.Context(), version.ID, 0) }, }) if err != nil { @@ -202,15 +225,15 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers return nil, nil, err } } - version, err = client.TemplateVersion(cmd.Context(), version.ID) + version, err = client.TemplateVersion(inv.Context(), version.ID) if err != nil { return nil, nil, err } - parameterSchemas, err := client.TemplateVersionSchema(cmd.Context(), version.ID) + parameterSchemas, err := client.TemplateVersionSchema(inv.Context(), version.ID) if err != nil { return nil, nil, err } - parameterValues, err := client.TemplateVersionParameters(cmd.Context(), version.ID) + parameterValues, err := client.TemplateVersionParameters(inv.Context(), version.ID) if err != nil { return nil, nil, err } @@ -220,13 +243,13 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers // version instead of prompting if we are updating template versions. lastParameterValues := make(map[string]codersdk.Parameter) if args.ReuseParameters && args.Template != nil { - activeVersion, err := client.TemplateVersion(cmd.Context(), args.Template.ActiveVersionID) + activeVersion, err := client.TemplateVersion(inv.Context(), args.Template.ActiveVersionID) if err != nil { return nil, nil, xerrors.Errorf("Fetch current active template version: %w", err) } // We don't want to compute the params, we only want to copy from this scope - values, err := client.Parameters(cmd.Context(), codersdk.ParameterImportJob, activeVersion.Job.ID) + values, err := client.Parameters(inv.Context(), codersdk.ParameterImportJob, activeVersion.Job.ID) if err != nil { return nil, nil, xerrors.Errorf("Fetch previous version parameters: %w", err) } @@ -244,7 +267,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers // parameterMapFromFile can be nil if parameter file is not specified var parameterMapFromFile map[string]string if args.ParameterFile != "" { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Attempting to read the variables from the parameter file.")+"\r\n") parameterMapFromFile, err = createParameterMapFromFile(args.ParameterFile) if err != nil { return nil, nil, err @@ -275,15 +298,15 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers missingSchemas = append(missingSchemas, parameterSchema) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set.")) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("This template has required variables! They are scoped to the template, and not viewable after being set.")) if len(pulled) > 0 { - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render(fmt.Sprintf("The following parameter values are being pulled from the latest template version: %s.", strings.Join(pulled, ", ")))) - _, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Paragraph.Render("Use \"--always-prompt\" flag to change the values.")) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render(fmt.Sprintf("The following parameter values are being pulled from the latest template version: %s.", strings.Join(pulled, ", ")))) + _, _ = fmt.Fprintln(inv.Stdout, cliui.Styles.Paragraph.Render("Use \"--always-prompt\" flag to change the values.")) } - _, _ = fmt.Fprint(cmd.OutOrStdout(), "\r\n") + _, _ = fmt.Fprint(inv.Stdout, "\r\n") for _, parameterSchema := range missingSchemas { - parameterValue, err := getParameterValueFromMapOrInput(cmd, parameterMapFromFile, parameterSchema) + parameterValue, err := getParameterValueFromMapOrInput(inv, parameterMapFromFile, parameterSchema) if err != nil { return nil, nil, err } @@ -293,19 +316,19 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers SourceScheme: codersdk.ParameterSourceSchemeData, DestinationScheme: parameterSchema.DefaultDestinationScheme, }) - _, _ = fmt.Fprintln(cmd.OutOrStdout()) + _, _ = fmt.Fprintln(inv.Stdout) } // This recursion is only 1 level deep in practice. // The first pass populates the missing parameters, so it does not enter this `if` block again. - return createValidTemplateVersion(cmd, args, parameters...) + return createValidTemplateVersion(inv, args, parameters...) } if version.Job.Status != codersdk.ProvisionerJobSucceeded { return nil, nil, xerrors.New(version.Job.Error) } - resources, err := client.TemplateVersionResources(cmd.Context(), version.ID) + resources, err := client.TemplateVersionResources(inv.Context(), version.ID) if err != nil { return nil, nil, err } @@ -317,7 +340,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers startResources = append(startResources, r) } } - err = cliui.WorkspaceResources(cmd.OutOrStdout(), startResources, cliui.WorkspaceResourcesOptions{ + err = cliui.WorkspaceResources(inv.Stdout, startResources, cliui.WorkspaceResourcesOptions{ HideAgentState: true, HideAccess: true, Title: "Template Preview", diff --git a/cli/templatecreate_test.go b/cli/templatecreate_test.go index b9308854dc868..5bf688972e5ec 100644 --- a/cli/templatecreate_test.go +++ b/cli/templatecreate_test.go @@ -55,16 +55,11 @@ func TestTemplateCreate(t *testing.T) { "--test.provisioner", string(database.ProvisionerTypeEcho), "--default-ttl", "24h", } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + clitest.Start(t, inv) matches := []struct { match string @@ -81,8 +76,6 @@ func TestTemplateCreate(t *testing.T) { pty.WriteLine(m.write) } } - - require.NoError(t, <-execDone) }) t.Run("CreateStdin", func(t *testing.T) { @@ -103,18 +96,13 @@ func TestTemplateCreate(t *testing.T) { "--test.provisioner", string(database.ProvisionerTypeEcho), "--default-ttl", "24h", } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) pty := ptytest.New(t) - cmd.SetIn(bytes.NewReader(source)) - cmd.SetOut(pty.Output()) - - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + inv.Stdin = bytes.NewReader(source) + inv.Stdout = pty.Output() - require.NoError(t, <-execDone) + require.NoError(t, inv.Run()) }) t.Run("WithParameter", func(t *testing.T) { @@ -126,17 +114,11 @@ func TestTemplateCreate(t *testing.T) { ProvisionApply: echo.ProvisionComplete, ProvisionPlan: echo.ProvisionComplete, }) - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) matches := []struct { match string write string @@ -149,8 +131,6 @@ func TestTemplateCreate(t *testing.T) { pty.ExpectMatch(m.match) pty.WriteLine(m.write) } - - require.NoError(t, <-execDone) }) t.Run("WithParameterFileContainingTheValue", func(t *testing.T) { @@ -166,16 +146,11 @@ func TestTemplateCreate(t *testing.T) { removeTmpDirUntilSuccessAfterTest(t, tempDir) parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") _, _ = parameterFile.WriteString("region: \"bananas\"") - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + clitest.Start(t, inv) matches := []struct { match string @@ -188,8 +163,6 @@ func TestTemplateCreate(t *testing.T) { pty.ExpectMatch(m.match) pty.WriteLine(m.write) } - - require.NoError(t, <-execDone) }) t.Run("WithParameterFileNotContainingTheValue", func(t *testing.T) { @@ -205,16 +178,11 @@ func TestTemplateCreate(t *testing.T) { removeTmpDirUntilSuccessAfterTest(t, tempDir) parameterFile, _ := os.CreateTemp(tempDir, "testParameterFile*.yaml") _, _ = parameterFile.WriteString("zone: \"bananas\"") - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--parameter-file", parameterFile.Name()) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + clitest.Start(t, inv) matches := []struct { match string @@ -237,8 +205,6 @@ func TestTemplateCreate(t *testing.T) { pty.ExpectMatch(m.match) pty.WriteLine(m.write) } - - require.NoError(t, <-execDone) }) t.Run("Recreate template with same name (create, delete, create)", func(t *testing.T) { @@ -259,10 +225,10 @@ func TestTemplateCreate(t *testing.T) { "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) - return cmd.Execute() + return inv.Run() } del := func() error { args := []string{ @@ -271,10 +237,10 @@ func TestTemplateCreate(t *testing.T) { "my-template", "--yes", } - cmd, root := clitest.New(t, args...) + inv, root := clitest.New(t, args...) clitest.SetupConfig(t, client, root) - return cmd.Execute() + return inv.Run() } err := create() @@ -289,15 +255,10 @@ func TestTemplateCreate(t *testing.T) { t.Parallel() client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true}) coderdtest.CreateFirstUser(t, client) - cmd, root := clitest.New(t, "templates", "create", "1234567890123456789012345678901234567891", "--test.provisioner", string(database.ProvisionerTypeEcho)) + inv, root := clitest.New(t, "templates", "create", "1234567890123456789012345678901234567891", "--test.provisioner", string(database.ProvisionerTypeEcho)) clitest.SetupConfig(t, client, root) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() - - require.EqualError(t, <-execDone, "Template name must be less than 32 characters") + clitest.StartWithWaiter(t, inv).RequireContains("Template name must be less than 32 characters") }) t.Run("WithVariablesFileWithoutRequiredValue", func(t *testing.T) { @@ -309,7 +270,7 @@ func TestTemplateCreate(t *testing.T) { templateVariables := []*proto.TemplateVariable{ { Name: "first_variable", - Description: "This is the first variable", + Description: "This is the first variable.", Type: "string", Required: true, Sensitive: true, @@ -329,17 +290,11 @@ func TestTemplateCreate(t *testing.T) { removeTmpDirUntilSuccessAfterTest(t, tempDir) variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml") _, _ = variablesFile.WriteString(`second_variable: foobar`) - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) - - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + pty := ptytest.New(t).Attach(inv) + clitest.Start(t, inv) matches := []struct { match string write string @@ -352,8 +307,6 @@ func TestTemplateCreate(t *testing.T) { pty.WriteLine(m.write) } } - - require.Error(t, <-execDone) }) t.Run("WithVariablesFileWithTheRequiredValue", func(t *testing.T) { @@ -365,7 +318,7 @@ func TestTemplateCreate(t *testing.T) { templateVariables := []*proto.TemplateVariable{ { Name: "first_variable", - Description: "This is the first variable", + Description: "This is the first variable.", Type: "string", Required: true, Sensitive: true, @@ -385,16 +338,11 @@ func TestTemplateCreate(t *testing.T) { removeTmpDirUntilSuccessAfterTest(t, tempDir) variablesFile, _ := os.CreateTemp(tempDir, "variables*.yaml") _, _ = variablesFile.WriteString(`first_variable: foobar`) - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variables-file", variablesFile.Name()) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + clitest.Start(t, inv) matches := []struct { match string @@ -409,8 +357,6 @@ func TestTemplateCreate(t *testing.T) { pty.WriteLine(m.write) } } - - require.NoError(t, <-execDone) }) t.Run("WithVariableOption", func(t *testing.T) { t.Parallel() @@ -421,7 +367,7 @@ func TestTemplateCreate(t *testing.T) { templateVariables := []*proto.TemplateVariable{ { Name: "first_variable", - Description: "This is the first variable", + Description: "This is the first variable.", Type: "string", Required: true, Sensitive: true, @@ -429,16 +375,11 @@ func TestTemplateCreate(t *testing.T) { } source := clitest.CreateTemplateVersionSource(t, createEchoResponsesWithTemplateVariables(templateVariables)) - cmd, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variable", "first_variable=foobar") + inv, root := clitest.New(t, "templates", "create", "my-template", "--directory", source, "--test.provisioner", string(database.ProvisionerTypeEcho), "--variable", "first_variable=foobar") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) - execDone := make(chan error) - go func() { - execDone <- cmd.Execute() - }() + clitest.Start(t, inv) matches := []struct { match string @@ -451,8 +392,6 @@ func TestTemplateCreate(t *testing.T) { pty.ExpectMatch(m.match) pty.WriteLine(m.write) } - - require.NoError(t, <-execDone) }) } diff --git a/cli/templatedelete.go b/cli/templatedelete.go index 230bb4bc2662d..4833362861489 100644 --- a/cli/templatedelete.go +++ b/cli/templatedelete.go @@ -5,35 +5,38 @@ import ( "strings" "time" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func templateDelete() *cobra.Command { - cmd := &cobra.Command{ +func (r *RootCmd) templateDelete() *clibase.Cmd { + client := new(codersdk.Client) + cmd := &clibase.Cmd{ Use: "delete [name...]", Short: "Delete templates", - RunE: func(cmd *cobra.Command, args []string) error { + Middleware: clibase.Chain( + r.InitClient(client), + ), + Options: clibase.OptionSet{ + cliui.SkipPromptOption(), + }, + Handler: func(inv *clibase.Invocation) error { var ( - ctx = cmd.Context() + ctx = inv.Context() templateNames = []string{} templates = []codersdk.Template{} ) - client, err := CreateClient(cmd) - if err != nil { - return err - } - organization, err := CurrentOrganization(cmd, client) + organization, err := CurrentOrganization(inv, client) if err != nil { return err } - if len(args) > 0 { - templateNames = args + if len(inv.Args) > 0 { + templateNames = inv.Args for _, templateName := range templateNames { template, err := client.TemplateByName(ctx, organization.ID, templateName) @@ -57,7 +60,7 @@ func templateDelete() *cobra.Command { opts = append(opts, template.Name) } - selection, err := cliui.Select(cmd, cliui.SelectOptions{ + selection, err := cliui.Select(inv, cliui.SelectOptions{ Options: opts, }) if err != nil { @@ -73,7 +76,7 @@ func templateDelete() *cobra.Command { } // Confirm deletion of the template. - _, err = cliui.Prompt(cmd, cliui.PromptOptions{ + _, err = cliui.Prompt(inv, cliui.PromptOptions{ Text: fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", "))), IsConfirm: true, Default: cliui.ConfirmNo, @@ -88,13 +91,12 @@ func templateDelete() *cobra.Command { return xerrors.Errorf("delete template %q: %w", template.Name, err) } - _, _ = fmt.Fprintln(cmd.OutOrStdout(), "Deleted template "+cliui.Styles.Code.Render(template.Name)+" at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"!") + _, _ = fmt.Fprintln(inv.Stdout, "Deleted template "+cliui.Styles.Code.Render(template.Name)+" at "+cliui.Styles.DateTimeStamp.Render(time.Now().Format(time.Stamp))+"!") } return nil }, } - cliui.AllowSkipPrompt(cmd) return cmd } diff --git a/cli/templatedelete_test.go b/cli/templatedelete_test.go index e2e404cf82a26..86893f8cd0328 100644 --- a/cli/templatedelete_test.go +++ b/cli/templatedelete_test.go @@ -27,16 +27,14 @@ func TestTemplateDelete(t *testing.T) { _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "templates", "delete", template.Name) + inv, root := clitest.New(t, "templates", "delete", template.Name) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) execDone := make(chan error) go func() { - execDone <- cmd.Execute() + execDone <- inv.Run() }() pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(template.Name))) @@ -65,9 +63,9 @@ func TestTemplateDelete(t *testing.T) { templateNames = append(templateNames, template.Name) } - cmd, root := clitest.New(t, append([]string{"templates", "delete", "--yes"}, templateNames...)...) + inv, root := clitest.New(t, append([]string{"templates", "delete", "--yes"}, templateNames...)...) clitest.SetupConfig(t, client, root) - require.NoError(t, cmd.Execute()) + require.NoError(t, inv.Run()) for _, template := range templates { _, err := client.Template(context.Background(), template.ID) @@ -92,15 +90,13 @@ func TestTemplateDelete(t *testing.T) { templateNames = append(templateNames, template.Name) } - cmd, root := clitest.New(t, append([]string{"templates", "delete"}, templateNames...)...) + inv, root := clitest.New(t, append([]string{"templates", "delete"}, templateNames...)...) clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) execDone := make(chan error) go func() { - execDone <- cmd.Execute() + execDone <- inv.Run() }() pty.ExpectMatch(fmt.Sprintf("Delete these templates: %s?", cliui.Styles.Code.Render(strings.Join(templateNames, ", ")))) @@ -123,16 +119,14 @@ func TestTemplateDelete(t *testing.T) { _ = coderdtest.AwaitTemplateVersionJob(t, client, version.ID) template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID) - cmd, root := clitest.New(t, "templates", "delete") + inv, root := clitest.New(t, "templates", "delete") clitest.SetupConfig(t, client, root) - pty := ptytest.New(t) - cmd.SetIn(pty.Input()) - cmd.SetOut(pty.Output()) + pty := ptytest.New(t).Attach(inv) execDone := make(chan error) go func() { - execDone <- cmd.Execute() + execDone <- inv.Run() }() pty.WriteLine("yes") diff --git a/cli/templateedit.go b/cli/templateedit.go index 1e487fa8000c2..e0aa6bf694fd3 100644 --- a/cli/templateedit.go +++ b/cli/templateedit.go @@ -5,14 +5,14 @@ import ( "net/http" "time" - "github.com/spf13/cobra" "golang.org/x/xerrors" + "github.com/coder/coder/cli/clibase" "github.com/coder/coder/cli/cliui" "github.com/coder/coder/codersdk" ) -func templateEdit() *cobra.Command { +func (r *RootCmd) templateEdit() *clibase.Cmd { var ( name string displayName string @@ -22,19 +22,18 @@ func templateEdit() *cobra.Command { maxTTL time.Duration allowUserCancelWorkspaceJobs bool ) + client := new(codersdk.Client) - cmd := &cobra.Command{ - Use: "edit