Skip to content

Commit

Permalink
Instance lookup by IP and hostname via private IP addresses on attach…
Browse files Browse the repository at this point in the history
…ed NIC (#54)

* Merges Attribute and Tag filters for defining matchers for non-tag values
* Adds an --attribute flag to session calls
  • Loading branch information
Nate Catelli authored Sep 24, 2021
1 parent 1e6dda1 commit d7241cb
Show file tree
Hide file tree
Showing 15 changed files with 274 additions and 49 deletions.
73 changes: 73 additions & 0 deletions aws/resolver/resolve.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package resolver

import (
"fmt"
"net"

"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
)

// Provides a interface for resolving an instance to a single ec2 instance.
type InstanceResolver interface {
ResolveToInstanceId(client ec2iface.EC2API) ([]string, error)
}

/// HostnameResolver attempts to resolve a fqdn or IP to a corresponding instance ID.
type HostnameResolver struct {
addrs []string
}

func NewHostnameResolver(addrs []string) *HostnameResolver {
return &HostnameResolver{
addrs: addrs,
}
}

func (hr *HostnameResolver) ResolveToInstanceId(client ec2iface.EC2API) (output []string, err error) {
ips := make([]*string, 1)
for _, addr := range hr.addrs {
ip, err := resolveToFirst(addr)
if err != nil {
return nil, fmt.Errorf("unable to resolve hostname to %v to ip", hr.addrs)
}

ipString := ip.String()
ips = append(ips, &ipString)

}

ipFilter := &ec2.Filter{}
ipFilter.SetName("addresses.private-ip-address").SetValues(ips)

dniInput := &ec2.DescribeNetworkInterfacesInput{}
dniInput.SetFilters([]*ec2.Filter{ipFilter})

describeNetworkInterfacesPager := func(page *ec2.DescribeNetworkInterfacesOutput, lastPage bool) bool {
for _, nic := range page.NetworkInterfaces {
output = append(output, *nic.Attachment.InstanceId)
}

// If it's not the last page, continue
return !lastPage
}

// Fetch all the instances described
if err = client.DescribeNetworkInterfacesPages(dniInput, describeNetworkInterfacesPager); err != nil {
return nil, fmt.Errorf("could not describe network interfaces\n%v", err)
}

return output, nil
}

func resolveToFirst(addr string) (net.IP, error) {
if ip := net.ParseIP(addr); ip != nil {
return ip, nil
} else if ips, err := net.LookupIP(addr); err == nil {
if len(ips[0]) > 0 {
return ips[0], nil
}
}

return nil, fmt.Errorf("no IP address found for %s", addr)
}
73 changes: 73 additions & 0 deletions aws/resolver/resolver_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package resolver

import (
"io/ioutil"
"net"
"testing"

"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ec2/ec2iface"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
)

type mockedEC2 struct {
ec2iface.EC2API
DescribeNetworkInterfacesOutput []*ec2.DescribeNetworkInterfacesOutput
}

func (c *mockedEC2) DescribeNetworkInterfacesPages(input *ec2.DescribeNetworkInterfacesInput, fn func(*ec2.DescribeNetworkInterfacesOutput, bool) bool) error {
totalPages := len(c.DescribeNetworkInterfacesOutput)
for i, output := range c.DescribeNetworkInterfacesOutput {
isLastPage := (i == (totalPages - 1))
if breakLoop := fn(output, isLastPage); breakLoop {
break
}
}
return nil
}

var (
exampleInstanceId = "i-1234567890abcdef0"
examplePrivateIpAddress = "127.0.0.1"
singleResponseDescribeNetworkInterfacesOutput = ec2.DescribeNetworkInterfacesOutput{
NetworkInterfaces: []*ec2.NetworkInterface{
{
Attachment: &ec2.NetworkInterfaceAttachment{
InstanceId: &exampleInstanceId,
},
PrivateIpAddress: &examplePrivateIpAddress,
},
},
NextToken: nil,
}
)

func TestHostnameResolver(t *testing.T) {
assert := assert.New(t)

logger := logrus.New()
logger.SetOutput(ioutil.Discard)

t.Run("test passed ip causes a short circuit", func(t *testing.T) {
validIp, err := resolveToFirst(examplePrivateIpAddress)
assert.Nil(err)
assert.NotNil(validIp)
assert.EqualValues(validIp, net.ParseIP(examplePrivateIpAddress))
})

t.Run("test passed passed hostname resolves to an IP", func(t *testing.T) {
validIp, err := resolveToFirst("example.com")
assert.Nil(err)
assert.NotNil(validIp)
})

t.Run("test mock resolver returns a valid instance id", func(t *testing.T) {
mockClient := &mockedEC2{DescribeNetworkInterfacesOutput: []*ec2.DescribeNetworkInterfacesOutput{&singleResponseDescribeNetworkInterfacesOutput}}
testResolver := NewHostnameResolver([]string{examplePrivateIpAddress})

resp, err := testResolver.ResolveToInstanceId(mockClient)
assert.Nil(err)
assert.EqualValues(resp, []string{exampleInstanceId})
})
}
10 changes: 10 additions & 0 deletions cmd/cmdutil/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ func AddInstanceFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("instance", "i", nil, "Specify what instance IDs you want to target.\nMultiple allowed, delimited by commas (e.g. --instance i-12345,i-23456)")
}

