Skip to content

Commit

Permalink
feat: enable --header flag to specify headers with requests (oras-pro…
Browse files Browse the repository at this point in the history
…ject#794)

Signed-off-by: wangxiaoxuan273 <[email protected]>
  • Loading branch information
wangxiaoxuan273 authored Feb 14, 2023
1 parent d6240d6 commit 1543ce2
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 1 deletion.
26 changes: 25 additions & 1 deletion cmd/oras/internal/option/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ type Remote struct {
resolveDialContext func(dialer *net.Dialer) func(context.Context, string, string) (net.Conn, error)
applyDistributionSpec bool
distributionSpec distributionSpec
headerFlags []string
headers http.Header
}

// EnableDistributionSpecFlag set distribution specification flag as applicable.
Expand All @@ -62,6 +64,7 @@ func (opts *Remote) EnableDistributionSpecFlag() {
func (opts *Remote) ApplyFlags(fs *pflag.FlagSet) {
opts.ApplyFlagsWithPrefix(fs, "", "")
fs.BoolVarP(&opts.PasswordFromStdin, "password-stdin", "", false, "read password or identity token from stdin")
fs.StringArrayVarP(&opts.headerFlags, "header", "H", nil, "add custom headers to requests")
}

func applyPrefix(prefix, description string) (flagPrefix, notePrefix string) {
Expand Down Expand Up @@ -105,6 +108,9 @@ func (opts *Remote) ApplyFlagsWithPrefix(fs *pflag.FlagSet, prefix, description

// Parse tries to read password with optional cmd prompt.
func (opts *Remote) Parse() error {
if err := opts.parseCustomHeaders(); err != nil {
return err
}
if err := opts.readPassword(); err != nil {
return err
}
Expand Down Expand Up @@ -209,7 +215,8 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client
TLSClientConfig: config,
},
},
Cache: auth.NewCache(),
Cache: auth.NewCache(),
Header: opts.headers,
}
client.SetUserAgent("oras/" + version.GetVersion())
if debug {
Expand Down Expand Up @@ -243,6 +250,23 @@ func (opts *Remote) authClient(registry string, debug bool) (client *auth.Client
return
}

func (opts *Remote) parseCustomHeaders() error {
if len(opts.headerFlags) != 0 {
headers := map[string][]string{}
for _, h := range opts.headerFlags {
name, value, found := strings.Cut(h, ":")
if !found || strings.TrimSpace(name) == "" {
// In conformance to the RFC 2616 specification
// Reference: https://www.rfc-editor.org/rfc/rfc2616#section-4.2
return fmt.Errorf("invalid header: %q", h)
}
headers[name] = append(headers[name], value)
}
opts.headers = headers
}
return nil
}

// Credential returns a credential based on the remote options.
func (opts *Remote) Credential() auth.Credential {
return credential.Credential(opts.Username, opts.Password)
Expand Down
101 changes: 101 additions & 0 deletions cmd/oras/internal/option/remote_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,104 @@ func TestRemote_parseResolve_err(t *testing.T) {
})
}
}

func TestRemote_parseCustomHeaders(t *testing.T) {
tests := []struct {
name string
headerFlags []string
want nhttp.Header
wantErr bool
}{
{
name: "no custom header is provided",
headerFlags: []string{},
want: nil,
wantErr: false,
},
{
name: "one name-value pair",
headerFlags: []string{"key:value"},
want: map[string][]string{"key": {"value"}},
wantErr: false,
},
{
name: "multiple name-value pairs",
headerFlags: []string{"key:value", "k:v"},
want: map[string][]string{"key": {"value"}, "k": {"v"}},
wantErr: false,
},
{
name: "multiple name-value pairs with commas",
headerFlags: []string{"key:value,value2,value3", "k:v,v2,v3"},
want: map[string][]string{"key": {"value,value2,value3"}, "k": {"v,v2,v3"}},
wantErr: false,
},
{
name: "empty string is a valid value",
headerFlags: []string{"k:", "key:value,value2,value3"},
want: map[string][]string{"k": {""}, "key": {"value,value2,value3"}},
wantErr: false,
},
{
name: "multiple colons are allowed",
headerFlags: []string{"k::::v,v2,v3", "key:value,value2,value3"},
want: map[string][]string{"k": {":::v,v2,v3"}, "key": {"value,value2,value3"}},
wantErr: false,
},
{
name: "name with spaces",
headerFlags: []string{"bar :b"},
want: map[string][]string{"bar ": {"b"}},
wantErr: false,
},
{
name: "value with spaces",
headerFlags: []string{"foo: a"},
want: map[string][]string{"foo": {" a"}},
wantErr: false,
},
{
name: "repeated pairs",
headerFlags: []string{"key:value", "key:value"},
want: map[string][]string{"key": {"value", "value"}},
wantErr: false,
},
{
name: "repeated name with different values",
headerFlags: []string{"key:value", "key:value2"},
want: map[string][]string{"key": {"value", "value2"}},
wantErr: false,
},
{
name: "one valid header and one invalid header(no pair)",
headerFlags: []string{"key:value,value2,value3", "vk"},
want: nil,
wantErr: true,
},
{
name: "one valid header and one invalid header(empty name)",
headerFlags: []string{":v", "key:value,value2,value3"},
want: nil,
wantErr: true,
},
{
name: "pure-space name is invalid",
headerFlags: []string{" : foo "},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := &Remote{
headerFlags: tt.headerFlags,
}
if err := opts.parseCustomHeaders(); (err != nil) != tt.wantErr {
t.Errorf("Remote.parseCustomHeaders() error = %v, wantErr %v", err, tt.wantErr)
}
if !reflect.DeepEqual(tt.want, opts.headers) {
t.Errorf("Remote.parseCustomHeaders() = %v, want %v", opts.headers, tt.want)
}
})
}
}

0 comments on commit 1543ce2

Please sign in to comment.