Skip to content

Commit

Permalink
Merge pull request cli#2230 from alissonbrunosa/fix-2179
Browse files Browse the repository at this point in the history
Add support for ssh_config Include directives
  • Loading branch information
mislav authored Dec 15, 2020
2 parents f3ea05f + 935f644 commit ae68da6
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 62 deletions.
13 changes: 12 additions & 1 deletion git/remote_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
package git

import "testing"
import (
"reflect"
"testing"
)

// TODO: extract assertion helpers into a shared package
func eq(t *testing.T, got interface{}, expected interface{}) {
t.Helper()
if !reflect.DeepEqual(got, expected) {
t.Errorf("expected: %v, got: %v", expected, got)
}
}

func Test_parseRemotes(t *testing.T) {
remoteList := []string{
Expand Down
143 changes: 100 additions & 43 deletions git/ssh_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,10 @@ import (
)

var (
sshHostRE,
sshTokenRE *regexp.Regexp
sshConfigLineRE = regexp.MustCompile(`\A\s*(?P<keyword>[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P<argument>.+)`)
sshTokenRE = regexp.MustCompile(`%[%h]`)
)

func init() {
sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$")
sshTokenRE = regexp.MustCompile(`%[%h]`)
}

// SSHAliasMap encapsulates the translation of SSH hostname aliases
type SSHAliasMap map[string]string

Expand All @@ -45,55 +40,79 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL {
}
}

// ParseSSHConfig constructs a map of SSH hostname aliases based on user and
// system configuration files
func ParseSSHConfig() SSHAliasMap {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}
if homedir, err := homedir.Dir(); err == nil {
userConfig := filepath.Join(homedir, ".ssh", "config")
configFiles = append([]string{userConfig}, configFiles...)
}
type sshParser struct {
homeDir string

openFiles := make([]io.Reader, 0, len(configFiles))
for _, file := range configFiles {
f, err := os.Open(file)
aliasMap SSHAliasMap
hosts []string

open func(string) (io.Reader, error)
glob func(string) ([]string, error)
}

func (p *sshParser) read(fileName string) error {
var file io.Reader
if p.open == nil {
f, err := os.Open(fileName)
if err != nil {
continue
return err
}
defer f.Close()
openFiles = append(openFiles, f)
file = f
} else {
var err error
file, err = p.open(fileName)
if err != nil {
return err
}
}
return sshParse(openFiles...)
}

func sshParse(r ...io.Reader) SSHAliasMap {
config := make(SSHAliasMap)
for _, file := range r {
_ = sshParseConfig(config, file)
if len(p.hosts) == 0 {
p.hosts = []string{"*"}
}
return config
}

func sshParseConfig(c SSHAliasMap, file io.Reader) error {
hosts := []string{"*"}
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
match := sshHostRE.FindStringSubmatch(line)
if match == nil {
m := sshConfigLineRE.FindStringSubmatch(scanner.Text())
if len(m) < 3 {
continue
}

names := strings.Fields(match[2])
if strings.EqualFold(match[1], "host") {
hosts = names
} else {
for _, host := range hosts {
for _, name := range names {
c[host] = sshExpandTokens(name, host)
keyword, arguments := strings.ToLower(m[1]), m[2]
switch keyword {
case "host":
p.hosts = strings.Fields(arguments)
case "hostname":
for _, host := range p.hosts {
for _, name := range strings.Fields(arguments) {
if p.aliasMap == nil {
p.aliasMap = make(SSHAliasMap)
}
p.aliasMap[host] = sshExpandTokens(name, host)
}
}
case "include":
for _, arg := range strings.Fields(arguments) {
path := p.absolutePath(fileName, arg)

var fileNames []string
if p.glob == nil {
paths, _ := filepath.Glob(path)
for _, p := range paths {
if s, err := os.Stat(p); err == nil && !s.IsDir() {
fileNames = append(fileNames, p)
}
}
} else {
var err error
fileNames, err = p.glob(path)
if err != nil {
continue
}
}

for _, fileName := range fileNames {
_ = p.read(fileName)
}
}
}
Expand All @@ -102,6 +121,44 @@ func sshParseConfig(c SSHAliasMap, file io.Reader) error {
return scanner.Err()
}

func (p *sshParser) absolutePath(parentFile, path string) string {
if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") {
return path
}

if strings.HasPrefix(path, "~") {
return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~"))
}

if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") {
return filepath.Join("/etc/ssh", path)
}

return filepath.Join(p.homeDir, ".ssh", path)
}

// ParseSSHConfig constructs a map of SSH hostname aliases based on user and
// system configuration files
func ParseSSHConfig() SSHAliasMap {
configFiles := []string{
"/etc/ssh_config",
"/etc/ssh/ssh_config",
}

p := sshParser{}

if homedir, err := homedir.Dir(); err == nil {
userConfig := filepath.Join(homedir, ".ssh", "config")
configFiles = append([]string{userConfig}, configFiles...)
p.homeDir = homedir
}

for _, file := range configFiles {
_ = p.read(file)
}
return p.aliasMap
}

func sshExpandTokens(text, host string) string {
return sshTokenRE.ReplaceAllStringFunc(text, func(match string) string {
switch match {
Expand Down
132 changes: 114 additions & 18 deletions git/ssh_config_test.go
Original file line number Diff line number Diff line change
@@ -1,31 +1,127 @@
package git

import (
"bytes"
"fmt"
"io"
"net/url"
"reflect"
"strings"
"path/filepath"
"testing"

"github.com/MakeNowJust/heredoc"
)

// TODO: extract assertion helpers into a shared package
func eq(t *testing.T, got interface{}, expected interface{}) {
t.Helper()
if !reflect.DeepEqual(got, expected) {
t.Errorf("expected: %v, got: %v", expected, got)
func Test_sshParser_read(t *testing.T) {
testFiles := map[string]string{
"/etc/ssh/config": heredoc.Doc(`
Include sites/*
`),
"/etc/ssh/sites/cfg1": heredoc.Doc(`
Host s1
Hostname=site1.net
`),
"/etc/ssh/sites/cfg2": heredoc.Doc(`
Host s2
Hostname = site2.net
`),
"HOME/.ssh/config": heredoc.Doc(`
Host *
Host gh gittyhubby
Hostname github.com
#Hostname example.com
Host ex
Include ex_config/*
`),
"HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(`
Hostname example.com
`),
}
globResults := map[string][]string{
"/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"},
"HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"},
}

p := &sshParser{
homeDir: "HOME",
open: func(s string) (io.Reader, error) {
if contents, ok := testFiles[filepath.ToSlash(s)]; ok {
return bytes.NewBufferString(contents), nil
} else {
return nil, fmt.Errorf("no test file stub found: %q", s)
}
},
glob: func(p string) ([]string, error) {
if results, ok := globResults[filepath.ToSlash(p)]; ok {
return results, nil
} else {
return nil, fmt.Errorf("no glob stubs found: %q", p)
}
},
}

if err := p.read("/etc/ssh/config"); err != nil {
t.Fatalf("read(global config) = %v", err)
}
if err := p.read("HOME/.ssh/config"); err != nil {
t.Fatalf("read(user config) = %v", err)
}

if got := p.aliasMap["gh"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got)
}
if got := p.aliasMap["gittyhubby"]; got != "github.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got)
}
if got := p.aliasMap["example.com"]; got != "" {
t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got)
}
if got := p.aliasMap["ex"]; got != "example.com" {
t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got)
}
if got := p.aliasMap["s1"]; got != "site1.net" {
t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got)
}
}

func Test_sshParse(t *testing.T) {
m := sshParse(strings.NewReader(`
Host foo bar
HostName example.com
`), strings.NewReader(`
Host bar baz
hostname %%%h.net%%
`))
eq(t, m["foo"], "example.com")
eq(t, m["bar"], "%bar.net%")
eq(t, m["nonexistent"], "")
func Test_sshParser_absolutePath(t *testing.T) {
dir := "HOME"
p := &sshParser{homeDir: dir}

tests := map[string]struct {
parentFile string
arg string
want string
wantErr bool
}{
"absolute path": {
parentFile: "/etc/ssh/ssh_config",
arg: "/etc/ssh/config",
want: "/etc/ssh/config",
},
"system relative path": {
parentFile: "/etc/ssh/config",
arg: "configs/*.conf",
want: filepath.Join("/etc", "ssh", "configs", "*.conf"),
},
"user relative path": {
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
arg: "configs/*.conf",
want: filepath.Join(dir, ".ssh", "configs/*.conf"),
},
"shell-like ~ rerefence": {
parentFile: filepath.Join(dir, ".ssh", "ssh_config"),
arg: "~/.ssh/*.conf",
want: filepath.Join(dir, ".ssh", "*.conf"),
},
}

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want {
t.Errorf("absolutePath(): %q, wants %q", got, tt.want)
}
})
}
}

func Test_Translator(t *testing.T) {
Expand Down

0 comments on commit ae68da6

Please sign in to comment.