// AddsHostnameFlags adds --address to command
func AddHostnameFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("address", "a", nil, "Specify what Address or FQDN you want to target.\nMultiple allowed, delimited by commas (e.g. --address 10.240.12.6,10.240.12.7)")
}

// AddAllProfilesFlag adds --all-profiles to command
func AddAllProfilesFlag(cmd *cobra.Command) {
cmd.Flags().Bool("all-profiles", false, "[USE WITH CAUTION] Parse through ~/.aws/config to target all profiles.")
Expand All @@ -81,6 +86,11 @@ func AddTagFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("tag", "t", nil, "Adds the specified tag as an additional column to be displayed during the instance selection prompt.")
}

// AddsAttributeFlag adds the --attribute flag to command.
func AddAttributeFlag(cmd *cobra.Command) {
cmd.Flags().StringSliceP("attribute", "x", nil, "Adds the specified attribute as an additional column to be displayed during the instance selection prompt.")
}

// AddSessionNameFlag adds --session-name to command
func AddSessionNameFlag(cmd *cobra.Command, defaultName string) {
cmd.Flags().String("session-name", defaultName, "Specify a name for the tmux session created when multiple instances are selected")
Expand Down
6 changes: 4 additions & 2 deletions cmd/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ func addBaseFlags(cmd *cobra.Command) {
cmdutil.AddDryRunFlag(cmd)
cmdutil.AddFilterFlag(cmd)
cmdutil.AddInstanceFlag(cmd)
cmdutil.AddHostnameFlag(cmd)
cmdutil.AddProfileFlag(cmd)
cmdutil.AddRegionFlag(cmd)
}
Expand All @@ -33,6 +34,7 @@ func addRunFlags(cmd *cobra.Command) {

func addSessionFlags(cmd *cobra.Command) {
cmdutil.AddTagFlag(cmd)
cmdutil.AddAttributeFlag(cmd)
cmdutil.AddSessionNameFlag(cmd, "ssm-session")
cmdutil.AddLimitFlag(cmd, 10, "Set a limit for the number of instance results returned per profile/region combination.")
}
Expand Down Expand Up @@ -159,12 +161,12 @@ func validateSessionFlags(cmd *cobra.Command, instanceList []string, filterList
}

// validateRunFlags validates the usage of certain flags required by the run subcommand
func validateRunFlags(cmd *cobra.Command, instanceList []string, commandList []string, filterList []*ssm.Target) error {
func validateRunFlags(cmd *cobra.Command, instanceList []string, addressList []string, commandList []string, filterList []*ssm.Target) error {
if len(instanceList) > 0 && len(filterList) > 0 {
return cmdutil.UsageError(cmd, "The --filter and --instance flags cannot be used simultaneously.")
}

if len(instanceList) == 0 && len(filterList) == 0 {
if len(instanceList) == 0 && len(addressList) == 0 && len(filterList) == 0 {
return cmdutil.UsageError(cmd, "You must supply target arguments using either the --filter or --instance flags.")
}

Expand Down
13 changes: 7 additions & 6 deletions cmd/flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,39 +316,40 @@ func Test_validateRunFlags(t *testing.T) {
cmdutil.AddCommandFlag(cmd)
cmdutil.AddFilterFlag(cmd)
cmdutil.AddInstanceFlag(cmd)
cmdutil.AddHostnameFlag(cmd)
cmd.Execute()

instanceList := make([]string, 51)

t.Run("try to use --filter and --instance flags", func(t *testing.T) {
targetList := make([]*ssm.Target, 2)
err := validateRunFlags(cmd, instanceList, []string{"hostname"}, targetList)
err := validateRunFlags(cmd, instanceList, nil, []string{"hostname"}, targetList)
assert.Error(err)
})

t.Run("specify more than 5 filters", func(t *testing.T) {
targetList := make([]*ssm.Target, 6)
err := validateRunFlags(cmd, nil, []string{"hostname"}, targetList)
err := validateRunFlags(cmd, nil, nil, []string{"hostname"}, targetList)
assert.Error(err)
})

t.Run("no instances or filters specified", func(t *testing.T) {
err := validateRunFlags(cmd, nil, []string{"hostname"}, nil)
err := validateRunFlags(cmd, nil, nil, []string{"hostname"}, nil)
assert.Error(err)
})

t.Run(">50 specified instances", func(t *testing.T) {
err := validateRunFlags(cmd, instanceList, []string{"hostname"}, nil)
err := validateRunFlags(cmd, instanceList, nil, []string{"hostname"}, nil)
assert.Error(err)
})

t.Run("no command specified", func(t *testing.T) {
err := validateRunFlags(cmd, []string{"myInstance"}, nil, nil)
err := validateRunFlags(cmd, []string{"myInstance"}, nil, nil, nil)
assert.Error(err)
})

t.Run("valid flag combination", func(t *testing.T) {
err := validateRunFlags(cmd, []string{"myInstance"}, []string{"hostname"}, nil)
err := validateRunFlags(cmd, []string{"myInstance"}, nil, []string{"hostname"}, nil)
assert.NoError(err)
})
}
Expand Down
20 changes: 18 additions & 2 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@ import (
"sync"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ssm"
"github.com/spf13/cobra"

"github.com/disneystreaming/ssm-helpers/aws/resolver"
"github.com/disneystreaming/ssm-helpers/aws/session"
"github.com/disneystreaming/ssm-helpers/cmd/cmdutil"
ssmx "github.com/disneystreaming/ssm-helpers/ssm"
Expand All @@ -34,7 +36,7 @@ func newCommandSSMRun() *cobra.Command {

func runCommand(cmd *cobra.Command, args []string) {
var err error
var instanceList, commandList, profileList, regionList []string
var instanceList, addressList, commandList, profileList, regionList []string
var maxConcurrency, maxErrors string
var targets []*ssm.Target

Expand All @@ -46,14 +48,17 @@ func runCommand(cmd *cobra.Command, args []string) {
if instanceList, err = cmdutil.GetFlagStringSlice(cmd, "instance"); err != nil {
log.Fatal(err)
}
if addressList, err = cmdutil.GetFlagStringSlice(cmd, "address"); err != nil {
log.Fatal(err)
}
if commandList, err = getCommandList(cmd); err != nil {
log.Fatal(err)
}
if targets, err = getTargetList(cmd); err != nil {
log.Fatal(err)
}

if err := validateRunFlags(cmd, instanceList, commandList, targets); err != nil {
if err := validateRunFlags(cmd, instanceList, addressList, commandList, targets); err != nil {
log.Fatal(err)
}

Expand Down Expand Up @@ -101,7 +106,18 @@ func runCommand(cmd *cobra.Command, args []string) {
sessionPool := session.NewPool(profileList, regionList, log)
for _, sess := range sessionPool.Sessions {
wg.Add(1)
var threadLocalSendCommandInput *ssm.SendCommandInput = &(*sciInput)
ssmClient := ssm.New(sess.Session)

if len(addressList) > 0 {
ec2Client := ec2.New(sess.Session)
hr := resolver.NewHostnameResolver(addressList)
ids, _ := hr.ResolveToInstanceId(ec2Client)
for _, id := range ids {
threadLocalSendCommandInput.InstanceIds = append(threadLocalSendCommandInput.InstanceIds, &id)
}
}

log.Debugf("Starting invocation targeting account %s in %s", sess.ProfileName, *sess.Session.Config.Region)
go ssmx.RunInvocations(sess, ssmClient, &wg, sciInput, &output)
}
Expand Down
Loading

0 comments on commit d7241cb

Please sign in to comment.