Skip to content

Commit

Permalink
Initial attempt at s3 getter
Browse files Browse the repository at this point in the history
  • Loading branch information
dancannon committed Oct 22, 2015
1 parent 2463fe5 commit f8a65f2
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 1 deletion.
1 change: 1 addition & 0 deletions detect.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func init() {
Detectors = []Detector{
new(GitHubDetector),
new(BitBucketDetector),
new(S3Detector),
new(FileDetector),
}
}
Expand Down
89 changes: 89 additions & 0 deletions detect_s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
package getter

import (
"fmt"
"net/url"
"strings"
)

const (
vhostFormat = ""
)

// S3Detector implements Detector to detect S3 URLs and turn
// them into URLs that the S3 getter can understand.
type S3Detector struct{}

func (d *S3Detector) Detect(src, _ string) (string, bool, error) {
if len(src) == 0 {
return "", false, nil
}

if strings.Contains(src, ".amazonaws.com/") {
return d.detectHTTP(src)
}

return "", false, nil
}

func (d *S3Detector) detectHTTP(src string) (string, bool, error) {
parts := strings.Split(src, "/")
if len(parts) < 0 {
return "", false, fmt.Errorf(
"URL is not a valid S3 URL")
}

hostParts := strings.Split(parts[0], ".")
if len(hostParts) == 3 {
return d.detectPathStyle(hostParts[0], parts[1:])
} else if len(hostParts) == 4 {
return d.detectVhostStyle(hostParts[1], hostParts[0], parts[1:])
} else {
return "", false, fmt.Errorf(
"URL is not a valid S3 URL")
}
}

func (d *S3Detector) detectPathStyle(region string, parts []string) (string, bool, error) {
urlStr := fmt.Sprintf("https://%s.amazonaws.com/%s", region, strings.Join(parts, "/"))
url, err := url.Parse(urlStr)
if err != nil {
return "", true, fmt.Errorf("error parsing GitHub URL: %s", err)
}

return "s3::" + url.String(), true, nil
}

func (d *S3Detector) detectVhostStyle(region, bucket string, parts []string) (string, bool, error) {
urlStr := fmt.Sprintf("https://%s.amazonaws.com/%s/%s", region, bucket, strings.Join(parts, "/"))
url, err := url.Parse(urlStr)
if err != nil {
return "", true, fmt.Errorf("error parsing S3 URL: %s", err)
}

return "s3::" + url.String(), true, nil
}

// func (d *S3Detector) detectSSH(src string) (string, bool, error) {
// idx := strings.Index(src, ":")
// qidx := strings.Index(src, "?")
// if qidx == -1 {
// qidx = len(src)
// }

// var u url.URL
// u.Scheme = "ssh"
// u.User = url.User("git")
// u.Host = "github.com"
// u.Path = src[idx+1 : qidx]
// if qidx < len(src) {
// q, err := url.ParseQuery(src[qidx+1:])
// if err != nil {
// return "", true, fmt.Errorf("error parsing GitHub SSH URL: %s", err)
// }

// u.RawQuery = q.Encode()
// }

// return "git::" + u.String(), true, nil
// }
79 changes: 79 additions & 0 deletions detect_s3_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package getter

import (
"testing"
)

func TestS3Detector(t *testing.T) {
cases := []struct {
Input string
Output string
}{
// Virtual hosted style
{
"bucket.s3.amazonaws.com/foo",
"s3::https://s3.amazonaws.com/bucket/foo",
},
{
"bucket.s3.amazonaws.com/foo/bar",
"s3::https://s3.amazonaws.com/bucket/foo/bar",
},
{
"bucket.s3.amazonaws.com/foo/bar.baz",
"s3::https://s3.amazonaws.com/bucket/foo/bar.baz",
},
{
"bucket.s3-eu-west-1.amazonaws.com/foo",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
},
{
"bucket.s3-eu-west-1.amazonaws.com/foo/bar",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar",
},
{
"bucket.s3-eu-west-1.amazonaws.com/foo/bar.baz",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar.baz",
},
// Path style
{
"s3.amazonaws.com/bucket/foo",
"s3::https://s3.amazonaws.com/bucket/foo",
},
{
"s3.amazonaws.com/bucket/foo/bar",
"s3::https://s3.amazonaws.com/bucket/foo/bar",
},
{
"s3.amazonaws.com/bucket/foo/bar.baz",
"s3::https://s3.amazonaws.com/bucket/foo/bar.baz",
},
{
"s3-eu-west-1.amazonaws.com/bucket/foo",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
},
{
"s3-eu-west-1.amazonaws.com/bucket/foo/bar",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar",
},
{
"s3-eu-west-1.amazonaws.com/bucket/foo/bar.baz",
"s3::https://s3-eu-west-1.amazonaws.com/bucket/foo/bar.baz",
},
}

pwd := "/pwd"
f := new(S3Detector)
for i, tc := range cases {
output, ok, err := f.Detect(tc.Input, pwd)
if err != nil {
t.Fatalf("err: %s", err)
}
if !ok {
t.Fatal("not ok")
}

if output != tc.Output {
t.Fatalf("%d: bad: %#v", i, output)
}
}
}
3 changes: 2 additions & 1 deletion get.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var Getters map[string]Getter

// forcedRegexp is the regular expression that finds forced getters. This
// syntax is schema::url, example: git::https://foo.com
var forcedRegexp = regexp.MustCompile(`^([A-Za-z]+)::(.+)$`)
var forcedRegexp = regexp.MustCompile(`^([A-Za-z0-9]+)::(.+)$`)

func init() {
httpGetter := new(HttpGetter)
Expand All @@ -52,6 +52,7 @@ func init() {
"file": new(FileGetter),
"git": new(GitGetter),
"hg": new(HgGetter),
"s3": new(S3Getter),
"http": httpGetter,
"https": httpGetter,
}
Expand Down
66 changes: 66 additions & 0 deletions get_s3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package getter

import (
"fmt"
"io"
"net/url"
"os"
"path/filepath"
"strings"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
)

// S3Getter is a Getter implementation that will download a module from
// a S3 bucket.
type S3Getter struct{}

func (g *S3Getter) Get(dst string, u *url.URL) error {
return fmt.Errorf("Operation is unsupported")
}

func (g *S3Getter) GetFile(dst string, u *url.URL) error {
region, bucket, path, err := g.parseUrl(u)
if err != nil {
return err
}

client := s3.New(&aws.Config{Region: aws.String(region)})
resp, err := client.GetObject(&s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(path),
})
if err != nil {
return err
}

// Create all the parent directories
if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil {
return err
}

f, err := os.Create(dst)
if err != nil {
return err
}
defer f.Close()

_, err = io.Copy(f, resp.Body)
return err
}

func (g *S3Getter) parseUrl(u *url.URL) (region, bucket, path string, err error) {
hostParts := strings.Split(u.Host, ".")

if len(hostParts) != 3 {
return "", "", "", fmt.Errorf("URL is not a valid S3 URL")
}
region = strings.TrimPrefix(strings.TrimPrefix(hostParts[0], "s3-"), "s3")

pathParts := strings.Split(u.Path, "/")
bucket = pathParts[1]
path = strings.Join(pathParts[2:], "/")

return
}

0 comments on commit f8a65f2

Please sign in to comment.