Skip to content

Commit

Permalink
Merge pull request remind101#12 from remind101/duration
Browse files Browse the repository at this point in the history
Change default duration back to 1 hour
ejholmes authored May 31, 2017
2 parents 839fbe6 + 7d9496a commit 40a3af7
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
@@ -2,15 +2,18 @@ package main

import (
"bufio"
"flag"
"fmt"
"io/ioutil"
"os"
"os/exec"
"strings"
"syscall"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/sts"
"gopkg.in/yaml.v2"
@@ -19,18 +22,30 @@ import (
var configFilePath = fmt.Sprintf("%s/.aws/roles", os.Getenv("HOME"))

func usage() {
fmt.Print(`Usage: assume-role <role> [<command> <args...>]
`)
fmt.Fprintf(os.Stderr, "Usage: %s <role> [<command> <args...>]\n", os.Args[0])
flag.PrintDefaults()
}

func init() {
flag.Usage = usage
}

func main() {
if len(os.Args) < 2 {
usage()
var (
duration = flag.Duration("duration", time.Hour, "The duration that the credentials will be valid for.")
)

flag.Parse()
argv := flag.Args()
if len(argv) < 1 {
flag.Usage()
os.Exit(1)
}

role := os.Args[1]
args := os.Args[2:]
stscreds.DefaultDuration = *duration

role := argv[0]
args := argv[1:]

// Load credentials from configFilePath if it exists, else use regular AWS config
var creds *credentials.Value
@@ -52,7 +67,7 @@ func main() {
cleanEnv()
}

creds, err = assumeRole(roleConfig.Role, roleConfig.MFA)
creds, err = assumeRole(roleConfig.Role, roleConfig.MFA, *duration)
must(err)
} else {
if os.Getenv("ASSUMED_ROLE") != "" {
@@ -123,14 +138,15 @@ func assumeProfile(profile string) (*credentials.Value, error) {
}

// assumeRole assumes the given role and returns the temporary STS credentials.
func assumeRole(role, mfa string) (*credentials.Value, error) {
func assumeRole(role, mfa string, duration time.Duration) (*credentials.Value, error) {
sess := session.Must(session.NewSession())

svc := sts.New(sess)

params := &sts.AssumeRoleInput{
RoleArn: aws.String(role),
RoleSessionName: aws.String("cli"),
DurationSeconds: aws.Int64(int64(duration / time.Second)),
}
if mfa != "" {
params.SerialNumber = aws.String(mfa)

0 comments on commit 40a3af7

Please sign in to comment.