Skip to content

Commit

Permalink
Merge pull request hashicorp#133 from hashicorp/gracefull_termination
Browse files Browse the repository at this point in the history
Gracefull (context) cancellation
  • Loading branch information
vancluever authored Jan 17, 2019
2 parents c68364a + bde4550 commit 3279f78
Show file tree
Hide file tree
Showing 13 changed files with 225 additions and 50 deletions.
6 changes: 5 additions & 1 deletion client.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package getter

import (
"context"
"fmt"
"io/ioutil"
"os"
Expand All @@ -18,6 +19,9 @@ import (
// Using a client directly allows more fine-grained control over how downloading
// is done, as well as customizing the protocols supported.
type Client struct {
// Ctx for cancellation
Ctx context.Context

// Src is the source URL to get.
//
// Dst is the path to save the downloaded thing as. If Dir is set to
Expand Down Expand Up @@ -287,7 +291,7 @@ func (c *Client) Get() error {
return err
}

return copyDir(realDst, subDir, false)
return copyDir(c.Ctx, realDst, subDir, false)
}

return nil
Expand Down
14 changes: 14 additions & 0 deletions client_option.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
package getter

import "context"

// A ClientOption allows to configure a client
type ClientOption func(*Client) error

// Configure configures a client with options.
func (c *Client) Configure(opts ...ClientOption) error {
if c.Ctx == nil {
c.Ctx = context.Background()
}
c.Options = opts
for _, opt := range opts {
err := opt(c)
Expand All @@ -30,3 +35,12 @@ func (c *Client) Configure(opts ...ClientOption) error {
}
return nil
}

// WithContext allows to pass a context to operation
// in order to be able to cancel a download in progress.
func WithContext(ctx context.Context) func(*Client) error {
return func(c *Client) error {
c.Ctx = ctx
return nil
}
}
54 changes: 39 additions & 15 deletions cmd/go-getter/main.go
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
package main

import (
"context"
"flag"
"log"
"os"
"os/signal"
"sync"

"github.com/hashicorp/go-getter"
getter "github.com/hashicorp/go-getter"
)

func main() {
Expand Down Expand Up @@ -36,28 +39,49 @@ func main() {
pwd, err := os.Getwd()
if err != nil {
log.Fatalf("Error getting wd: %s", err)
os.Exit(1)
}

// Build the client
client := &getter.Client{
Src: args[0],
Dst: args[1],
Pwd: pwd,
Mode: mode,
}
var opts []getter.ClientOption
opts := []getter.ClientOption{}
if *progress {
opts = append(opts, getter.WithProgress(defaultProgressBar))
}

if err := client.Configure(opts...); err != nil {
log.Fatalf("Configure: %s", err)
ctx, cancel := context.WithCancel(context.Background())
// Build the client
client := &getter.Client{
Ctx: ctx,
Src: args[0],
Dst: args[1],
Pwd: pwd,
Mode: mode,
Options: opts,
}

if err := client.Get(); err != nil {
wg := sync.WaitGroup{}
wg.Add(1)
errChan := make(chan error, 2)
go func() {
defer wg.Done()
defer cancel()
if err := client.Get(); err != nil {
errChan <- err
}
}()

c := make(chan os.Signal)
signal.Notify(c, os.Interrupt)

select {
case sig := <-c:
signal.Reset(os.Interrupt)
cancel()
wg.Wait()
log.Printf("signal %v", sig)
case <-ctx.Done():
wg.Wait()
log.Printf("success!")
case err := <-errChan:
wg.Wait()
log.Fatalf("Error downloading: %s", err)
}

log.Println("Success!")
}
6 changes: 3 additions & 3 deletions copy_dir.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package getter

import (
"io"
"context"
"os"
"path/filepath"
"strings"
Expand All @@ -11,7 +11,7 @@ import (
// should already exist.
//
// If ignoreDot is set to true, then dot-prefixed files/folders are ignored.
func copyDir(dst string, src string, ignoreDot bool) error {
func copyDir(ctx context.Context, dst string, src string, ignoreDot bool) error {
src, err := filepath.EvalSymlinks(src)
if err != nil {
return err
Expand Down Expand Up @@ -66,7 +66,7 @@ func copyDir(dst string, src string, ignoreDot bool) error {
}
defer dstF.Close()

if _, err := io.Copy(dstF, srcF); err != nil {
if _, err := Copy(ctx, dstF, srcF); err != nil {
return err
}

Expand Down
11 changes: 11 additions & 0 deletions get_base.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
package getter

import "context"

// getter is our base getter; it regroups
// fields all getters have in common.
type getter struct {
client *Client
}

func (g *getter) SetClient(c *Client) { g.client = c }

// Context tries to returns the Contex from the getter's
// client. otherwise context.Background() is returned.
func (g *getter) Context() context.Context {
if g == nil || g.client == nil {
return context.Background()
}
return g.client.Ctx
}
29 changes: 29 additions & 0 deletions get_file_copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package getter

import (
"context"
"io"
)

// readerFunc is syntactic sugar for read interface.
type readerFunc func(p []byte) (n int, err error)

func (rf readerFunc) Read(p []byte) (n int, err error) { return rf(p) }

// Copy is a io.Copy cancellable by context
func Copy(ctx context.Context, dst io.Writer, src io.Reader) (int64, error) {
// Copy will call the Reader and Writer interface multiple time, in order
// to copy by chunk (avoiding loading the whole file in memory).
return io.Copy(dst, readerFunc(func(p []byte) (int, error) {

select {
case <-ctx.Done():
// context has been canceled
// stop process and propagate "context canceled" error
return 0, ctx.Err()
default:
// otherwise just run default io.Reader implementation
return src.Read(p)
}
}))
}
82 changes: 82 additions & 0 deletions get_file_copy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package getter

import (
"bytes"
"context"
"io"
"testing"
"time"
)

// OneDoneContext is a context that is
// cancelled after a first done is called.
type OneDoneContext bool

func (*OneDoneContext) Deadline() (deadline time.Time, ok bool) { return }
func (*OneDoneContext) Value(key interface{}) interface{} { return nil }

func (o *OneDoneContext) Err() error {
if *o == false {
return nil
}
return context.Canceled
}

func (o *OneDoneContext) Done() <-chan struct{} {
if *o == false {
*o = true
return nil
}
c := make(chan struct{})
close(c)
return c
}

func (o *OneDoneContext) String() string {
if *o {
return "done OneDoneContext"
}
return "OneDoneContext"
}

func TestCopy(t *testing.T) {
const text3lines = `line1
line2
line3
`

cancelledContext, cancel := context.WithCancel(context.Background())
_ = cancelledContext
cancel()
type args struct {
ctx context.Context
src io.Reader
}
tests := []struct {
name string
args args
want int64
wantDst string
wantErr error
}{
{"read all", args{context.Background(), bytes.NewBufferString(text3lines)}, int64(len(text3lines)), text3lines, nil},
{"read none", args{cancelledContext, bytes.NewBufferString(text3lines)}, 0, "", context.Canceled},
{"cancel after read", args{new(OneDoneContext), bytes.NewBufferString(text3lines)}, int64(len(text3lines)), text3lines, context.Canceled},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
dst := &bytes.Buffer{}
got, err := Copy(tt.args.ctx, dst, tt.args.src)
if err != tt.wantErr {
t.Errorf("Copy() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Copy() = %v, want %v", got, tt.want)
}
if gotDst := dst.String(); gotDst != tt.wantDst {
t.Errorf("Copy() = %v, want %v", gotDst, tt.wantDst)
}
})
}
}
4 changes: 2 additions & 2 deletions get_file_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package getter

import (
"fmt"
"io"
"net/url"
"os"
"path/filepath"
Expand Down Expand Up @@ -50,6 +49,7 @@ func (g *FileGetter) Get(dst string, u *url.URL) error {
}

func (g *FileGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down Expand Up @@ -98,6 +98,6 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
}
defer dstF.Close()

_, err = io.Copy(dstF, srcF)
_, err = Copy(ctx, dstF, srcF)
return err
}
7 changes: 4 additions & 3 deletions get_file_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package getter

import (
"fmt"
"io"
"net/url"
"os"
"os/exec"
Expand All @@ -13,6 +12,7 @@ import (
)

func (g *FileGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down Expand Up @@ -51,7 +51,7 @@ func (g *FileGetter) Get(dst string, u *url.URL) error {
sourcePath := toBackslash(path)

// Use mklink to create a junction point
output, err := exec.Command("cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput()
output, err := exec.CommandContext(ctx, "cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run mklink %v %v: %v %q", dst, sourcePath, err, output)
}
Expand All @@ -60,6 +60,7 @@ func (g *FileGetter) Get(dst string, u *url.URL) error {
}

func (g *FileGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down Expand Up @@ -108,7 +109,7 @@ func (g *FileGetter) GetFile(dst string, u *url.URL) error {
}
defer dstF.Close()

_, err = io.Copy(dstF, srcF)
_, err = Copy(ctx, dstF, srcF)
return err
}

Expand Down
Loading

0 comments on commit 3279f78

Please sign in to comment.