Skip to content

Commit

Permalink
Merge pull request cli#1626 from cli/ghe-auth-tweaks
Browse files Browse the repository at this point in the history
Make GitHub remote parsing and authentication stricter
  • Loading branch information
mislav authored Sep 8, 2020
2 parents d13e6b3 + ece17c4 commit 72e9747
Show file tree
Hide file tree
Showing 19 changed files with 555 additions and 278 deletions.
20 changes: 7 additions & 13 deletions cmd/gh/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,23 +92,17 @@ func main() {
}
}

authCheckEnabled := cmdutil.IsAuthCheckEnabled(cmd)

// TODO support other names
ghtoken := os.Getenv("GITHUB_TOKEN")
if ghtoken != "" {
authCheckEnabled = false
}

authCheckEnabled := os.Getenv("GITHUB_TOKEN") == "" &&
os.Getenv("GITHUB_ENTERPRISE_TOKEN") == "" &&
cmdutil.IsAuthCheckEnabled(cmd)
if authCheckEnabled {
hasAuth := false

cfg, err := cmdFactory.Config()
if err == nil {
hasAuth = cmdutil.CheckAuth(cfg)
if err != nil {
fmt.Fprintf(stderr, "failed to read configuration: %s\n", err)
os.Exit(2)
}

if !hasAuth {
if !cmdutil.CheckAuth(cfg) {
fmt.Fprintln(stderr, utils.Bold("Welcome to GitHub CLI!"))
fmt.Fprintln(stderr)
fmt.Fprintln(stderr, "To authenticate, please run `gh auth login`.")
Expand Down
29 changes: 20 additions & 9 deletions git/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,35 @@ package git

import (
"net/url"
"regexp"
"strings"
)

var (
protocolRe = regexp.MustCompile("^[a-zA-Z_+-]+://")
)

func IsURL(u string) bool {
return strings.HasPrefix(u, "git@") || protocolRe.MatchString(u)
return strings.HasPrefix(u, "git@") || isSupportedProtocol(u)
}

func isSupportedProtocol(u string) bool {
return strings.HasPrefix(u, "ssh:") ||
strings.HasPrefix(u, "git+ssh:") ||
strings.HasPrefix(u, "git:") ||
strings.HasPrefix(u, "http:") ||
strings.HasPrefix(u, "https:")
}

func isPossibleProtocol(u string) bool {
return isSupportedProtocol(u) ||
strings.HasPrefix(u, "ftp:") ||
strings.HasPrefix(u, "ftps:") ||
strings.HasPrefix(u, "file:")
}

// ParseURL normalizes git remote urls
func ParseURL(rawURL string) (u *url.URL, err error) {
if !protocolRe.MatchString(rawURL) &&
strings.Contains(rawURL, ":") &&
if !isPossibleProtocol(rawURL) &&
strings.ContainsRune(rawURL, ':') &&
// not a Windows path
!strings.Contains(rawURL, "\\") {
!strings.ContainsRune(rawURL, '\\') {
// support scp-like syntax for ssh protocol
rawURL = "ssh://" + strings.Replace(rawURL, ":", "/", 1)
}

Expand Down
195 changes: 195 additions & 0 deletions git/url_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
package git

import "testing"

func TestIsURL(t *testing.T) {
tests := []struct {
name string
url string
want bool
}{
{
name: "scp-like",
url: "[email protected]:owner/repo",
want: true,
},
{
name: "scp-like with no user",
url: "example.com:owner/repo",
want: false,
},
{
name: "ssh",
url: "ssh://[email protected]/owner/repo",
want: true,
},
{
name: "git",
url: "git://example.com/owner/repo",
want: true,
},
{
name: "https",
url: "https://example.com/owner/repo.git",
want: true,
},
{
name: "no protocol",
url: "example.com/owner/repo",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsURL(tt.url); got != tt.want {
t.Errorf("IsURL() = %v, want %v", got, tt.want)
}
})
}
}

func TestParseURL(t *testing.T) {
type url struct {
Scheme string
User string
Host string
Path string
}
tests := []struct {
name string
url string
want url
wantErr bool
}{
{
name: "HTTPS",
url: "https://example.com/owner/repo.git",
want: url{
Scheme: "https",
User: "",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "HTTP",
url: "http://example.com/owner/repo.git",
want: url{
Scheme: "http",
User: "",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "git",
url: "git://example.com/owner/repo.git",
want: url{
Scheme: "git",
User: "",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "ssh",
url: "ssh://[email protected]/owner/repo.git",
want: url{
Scheme: "ssh",
User: "git",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "ssh with port",
url: "ssh://[email protected]:443/owner/repo.git",
want: url{
Scheme: "ssh",
User: "git",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "git+ssh",
url: "git+ssh://example.com/owner/repo.git",
want: url{
Scheme: "ssh",
User: "",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "scp-like",
url: "[email protected]:owner/repo.git",
want: url{
Scheme: "ssh",
User: "git",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "scp-like, leading slash",
url: "[email protected]:/owner/repo.git",
want: url{
Scheme: "ssh",
User: "git",
Host: "example.com",
Path: "/owner/repo.git",
},
},
{
name: "file protocol",
url: "file:///example.com/owner/repo.git",
want: url{
Scheme: "file",
User: "",
Host: "",
Path: "/example.com/owner/repo.git",
},
},
{
name: "file path",
url: "/example.com/owner/repo.git",
want: url{
Scheme: "",
User: "",
Host: "",
Path: "/example.com/owner/repo.git",
},
},
{
name: "Windows file path",
url: "C:\\example.com\\owner\\repo.git",
want: url{
Scheme: "c",
User: "",
Host: "",
Path: "",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
u, err := ParseURL(tt.url)
if (err != nil) != tt.wantErr {
t.Fatalf("got error: %v", err)
}
if u.Scheme != tt.want.Scheme {
t.Errorf("expected scheme %q, got %q", tt.want.Scheme, u.Scheme)
}
if u.User.Username() != tt.want.User {
t.Errorf("expected user %q, got %q", tt.want.User, u.User.Username())
}
if u.Host != tt.want.Host {
t.Errorf("expected host %q, got %q", tt.want.Host, u.Host)
}
if u.Path != tt.want.Path {
t.Errorf("expected path %q, got %q", tt.want.Path, u.Path)
}
})
}
}
30 changes: 23 additions & 7 deletions internal/config/config_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ const defaultGitProtocol = "https"
// This interface describes interacting with some persistent configuration for gh.
type Config interface {
Get(string, string) (string, error)
GetWithSource(string, string) (string, string, error)
Set(string, string, string) error
UnsetHost(string)
Hosts() ([]string, error)
Aliases() (*AliasConfig, error)
CheckWriteable(string, string) error
Write() error
}

Expand Down Expand Up @@ -200,42 +202,51 @@ func (c *fileConfig) Root() *yaml.Node {
}

func (c *fileConfig) Get(hostname, key string) (string, error) {
val, _, err := c.GetWithSource(hostname, key)
return val, err
}

func (c *fileConfig) GetWithSource(hostname, key string) (string, string, error) {
if hostname != "" {
var notFound *NotFoundError

hostCfg, err := c.configForHost(hostname)
if err != nil && !errors.As(err, &notFound) {
return "", err
return "", "", err
}

var hostValue string
if hostCfg != nil {
hostValue, err = hostCfg.GetStringValue(key)
if err != nil && !errors.As(err, &notFound) {
return "", err
return "", "", err
}
}

if hostValue != "" {
return hostValue, nil
// TODO: avoid hardcoding this
return hostValue, "~/.config/gh/hosts.yml", nil
}
}

// TODO: avoid hardcoding this
defaultSource := "~/.config/gh/config.yml"

value, err := c.GetStringValue(key)

var notFound *NotFoundError

if err != nil && errors.As(err, &notFound) {
return defaultFor(key), nil
return defaultFor(key), defaultSource, nil
} else if err != nil {
return "", err
return "", defaultSource, err
}

if value == "" {
return defaultFor(key), nil
return defaultFor(key), defaultSource, nil
}

return value, nil
return value, defaultSource, nil
}

func (c *fileConfig) Set(hostname, key, value string) error {
Expand Down Expand Up @@ -281,6 +292,11 @@ func (c *fileConfig) configForHost(hostname string) (*HostConfig, error) {
return nil, &NotFoundError{fmt.Errorf("could not find config entry for %q", hostname)}
}

func (c *fileConfig) CheckWriteable(hostname, key string) error {
// TODO: check filesystem permissions
return nil
}

func (c *fileConfig) Write() error {
mainData := yaml.Node{Kind: yaml.MappingNode}
hostsData := yaml.Node{Kind: yaml.MappingNode}
Expand Down
Loading

0 comments on commit 72e9747

Please sign in to comment.