Skip to content

Commit

Permalink
Fix binding of untagged struct fields (labstack#812)
Browse files Browse the repository at this point in the history
* Add failing test

A BindUnmarshaler struct with no tag is not decoded properly.

* Fix binding of untagged structs
  • Loading branch information
flimzy authored and vishr committed Jan 16, 2017
1 parent 80d5c96 commit ed7353c
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
18 changes: 13 additions & 5 deletions bind.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ func (b *DefaultBinder) bindData(ptr interface{}, data map[string][]string, tag
if inputFieldName == "" {
inputFieldName = typeField.Name
// If tag is nil, we inspect if the field is a struct.
if structFieldKind == reflect.Struct {
if _, ok := bindUnmarshaler(structField); !ok && structFieldKind == reflect.Struct {
err := b.bindData(structField.Addr().Interface(), data, tag)
if err != nil {
return err
Expand Down Expand Up @@ -185,16 +185,24 @@ func unmarshalField(valueKind reflect.Kind, val string, field reflect.Value) (bo
}
}

func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
// bindUnmarshaler attempts to unmarshal a reflect.Value into a BindUnmarshaler
func bindUnmarshaler(field reflect.Value) (BindUnmarshaler, bool) {
ptr := reflect.New(field.Type())
if ptr.CanInterface() {
iface := ptr.Interface()
if unmarshaler, ok := iface.(BindUnmarshaler); ok {
err := unmarshaler.UnmarshalParam(value)
field.Set(ptr.Elem())
return true, err
return unmarshaler, ok
}
}
return nil, false
}

func unmarshalFieldNonPtr(value string, field reflect.Value) (bool, error) {
if unmarshaler, ok := bindUnmarshaler(field); ok {
err := unmarshaler.UnmarshalParam(value)
field.Set(reflect.ValueOf(unmarshaler).Elem())
return true, err
}
return false, nil
}

Expand Down
15 changes: 14 additions & 1 deletion bind_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ type (
Timestamp time.Time
TA []Timestamp
StringArray []string
Struct struct {
Foo string
}
)

func (t *Timestamp) UnmarshalParam(src string) error {
Expand All @@ -53,6 +56,13 @@ func (a *StringArray) UnmarshalParam(src string) error {
return nil
}

func (s *Struct) UnmarshalParam(src string) error {
*s = Struct{
Foo: src,
}
return nil
}

func (t bindTestStruct) GetCantSet() string {
return t.cantSet
}
Expand All @@ -75,6 +85,7 @@ var values = map[string][]string{
"cantSet": {"test"},
"T": {"2016-12-06T19:09:05+01:00"},
"Tptr": {"2016-12-06T19:09:05+01:00"},
"ST": {"bar"},
}

func TestBindJSON(t *testing.T) {
Expand Down Expand Up @@ -115,13 +126,14 @@ func TestBindQueryParams(t *testing.T) {

func TestBindUnmarshalParam(t *testing.T) {
e := New()
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z", nil)
req, _ := http.NewRequest(GET, "/?ts=2016-12-06T19:09:05Z&sa=one,two,three&ta=2016-12-06T19:09:05Z&ta=2016-12-06T19:09:05Z&ST=baz", nil)
rec := httptest.NewRecorder()
c := e.NewContext(req, rec)
result := struct {
T Timestamp `query:"ts"`
TA []Timestamp `query:"ta"`
SA StringArray `query:"sa"`
ST Struct
}{}
err := c.Bind(&result)
ts := Timestamp(time.Date(2016, 12, 6, 19, 9, 5, 0, time.UTC))
Expand All @@ -130,6 +142,7 @@ func TestBindUnmarshalParam(t *testing.T) {
assert.Equal(t, ts, result.T)
assert.Equal(t, StringArray([]string{"one", "two", "three"}), result.SA)
assert.Equal(t, []Timestamp{ts, ts}, result.TA)
assert.Equal(t, Struct{"baz"}, result.ST)
}
}

Expand Down

0 comments on commit ed7353c

Please sign in to comment.