Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Smuggle): fields-path param can contain method calls #265

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 122 additions & 15 deletions td/td_smuggle.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2018-2023, Maxime Soulé
// Copyright (c) 2018-2024, Maxime Soulé
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -72,6 +72,7 @@ var smuggleValueType = reflect.TypeOf(smuggleValue{})
type smuggleField struct {
Name string
Indexed bool
Method bool
}

func joinFieldsPath(path []smuggleField) string {
Expand All @@ -84,6 +85,9 @@ func joinFieldsPath(path []smuggleField) string {
buf.WriteByte('.')
}
buf.WriteString(part.Name)
if part.Method {
buf.WriteString("()")
}
}
}
return buf.String()
Expand All @@ -94,6 +98,7 @@ func splitFieldsPath(origPath string) ([]smuggleField, error) {
return nil, fmt.Errorf("FIELD_PATH cannot be empty")
}

privateField := ""
var res []smuggleField
for path := origPath; len(path) > 0; {
r, _ := utf8.DecodeRuneInString(path)
Expand Down Expand Up @@ -130,12 +135,33 @@ func splitFieldsPath(origPath string) ([]smuggleField, error) {
field, path = path[:end], path[end:]
}

for j, r := range field {
if !unicode.IsLetter(r) && (j == 0 || !unicode.IsNumber(r)) {
return nil, fmt.Errorf("unexpected %q in field name %q in FIELDS_PATH %q", r, field, origPath)
if strings.HasSuffix(field, "()") {
if len(field) == 2 {
return nil, fmt.Errorf("missing method name before () in FIELDS_PATH %q", origPath)
}
for j, r := range field[:len(field)-2] {
if j == 0 && !unicode.IsUpper(r) {
return nil, fmt.Errorf("method name %q is not public in FIELDS_PATH %q", field, origPath)
}
if !unicode.IsLetter(r) && !unicode.IsNumber(r) {
return nil, fmt.Errorf("unexpected %q in method name %q in FIELDS_PATH %q", r, field, origPath)
}
}
if privateField != "" {
return nil, fmt.Errorf("cannot call method %s as it is based on an unexported field %q in FIELDS_PATH %q", field, privateField, origPath)
}
res = append(res, smuggleField{Name: field[:len(field)-2], Method: true})
} else {
for j, r := range field {
if privateField == "" && j == 0 && !unicode.IsUpper(r) {
privateField = field
}
if !unicode.IsLetter(r) && (j == 0 || !unicode.IsNumber(r)) {
return nil, fmt.Errorf("unexpected %q in field name %q in FIELDS_PATH %q", r, field, origPath)
}
}
res = append(res, smuggleField{Name: field})
}
res = append(res, smuggleField{Name: field})
}
}
return res, nil
Expand All @@ -155,7 +181,63 @@ func buildFieldsPathFn(path string) (func(any) (smuggleValue, error), error) {
vgot := reflect.ValueOf(got)

for idxPart, field := range parts {
if field.Method {
var method reflect.Value
for {
method = vgot.MethodByName(field.Name)
if !method.IsValid() {
switch vgot.Kind() {
case reflect.Interface, reflect.Ptr:
if !vgot.IsNil() {
vgot = vgot.Elem()
continue
}
return smuggleValue{}, nilFieldErr(parts[:idxPart])
}
if idxPart > 0 {
return smuggleValue{}, fmt.Errorf(
"field %s (type %s) does not implement %s() method",
joinFieldsPath(parts[:idxPart]),
vgot.Type(),
field.Name)
}
return smuggleValue{}, fmt.Errorf(
"type %s has no method %s()", vgot.Type(), field.Name)
}
break
}
mt := method.Type()
if mt.NumIn() != 0 ||
(mt.NumOut() != 1 && (mt.NumOut() != 2 || mt.Out(1) != types.Error)) {
return smuggleValue{}, fmt.Errorf(
"cannot call %s, signature %s not handled, only func() A or func() (A, error) allowed",
joinFieldsPath(parts[:idxPart+1]),
method.Type())
}
var ret []reflect.Value
var panicked any
func() {
defer func() { panicked = recover() }()
ret = method.Call(nil)
}()
if panicked != nil {
return smuggleValue{}, fmt.Errorf(
"method %s panicked: %v",
joinFieldsPath(parts[:idxPart+1]),
panicked)
}
if len(ret) == 2 && !ret[1].IsNil() {
return smuggleValue{}, fmt.Errorf(
"method %s returned an error: %w",
joinFieldsPath(parts[:idxPart+1]),
ret[1].Interface().(error))
}
vgot = ret[0]
continue
}

// Resolve all interface and pointer dereferences
origKind := vgot.Kind()
for {
switch vgot.Kind() {
case reflect.Interface, reflect.Ptr:
Expand All @@ -178,13 +260,22 @@ func buildFieldsPathFn(path string) (func(any) (smuggleValue, error), error) {
}
continue
}
deref := ""
if origKind != vgot.Kind() {
deref = " (after dereferencing)"
}
if idxPart == 0 {
return smuggleValue{},
fmt.Errorf("it is a %s and should be a struct", vgot.Kind())
fmt.Errorf("it is a %s%s and should be a struct", vgot.Kind(), deref)
}
if parts[idxPart-1].Method {
return smuggleValue{}, fmt.Errorf(
"method %s returned a %s%s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind(), deref)
}
return smuggleValue{}, fmt.Errorf(
"field %q is a %s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind())
"field %q is a %s%s and should be a struct",
joinFieldsPath(parts[:idxPart]), vgot.Kind(), deref)
}

switch vgot.Kind() {
Expand Down Expand Up @@ -546,6 +637,7 @@ func buildCaster(outType reflect.Type, useString bool) reflect.Value {
// several struct layers.
//
// type A struct{ Num int }
// // func (a *A) String() string { return fmt.Sprintf("Num is %d", a.Num) }
// type B struct{ As map[string]*A }
// type C struct{ B B }
// got := C{B: B{As: map[string]*A{"foo": {Num: 12}}}}
Expand All @@ -563,15 +655,30 @@ func buildCaster(outType reflect.Type, useString bool) reflect.Value {
// // Tests that got.B.As["foo"].Num is 12
// td.Cmp(t, got, td.Smuggle("B.As[foo].Num", 12))
//
// Contrary to [JSONPointer] operator, private fields can be
// followed. Arrays, slices and maps work using the index/key inside
// square brackets (e.g. [12] or [foo]). Maps work only for simple key
// types (string or numbers), without "" when using strings
// (e.g. [foo]).
// In addition, simple public methods can also be called like in:
//
// td.Cmp(t, got, td.Smuggle("B.As[foo].String()", "Num is 12"))
//
// Allowed methods must not take any parameter and must return one
// value or a value and an error. For the latter case, if the method
// returns a non-nil error, the comparison fails. The comparison also
// fails if a panic occurs or if a method cannot be called. No private
// fields should be traversed before calling the method. For fun,
// consider a more complex example involving [reflect] and chaining
// method calls:
//
// got := reflect.Valueof(&C{B: B{As: map[string]*A{"foo": {Num: 12}}}})
// td.Cmp(t, got, td.Smuggle("Elem().Interface().B.As[foo].String()", "Num is 12"))
//
// Contrary to [JSONPointer] operator, private fields can be followed
// and public methods on public fields can be called. Arrays, slices
// and maps work using the index/key inside square brackets (e.g. [12]
// or [foo]). Maps work only for simple key types (string or numbers),
// without "" when using strings (e.g. [foo]).
//
// Behind the scenes, a temporary function is automatically created to
// achieve the same goal, but add some checks against nil values and
// auto-dereference interfaces and pointers, even on several levels,
// achieve the same goal, but adds some checks against nil values and
// auto-dereferences interfaces and pointers, even on several levels,
// like in:
//
// type A struct{ N any }
Expand Down
Loading
Loading