diff --git a/git/remote_test.go b/git/remote_test.go index 2e7d30cb622..e8c0916534f 100644 --- a/git/remote_test.go +++ b/git/remote_test.go @@ -1,6 +1,17 @@ package git -import "testing" +import ( + "reflect" + "testing" +) + +// TODO: extract assertion helpers into a shared package +func eq(t *testing.T, got interface{}, expected interface{}) { + t.Helper() + if !reflect.DeepEqual(got, expected) { + t.Errorf("expected: %v, got: %v", expected, got) + } +} func Test_parseRemotes(t *testing.T) { remoteList := []string{ diff --git a/git/ssh_config.go b/git/ssh_config.go index 287298cd944..317ff605941 100644 --- a/git/ssh_config.go +++ b/git/ssh_config.go @@ -13,15 +13,10 @@ import ( ) var ( - sshHostRE, - sshTokenRE *regexp.Regexp + sshConfigLineRE = regexp.MustCompile(`\A\s*(?P[A-Za-z][A-Za-z0-9]*)(?:\s+|\s*=\s*)(?P.+)`) + sshTokenRE = regexp.MustCompile(`%[%h]`) ) -func init() { - sshHostRE = regexp.MustCompile("(?i)^[ \t]*(host|hostname)[ \t]+(.+)$") - sshTokenRE = regexp.MustCompile(`%[%h]`) -} - // SSHAliasMap encapsulates the translation of SSH hostname aliases type SSHAliasMap map[string]string @@ -45,55 +40,79 @@ func (m SSHAliasMap) Translator() func(*url.URL) *url.URL { } } -// ParseSSHConfig constructs a map of SSH hostname aliases based on user and -// system configuration files -func ParseSSHConfig() SSHAliasMap { - configFiles := []string{ - "/etc/ssh_config", - "/etc/ssh/ssh_config", - } - if homedir, err := homedir.Dir(); err == nil { - userConfig := filepath.Join(homedir, ".ssh", "config") - configFiles = append([]string{userConfig}, configFiles...) - } +type sshParser struct { + homeDir string - openFiles := make([]io.Reader, 0, len(configFiles)) - for _, file := range configFiles { - f, err := os.Open(file) + aliasMap SSHAliasMap + hosts []string + + open func(string) (io.Reader, error) + glob func(string) ([]string, error) +} + +func (p *sshParser) read(fileName string) error { + var file io.Reader + if p.open == nil { + f, err := os.Open(fileName) if err != nil { - continue + return err } defer f.Close() - openFiles = append(openFiles, f) + file = f + } else { + var err error + file, err = p.open(fileName) + if err != nil { + return err + } } - return sshParse(openFiles...) -} -func sshParse(r ...io.Reader) SSHAliasMap { - config := make(SSHAliasMap) - for _, file := range r { - _ = sshParseConfig(config, file) + if len(p.hosts) == 0 { + p.hosts = []string{"*"} } - return config -} -func sshParseConfig(c SSHAliasMap, file io.Reader) error { - hosts := []string{"*"} scanner := bufio.NewScanner(file) for scanner.Scan() { - line := scanner.Text() - match := sshHostRE.FindStringSubmatch(line) - if match == nil { + m := sshConfigLineRE.FindStringSubmatch(scanner.Text()) + if len(m) < 3 { continue } - names := strings.Fields(match[2]) - if strings.EqualFold(match[1], "host") { - hosts = names - } else { - for _, host := range hosts { - for _, name := range names { - c[host] = sshExpandTokens(name, host) + keyword, arguments := strings.ToLower(m[1]), m[2] + switch keyword { + case "host": + p.hosts = strings.Fields(arguments) + case "hostname": + for _, host := range p.hosts { + for _, name := range strings.Fields(arguments) { + if p.aliasMap == nil { + p.aliasMap = make(SSHAliasMap) + } + p.aliasMap[host] = sshExpandTokens(name, host) + } + } + case "include": + for _, arg := range strings.Fields(arguments) { + path := p.absolutePath(fileName, arg) + + var fileNames []string + if p.glob == nil { + paths, _ := filepath.Glob(path) + for _, p := range paths { + if s, err := os.Stat(p); err == nil && !s.IsDir() { + fileNames = append(fileNames, p) + } + } + } else { + var err error + fileNames, err = p.glob(path) + if err != nil { + continue + } + } + + for _, fileName := range fileNames { + _ = p.read(fileName) } } } @@ -102,6 +121,44 @@ func sshParseConfig(c SSHAliasMap, file io.Reader) error { return scanner.Err() } +func (p *sshParser) absolutePath(parentFile, path string) string { + if filepath.IsAbs(path) || strings.HasPrefix(filepath.ToSlash(path), "/") { + return path + } + + if strings.HasPrefix(path, "~") { + return filepath.Join(p.homeDir, strings.TrimPrefix(path, "~")) + } + + if strings.HasPrefix(filepath.ToSlash(parentFile), "/etc/ssh") { + return filepath.Join("/etc/ssh", path) + } + + return filepath.Join(p.homeDir, ".ssh", path) +} + +// ParseSSHConfig constructs a map of SSH hostname aliases based on user and +// system configuration files +func ParseSSHConfig() SSHAliasMap { + configFiles := []string{ + "/etc/ssh_config", + "/etc/ssh/ssh_config", + } + + p := sshParser{} + + if homedir, err := homedir.Dir(); err == nil { + userConfig := filepath.Join(homedir, ".ssh", "config") + configFiles = append([]string{userConfig}, configFiles...) + p.homeDir = homedir + } + + for _, file := range configFiles { + _ = p.read(file) + } + return p.aliasMap +} + func sshExpandTokens(text, host string) string { return sshTokenRE.ReplaceAllStringFunc(text, func(match string) string { switch match { diff --git a/git/ssh_config_test.go b/git/ssh_config_test.go index 7aafc5b219b..f05ca303b9e 100644 --- a/git/ssh_config_test.go +++ b/git/ssh_config_test.go @@ -1,31 +1,127 @@ package git import ( + "bytes" + "fmt" + "io" "net/url" - "reflect" - "strings" + "path/filepath" "testing" + + "github.com/MakeNowJust/heredoc" ) -// TODO: extract assertion helpers into a shared package -func eq(t *testing.T, got interface{}, expected interface{}) { - t.Helper() - if !reflect.DeepEqual(got, expected) { - t.Errorf("expected: %v, got: %v", expected, got) +func Test_sshParser_read(t *testing.T) { + testFiles := map[string]string{ + "/etc/ssh/config": heredoc.Doc(` + Include sites/* + `), + "/etc/ssh/sites/cfg1": heredoc.Doc(` + Host s1 + Hostname=site1.net + `), + "/etc/ssh/sites/cfg2": heredoc.Doc(` + Host s2 + Hostname = site2.net + `), + "HOME/.ssh/config": heredoc.Doc(` + Host * + Host gh gittyhubby + Hostname github.com + #Hostname example.com + Host ex + Include ex_config/* + `), + "HOME/.ssh/ex_config/ex_cfg": heredoc.Doc(` + Hostname example.com + `), + } + globResults := map[string][]string{ + "/etc/ssh/sites/*": {"/etc/ssh/sites/cfg1", "/etc/ssh/sites/cfg2"}, + "HOME/.ssh/ex_config/*": {"HOME/.ssh/ex_config/ex_cfg"}, + } + + p := &sshParser{ + homeDir: "HOME", + open: func(s string) (io.Reader, error) { + if contents, ok := testFiles[filepath.ToSlash(s)]; ok { + return bytes.NewBufferString(contents), nil + } else { + return nil, fmt.Errorf("no test file stub found: %q", s) + } + }, + glob: func(p string) ([]string, error) { + if results, ok := globResults[filepath.ToSlash(p)]; ok { + return results, nil + } else { + return nil, fmt.Errorf("no glob stubs found: %q", p) + } + }, + } + + if err := p.read("/etc/ssh/config"); err != nil { + t.Fatalf("read(global config) = %v", err) + } + if err := p.read("HOME/.ssh/config"); err != nil { + t.Fatalf("read(user config) = %v", err) + } + + if got := p.aliasMap["gh"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gh", "github.com", got) + } + if got := p.aliasMap["gittyhubby"]; got != "github.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "gittyhubby", "github.com", got) + } + if got := p.aliasMap["example.com"]; got != "" { + t.Errorf("expected alias %q to expand to %q, got %q", "example.com", "", got) + } + if got := p.aliasMap["ex"]; got != "example.com" { + t.Errorf("expected alias %q to expand to %q, got %q", "ex", "example.com", got) + } + if got := p.aliasMap["s1"]; got != "site1.net" { + t.Errorf("expected alias %q to expand to %q, got %q", "s1", "site1.net", got) } } -func Test_sshParse(t *testing.T) { - m := sshParse(strings.NewReader(` - Host foo bar - HostName example.com - `), strings.NewReader(` - Host bar baz - hostname %%%h.net%% - `)) - eq(t, m["foo"], "example.com") - eq(t, m["bar"], "%bar.net%") - eq(t, m["nonexistent"], "") +func Test_sshParser_absolutePath(t *testing.T) { + dir := "HOME" + p := &sshParser{homeDir: dir} + + tests := map[string]struct { + parentFile string + arg string + want string + wantErr bool + }{ + "absolute path": { + parentFile: "/etc/ssh/ssh_config", + arg: "/etc/ssh/config", + want: "/etc/ssh/config", + }, + "system relative path": { + parentFile: "/etc/ssh/config", + arg: "configs/*.conf", + want: filepath.Join("/etc", "ssh", "configs", "*.conf"), + }, + "user relative path": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + arg: "configs/*.conf", + want: filepath.Join(dir, ".ssh", "configs/*.conf"), + }, + "shell-like ~ rerefence": { + parentFile: filepath.Join(dir, ".ssh", "ssh_config"), + arg: "~/.ssh/*.conf", + want: filepath.Join(dir, ".ssh", "*.conf"), + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + if got := p.absolutePath(tt.parentFile, tt.arg); got != tt.want { + t.Errorf("absolutePath(): %q, wants %q", got, tt.want) + } + }) + } } func Test_Translator(t *testing.T) {