Skip to content

Commit

Permalink
fix: add timeout handling to legacy cli (snyk#4950)
Browse files Browse the repository at this point in the history
* fix: use deadline context for timeout in legacy workflow

* fix: tests & refactor error handling for deadline

* fix: display error better

* chore: use context.WithTimeout

Co-authored-by: Casey Marshall <[email protected]>

* chore: assert error is nil

Co-authored-by: Casey Marshall <[email protected]>

* chore: implement pr suggestions

* fix: test

---------

Co-authored-by: Casey Marshall <[email protected]>
  • Loading branch information
bastiandoetsch and cmars authored Nov 28, 2023
1 parent a0308f1 commit 1686db5
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 36 deletions.
23 changes: 15 additions & 8 deletions cliv2/cmd/cliv2/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ package main
import _ "github.com/snyk/go-application-framework/pkg/networking/fips_enable"

import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
Expand All @@ -13,6 +15,9 @@ import (
"time"

"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/spf13/pflag"

"github.com/snyk/cli-extension-dep-graph/pkg/depgraph"
"github.com/snyk/cli-extension-iac-rules/iacrules"
"github.com/snyk/cli-extension-sbom/pkg/sbom"
Expand All @@ -24,18 +29,14 @@ import (
"github.com/snyk/go-application-framework/pkg/app"
"github.com/snyk/go-application-framework/pkg/auth"
"github.com/snyk/go-application-framework/pkg/configuration"

localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
"github.com/snyk/go-application-framework/pkg/networking"
"github.com/snyk/go-application-framework/pkg/runtimeinfo"
"github.com/snyk/go-application-framework/pkg/utils"
"github.com/snyk/go-application-framework/pkg/workflow"
"github.com/snyk/go-httpauth/pkg/httpauth"
"github.com/snyk/snyk-iac-capture/pkg/capture"

snykls "github.com/snyk/snyk-ls/ls_extension"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
)

var internalOS string
Expand Down Expand Up @@ -355,7 +356,8 @@ func handleError(err error) HandleError {

func displayError(err error) {
if err != nil {
if _, ok := err.(*exec.ExitError); !ok {
var exitError *exec.ExitError
if !errors.As(err, &exitError) {
if globalConfiguration.GetBool(localworkflows.OUTPUT_CONFIG_KEY_JSON) {
jsonError := JsonErrorStruct{
Ok: false,
Expand All @@ -366,7 +368,11 @@ func displayError(err error) {
jsonErrorBuffer, _ := json.MarshalIndent(jsonError, "", " ")
fmt.Println(string(jsonErrorBuffer))
} else {
fmt.Println(err)
if errors.Is(err, context.DeadlineExceeded) {
fmt.Println("command timed out")
} else {
fmt.Println(err)
}
}
}
}
Expand Down Expand Up @@ -479,8 +485,9 @@ func setTimeout(config configuration.Configuration, onTimeout func()) {
}
debugLogger.Printf("Command timeout set for %d seconds", timeout)
go func() {
<-time.After(time.Duration(timeout) * time.Second)
fmt.Fprintf(os.Stderr, "command timed out\n")
const gracePeriodForSubProcesses = 3
<-time.After(time.Duration(timeout+gracePeriodForSubProcesses) * time.Second)
fmt.Fprintf(os.Stdout, "command timed out")
onTimeout()
}()
}
7 changes: 4 additions & 3 deletions cliv2/cmd/cliv2/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"testing"
"time"

"github.com/snyk/go-application-framework/pkg/configuration"
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
"github.com/snyk/go-application-framework/pkg/workflow"
"github.com/spf13/cobra"
"github.com/spf13/pflag"
"github.com/stretchr/testify/assert"

"github.com/snyk/go-application-framework/pkg/configuration"
localworkflows "github.com/snyk/go-application-framework/pkg/local_workflows"
"github.com/snyk/go-application-framework/pkg/workflow"
)

func cleanup() {
Expand Down
45 changes: 34 additions & 11 deletions cliv2/internal/cliv2/cliv2.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ Entry point class for the CLIv2 version.
package cliv2

import (
"context"
_ "embed"
"errors"
"fmt"
"io"
"log"
Expand All @@ -13,6 +15,7 @@ import (
"path"
"regexp"
"strings"
"time"

"github.com/gofrs/flock"
"github.com/snyk/go-application-framework/pkg/configuration"
Expand Down Expand Up @@ -74,6 +77,11 @@ func NewCLIv2(config configuration.Configuration, debugLogger *log.Logger) (*CLI
return &cli, nil
}

// SetV1BinaryLocation for testing purposes
func (c *CLI) SetV1BinaryLocation(filePath string) {
c.v1BinaryLocation = filePath
}

func (c *CLI) Init() (err error) {
c.DebugLogger.Println("Init start")

Expand Down Expand Up @@ -200,7 +208,7 @@ func (c *CLI) GetBinaryLocation() string {
}

func (c *CLI) printVersion() {
fmt.Fprintln(c.stdout, GetFullVersion())
_, _ = fmt.Fprintln(c.stdout, GetFullVersion())
}

func (c *CLI) commandVersion(passthroughArgs []string) error {
Expand Down Expand Up @@ -235,8 +243,8 @@ func (c *CLI) commandAbout(proxyInfo *proxy.ProxyInfo, passthroughArgs []string)
}

fmt.Printf("Package: %s \n", strings.ReplaceAll(strings.ReplaceAll(fPath, "/licenses/", ""), "/"+f.Name(), ""))
fmt.Fprintln(c.stdout, string(data))
fmt.Fprint(c.stdout, separator)
_, _ = fmt.Fprintln(c.stdout, string(data))
_, _ = fmt.Fprint(c.stdout, separator)
}
}

Expand Down Expand Up @@ -341,15 +349,15 @@ func PrepareV1EnvironmentVariables(
}

func (c *CLI) PrepareV1Command(
ctx context.Context,
cmd string,
args []string,
proxyInfo *proxy.ProxyInfo,
integrationName string,
integrationVersion string,
) (snykCmd *exec.Cmd, err error) {
proxyAddress := fmt.Sprintf("http://%s:%[email protected]:%d", proxy.PROXY_USERNAME, proxyInfo.Password, proxyInfo.Port)

snykCmd = exec.Command(cmd, args...)
snykCmd = exec.CommandContext(ctx, cmd, args...)
snykCmd.Env, err = PrepareV1EnvironmentVariables(c.env, integrationName, integrationVersion, proxyAddress, proxyInfo.CertificateLocation, c.globalConfig, args)

if len(c.WorkingDirectory) > 0 {
Expand All @@ -360,8 +368,17 @@ func (c *CLI) PrepareV1Command(
}

func (c *CLI) executeV1Default(proxyInfo *proxy.ProxyInfo, passThroughArgs []string) error {
timeout := c.globalConfig.GetInt(configuration.TIMEOUT)
var ctx context.Context
var cancel context.CancelFunc
if timeout == 0 {
ctx = context.Background()
} else {
ctx, cancel = context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second)
defer cancel()
}

snykCmd, err := c.PrepareV1Command(c.v1BinaryLocation, passThroughArgs, proxyInfo, c.GetIntegrationName(), GetFullVersion())
snykCmd, err := c.PrepareV1Command(ctx, c.v1BinaryLocation, passThroughArgs, proxyInfo, c.GetIntegrationName(), GetFullVersion())

if c.DebugLogger.Writer() != io.Discard {
c.DebugLogger.Println("Launching: ")
Expand Down Expand Up @@ -397,13 +414,16 @@ func (c *CLI) executeV1Default(proxyInfo *proxy.ProxyInfo, passThroughArgs []str
snykCmd.Stderr = c.stderr

if err != nil {
if evWarning, ok := err.(EnvironmentWarning); ok {
fmt.Fprintln(c.stdout, "WARNING! ", evWarning)
var evWarning EnvironmentWarning
if errors.As(err, &evWarning) {
_, _ = fmt.Fprintln(c.stdout, "WARNING! ", evWarning)
}
}

err = snykCmd.Run()

if errors.Is(ctx.Err(), context.DeadlineExceeded) {
return ctx.Err()
}
return err
}

Expand All @@ -427,14 +447,17 @@ func DeriveExitCode(err error) int {
returnCode := constants.SNYK_EXIT_CODE_OK

if err != nil {
if exitError, ok := err.(*exec.ExitError); ok {
var exitError *exec.ExitError

if errors.As(err, &exitError) {
returnCode = exitError.ExitCode()
} else if errors.Is(err, context.DeadlineExceeded) {
returnCode = constants.SNYK_EXIT_CODE_EX_UNAVAILABLE
} else {
// got an error but it's not an ExitError
returnCode = constants.SNYK_EXIT_CODE_ERROR
}
}

return returnCode
}

Expand Down
57 changes: 43 additions & 14 deletions cliv2/internal/cliv2/cliv2_test.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
package cliv2_test

import (
"context"
"io"
"log"
"os"
"os/exec"
"path"
"runtime"
"sort"
"testing"
"time"
Expand Down Expand Up @@ -270,9 +272,11 @@ func Test_prepareV1Command(t *testing.T) {
cacheDir := getCacheDir(t)
config := configuration.NewInMemory()
config.Set(configuration.CACHE_PATH, cacheDir)
cli, _ := cliv2.NewCLIv2(config, discardLogger)
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)

snykCmd, err := cli.PrepareV1Command(
context.Background(),
"someExecutable",
expectedArgs,
getProxyInfoForTest(),
Expand All @@ -297,11 +301,12 @@ func Test_extractOnlyOnce(t *testing.T) {
assert.NoDirExists(t, tmpDir)

// create instance under test
cli, _ := cliv2.NewCLIv2(config, discardLogger)
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)
assert.NoError(t, cli.Init())

// run once
assert.Nil(t, cli.Init())
cli.Execute(getProxyInfoForTest(), []string{"--help"})
err = cli.Execute(getProxyInfoForTest(), []string{"--help"})
assert.FileExists(t, cli.GetBinaryLocation())
fileInfo1, _ := os.Stat(cli.GetBinaryLocation())

Expand All @@ -310,7 +315,7 @@ func Test_extractOnlyOnce(t *testing.T) {

// run twice
assert.Nil(t, cli.Init())
cli.Execute(getProxyInfoForTest(), []string{"--help"})
_ = cli.Execute(getProxyInfoForTest(), []string{"--help"})
assert.FileExists(t, cli.GetBinaryLocation())
fileInfo2, _ := os.Stat(cli.GetBinaryLocation())

Expand All @@ -326,7 +331,8 @@ func Test_init_extractDueToInvalidBinary(t *testing.T) {
assert.NoDirExists(t, tmpDir)

// create instance under test
cli, _ := cliv2.NewCLIv2(config, discardLogger)
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)

// fill binary with invalid data
_ = os.MkdirAll(tmpDir, 0755)
Expand Down Expand Up @@ -363,8 +369,9 @@ func Test_executeRunV2only(t *testing.T) {
assert.NoDirExists(t, tmpDir)

// create instance under test
cli, _ := cliv2.NewCLIv2(config, discardLogger)
assert.Nil(t, cli.Init())
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)
assert.NoError(t, cli.Init())

actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"--version"}))
assert.Equal(t, expectedReturnCode, actualReturnCode)
Expand All @@ -380,8 +387,9 @@ func Test_executeUnknownCommand(t *testing.T) {
config.Set(configuration.CACHE_PATH, cacheDir)

// create instance under test
cli, _ := cliv2.NewCLIv2(config, discardLogger)
assert.Nil(t, cli.Init())
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)
assert.NoError(t, cli.Init())

actualReturnCode := cliv2.DeriveExitCode(cli.Execute(getProxyInfoForTest(), []string{"bogusCommand"}))
assert.Equal(t, expectedReturnCode, actualReturnCode)
Expand Down Expand Up @@ -427,8 +435,9 @@ func Test_clearCacheBigCache(t *testing.T) {
config.Set(configuration.CACHE_PATH, cacheDir)

// create instance under test
cli, _ := cliv2.NewCLIv2(config, discardLogger)
assert.Nil(t, cli.Init())
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)
assert.NoError(t, cli.Init())

// create folders and files in cache dir
dir1 := path.Join(cli.CacheDirectory, "dir1")
Expand All @@ -447,8 +456,8 @@ func Test_clearCacheBigCache(t *testing.T) {
_ = os.Mkdir(dir6, 0755)

// clear cache
err := cli.ClearCache()
assert.Nil(t, err)
err = cli.ClearCache()
assert.NoError(t, err)

// check if directories that need to be deleted don't exist
assert.NoDirExists(t, dir1)
Expand All @@ -460,3 +469,23 @@ func Test_clearCacheBigCache(t *testing.T) {
assert.DirExists(t, dir6)
assert.FileExists(t, currentVersion)
}

func Test_setTimeout(t *testing.T) {
if //goland:noinspection ALL
runtime.GOOS == "windows" {
t.Skip("Skipping test on windows")
}
config := configuration.NewInMemory()
cli, err := cliv2.NewCLIv2(config, discardLogger)
assert.NoError(t, err)
config.Set(configuration.TIMEOUT, 1)

// sleep for 2s
cli.SetV1BinaryLocation("/bin/sleep")
err = cli.Execute(getProxyInfoForTest(), []string{"2"})

assert.ErrorIs(t, err, context.DeadlineExceeded)

// ensure that -1 is correctly mapped if timeout is set
assert.Equal(t, constants.SNYK_EXIT_CODE_EX_UNAVAILABLE, cliv2.DeriveExitCode(err))
}

0 comments on commit 1686db5

Please sign in to comment.