Skip to content

Commit

Permalink
Add config for flags
Browse files Browse the repository at this point in the history
  • Loading branch information
dearchap committed Nov 26, 2022
1 parent 397a9df commit 8d49a02
Show file tree
Hide file tree
Showing 21 changed files with 372 additions and 333 deletions.
6 changes: 4 additions & 2 deletions app_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2871,8 +2871,10 @@ func TestFlagAction(t *testing.T) {
},
},
&TimestampFlag{
Name: "f_timestamp",
Layout: "2006-01-02 15:04:05",
Name: "f_timestamp",
Config: TimestampConfig{
Layout: "2006-01-02 15:04:05",
},
Action: func(c *Context, v time.Time) error {
if v.IsZero() {
return fmt.Errorf("zero timestamp")
Expand Down
36 changes: 25 additions & 11 deletions flag_bool.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ import (
"strconv"
)

type BoolFlag = FlagBase[bool, BoolConfig, boolValue]

// BoolConfig defines the configuration for bool flags
type BoolConfig struct {
Count *int
}

// boolValue needs to implement the boolFlag internal interface in flag
// to be able to capture bool fields and values
//
Expand All @@ -19,18 +26,34 @@ type boolValue struct {
count *int
}

func (i boolValue) Create(val bool, p *bool, c FlagConfig) flag.Value {
func (cCtx *Context) Bool(name string) bool {
if v, ok := cCtx.Value(name).(bool); ok {
return v
}
return false
}

// Below functions are to satisfy the ValueCreator interface

// Create creates the bool value
func (i boolValue) Create(val bool, p *bool, c BoolConfig) flag.Value {
*p = val
if c.Count == nil {
c.Count = new(int)
}
return &boolValue{
destination: p,
count: c.GetCount(),
count: c.Count,
}
}

// ToString formats the bool value
func (i boolValue) ToString(b bool) string {
return fmt.Sprintf("%v", b)
}

// Below functions are to satisfy the flag.Value interface

func (b *boolValue) Set(s string) error {
v, err := strconv.ParseBool(s)
if err != nil {
Expand Down Expand Up @@ -61,12 +84,3 @@ func (b *boolValue) Count() int {
}
return 0
}

type BoolFlag = FlagBase[bool, boolValue]

func (cCtx *Context) Bool(name string) bool {
if v, ok := cCtx.Value(name).(bool); ok {
return v
}
return false
}
10 changes: 7 additions & 3 deletions flag_duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ import (
"time"
)

type DurationFlag = FlagBase[time.Duration, NoConfig, durationValue]

// -- time.Duration Value
type durationValue time.Duration

func (i durationValue) Create(val time.Duration, p *time.Duration, c FlagConfig) flag.Value {
// Below functions are to satisfy the ValueCreator interface

func (i durationValue) Create(val time.Duration, p *time.Duration, c NoConfig) flag.Value {
*p = val
return (*durationValue)(p)
}
Expand All @@ -18,6 +22,8 @@ func (i durationValue) ToString(d time.Duration) string {
return fmt.Sprintf("%v", d)
}

// Below functions are to satisfy the flag.Value interface

func (d *durationValue) Set(s string) error {
v, err := time.ParseDuration(s)
if err != nil {
Expand All @@ -31,8 +37,6 @@ func (d *durationValue) Get() any { return time.Duration(*d) }

func (d *durationValue) String() string { return (*time.Duration)(d).String() }

type DurationFlag = FlagBase[time.Duration, durationValue]

func (cCtx *Context) Duration(name string) time.Duration {
if v, ok := cCtx.Value(name).(time.Duration); ok {
return v
Expand Down
10 changes: 7 additions & 3 deletions flag_float64.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ import (
"strconv"
)

type Float64Flag = FlagBase[float64, NoConfig, float64Value]

// -- float64 Value
type float64Value float64

func (f float64Value) Create(val float64, p *float64, c FlagConfig) flag.Value {
// Below functions are to satisfy the ValueCreator interface

func (f float64Value) Create(val float64, p *float64, c NoConfig) flag.Value {
*p = val
return (*float64Value)(p)
}
Expand All @@ -18,6 +22,8 @@ func (f float64Value) ToString(b float64) string {
return fmt.Sprintf("%v", b)
}

// Below functions are to satisfy the flag.Value interface

func (f *float64Value) Set(s string) error {
v, err := strconv.ParseFloat(s, 64)
if err != nil {
Expand All @@ -31,8 +37,6 @@ func (f *float64Value) Get() any { return float64(*f) }

func (f *float64Value) String() string { return strconv.FormatFloat(float64(*f), 'g', -1, 64) }

type Float64Flag = FlagBase[float64, float64Value]

// Int looks up the value of a local IntFlag, returns
// 0 if not found
func (cCtx *Context) Float64(name string) float64 {
Expand Down
10 changes: 5 additions & 5 deletions flag_float64_slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ import (
"flag"
)

type Float64Slice = SliceBase[float64, NoConfig, float64Value]
type Float64SliceFlag = FlagBase[[]float64, NoConfig, Float64Slice]

var NewFloat64Slice = NewSliceBase[float64, NoConfig, float64Value]

// Float64Slice looks up the value of a local Float64SliceFlag, returns
// nil if not found
func (cCtx *Context) Float64Slice(name string) []float64 {
Expand All @@ -22,8 +27,3 @@ func lookupFloat64Slice(name string, set *flag.FlagSet) []float64 {
}
return nil
}

type Float64Slice = SliceBase[float64, float64Value]
type Float64SliceFlag = FlagBase[[]float64, Float64Slice]

var NewFloat64Slice = NewSliceBase[float64, float64Value]
98 changes: 36 additions & 62 deletions flag_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,29 @@ import (
"flag"
"fmt"
"reflect"
"time"
)

type FlagConfig interface {
GetNumberBase() int
GetCount() *int
GetLayout() string
GetTimezone() *time.Location
}

type ValueCreator[T any] interface {
Create(T, *T, FlagConfig) flag.Value
// ValueCreator is responsible for creating a flag.Value emulation
// as well as custom formatting
//
// T specifies the type
// C specifies the config for the type
type ValueCreator[T any, C any] interface {
Create(T, *T, C) flag.Value
ToString(T) string
}

// FlagBase[T,VC] is a generic flag base which can be used
// NoConfig is for flags which dont need a custom configuration
type NoConfig struct{}

// FlagBase[T,C,VC] is a generic flag base which can be used
// as a boilerplate to implement the most common interfaces
// used by urfave/cli. T specifies the types and VC specifies
// a value creator which can be used to get the correct flag.Value
// for that type
type FlagBase[T any, VC ValueCreator[T]] struct {
// used by urfave/cli.
//
// T specifies the type
// C specifies the configuration required(if any for that flag type)
// VC specifies the value creator which creates the flag.Value emulation
type FlagBase[T any, C any, VC ValueCreator[T, C]] struct {
Name string

Category string
Expand All @@ -46,55 +48,27 @@ type FlagBase[T any, VC ValueCreator[T]] struct {

Action func(*Context, T) error

// for int flags only
NumberBase int

// for bool flags only
Count *int

// for timestamp flags only
Layout string
Timezone *time.Location
Config C

creator VC
value flag.Value
}

func (f *FlagBase[T, V]) GetNumberBase() int {
return f.NumberBase
}

func (f *FlagBase[T, V]) GetCount() *int {
return f.Count
}

func (f *FlagBase[T, V]) GetLayout() string {
return f.Layout
}

func (f *FlagBase[T, V]) GetTimezone() *time.Location {
return f.Timezone
}

// GetValue returns the flags value as string representation and an empty
// string if the flag takes no value at all.
func (f *FlagBase[T, V]) GetValue() string {
func (f *FlagBase[T, C, V]) GetValue() string {
if reflect.TypeOf(f.Value).Kind() == reflect.Bool {
return ""
}
return fmt.Sprintf("%v", f.Value)
}

// Apply populates the flag given the flag set and environment
func (f *FlagBase[T, V]) Apply(set *flag.FlagSet) error {
if f.Count == nil {
f.Count = new(int)
}

func (f *FlagBase[T, C, V]) Apply(set *flag.FlagSet) error {
newVal := f.Value

if val, source, found := flagFromEnvOrFile(f.EnvVars, f.FilePath); found {
tmpVal := f.creator.Create(f.Value, new(T), f)
tmpVal := f.creator.Create(f.Value, new(T), f.Config)
if val != "" || reflect.TypeOf(f.Value).Kind() == reflect.String {
if err := tmpVal.Set(val); err != nil {
return fmt.Errorf("could not parse %q as %T value from %s for flag %s: %s", val, f.Value, source, f.Name, err)
Expand All @@ -111,9 +85,9 @@ func (f *FlagBase[T, V]) Apply(set *flag.FlagSet) error {
}

if f.Destination == nil {
f.value = f.creator.Create(newVal, new(T), f)
f.value = f.creator.Create(newVal, new(T), f.Config)
} else {
f.value = f.creator.Create(newVal, f.Destination, f)
f.value = f.creator.Create(newVal, f.Destination, f.Config)
}

for _, name := range f.Names() {
Expand All @@ -124,53 +98,53 @@ func (f *FlagBase[T, V]) Apply(set *flag.FlagSet) error {
}

// String returns a readable representation of this value (for usage defaults)
func (f *FlagBase[T, V]) String() string {
func (f *FlagBase[T, C, V]) String() string {
return FlagStringer(f)
}

// IsSet returns whether or not the flag has been set through env or file
func (f *FlagBase[T, V]) IsSet() bool {
func (f *FlagBase[T, C, V]) IsSet() bool {
return f.hasBeenSet
}

// Names returns the names of the flag
func (f *FlagBase[T, V]) Names() []string {
func (f *FlagBase[T, C, V]) Names() []string {
return FlagNames(f.Name, f.Aliases)
}

// IsRequired returns whether or not the flag is required
func (f *FlagBase[T, V]) IsRequired() bool {
func (f *FlagBase[T, C, V]) IsRequired() bool {
return f.Required
}

// IsVisible returns true if the flag is not hidden, otherwise false
func (f *FlagBase[T, V]) IsVisible() bool {
func (f *FlagBase[T, C, V]) IsVisible() bool {
return !f.Hidden
}

// GetCategory returns the category of the flag
func (f *FlagBase[T, V]) GetCategory() string {
func (f *FlagBase[T, C, V]) GetCategory() string {
return f.Category
}

// GetUsage returns the usage string for the flag
func (f *FlagBase[T, V]) GetUsage() string {
func (f *FlagBase[T, C, V]) GetUsage() string {
return f.Usage
}

// GetEnvVars returns the env vars for this flag
func (f *FlagBase[T, V]) GetEnvVars() []string {
func (f *FlagBase[T, C, V]) GetEnvVars() []string {
return f.EnvVars
}

// TakesValue returns true if the flag takes a value, otherwise false
func (f *FlagBase[T, V]) TakesValue() bool {
func (f *FlagBase[T, C, V]) TakesValue() bool {
var t T
return reflect.TypeOf(t).Kind() != reflect.Bool
}

// GetDefaultText returns the default text for this flag
func (f *FlagBase[T, V]) GetDefaultText() string {
func (f *FlagBase[T, C, V]) GetDefaultText() string {
if f.DefaultText != "" {
return f.DefaultText
}
Expand All @@ -179,7 +153,7 @@ func (f *FlagBase[T, V]) GetDefaultText() string {
}

// Get returns the flag’s value in the given Context.
func (f *FlagBase[T, V]) Get(ctx *Context) T {
func (f *FlagBase[T, C, V]) Get(ctx *Context) T {
if v, ok := ctx.Value(f.Name).(T); ok {
return v
}
Expand All @@ -188,15 +162,15 @@ func (f *FlagBase[T, V]) Get(ctx *Context) T {
}

// RunAction executes flag action if set
func (f *FlagBase[T, V]) RunAction(ctx *Context) error {
func (f *FlagBase[T, C, V]) RunAction(ctx *Context) error {
if f.Action != nil {
return f.Action(ctx, f.Get(ctx))
}

return nil
}

func (f *FlagBase[T, VC]) IsSliceFlag() bool {
func (f *FlagBase[T, C, VC]) IsSliceFlag() bool {
// TBD how to specify
return reflect.TypeOf(f.Value).Kind() == reflect.Slice
}
Loading

0 comments on commit 8d49a02

Please sign in to comment.