Skip to content

Commit

Permalink
fix not all pointer receivers being converted correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
Tim Cooper committed Jul 25, 2016
1 parent ca74b80 commit 6fd8014
Show file tree
Hide file tree
Showing 9 changed files with 52 additions and 28 deletions.
2 changes: 1 addition & 1 deletion array.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func arrayNewIndex(L *lua.LState) int {
if index < 1 || index > ref.Len() {
L.ArgError(2, "index out of range")
}
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem(), false))
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem(), nil))
return 0
}

Expand Down
2 changes: 1 addition & 1 deletion chan.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func chanLen(L *lua.LState) int {
func chanSend(L *lua.LState) int {
ref, _, _ := checkChan(L, 1)
value := L.CheckAny(2)
convertedValue := lValueToReflect(L, value, ref.Type().Elem(), false)
convertedValue := lValueToReflect(L, value, ref.Type().Elem(), nil)
if convertedValue.Type() != ref.Type().Elem() {
L.ArgError(2, "incorrect type")
}
Expand Down
16 changes: 16 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package luar
import (
"fmt"
"strconv"
"strings"

"github.com/yuin/gopher-lua"
)
Expand Down Expand Up @@ -1300,6 +1301,12 @@ func Example__37() {
// 2 y
}

type E38String string

func (s *E38String) ToUpper() {
*s = E38String(strings.ToUpper(string(*s)))
}

func Example__38() {
const code = `
print(a[1]:AddNumbers(1, 2, 3, 4, 5))
Expand All @@ -1311,6 +1318,10 @@ func Example__38() {
print(p.Age)
p:IncreaseAge()
print(p.Age)
print(-str)
str:ToUpper()
print(-str)
`

L := lua.NewState()
Expand All @@ -1323,9 +1334,12 @@ func Example__38() {
{Name: "Tim", Age: 32},
}

str := E38String("Hello World")

L.SetGlobal("a", New(L, &a))
L.SetGlobal("s", New(L, s))
L.SetGlobal("p", New(L, s[0]))
L.SetGlobal("str", New(L, &str))

if err := L.DoString(code); err != nil {
panic(err)
Expand All @@ -1339,6 +1353,8 @@ func Example__38() {
// 15
// 32
// 33
// Hello World
// HELLO WORLD
}

func ExampleLState() {
Expand Down
28 changes: 16 additions & 12 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func getFunc(L *lua.LState) (ref reflect.Value, refType reflect.Type) {
return
}

func shouldConvertPtr(L *lua.LState) bool {
func isPtrReceiverMethod(L *lua.LState) bool {
return bool(L.Get(lua.UpvalueIndex(2)).(lua.LBool))
}

Expand All @@ -50,7 +50,7 @@ func funcIsBypass(t reflect.Type) bool {
func funcBypass(L *lua.LState) int {
ref, refType := getFunc(L)

convertPtr := shouldConvertPtr(L)
convertedPtr := false
var receiver reflect.Value
var ud lua.LValue

Expand All @@ -59,7 +59,11 @@ func funcBypass(L *lua.LState) int {
if refType.NumIn() == 2 {
receiverHint := refType.In(0)
ud = L.Get(1)
receiver = lValueToReflect(L, ud, receiverHint, convertPtr)
if isPtrReceiverMethod(L) {
receiver = lValueToReflect(L, ud, receiverHint, &convertedPtr)
} else {
receiver = lValueToReflect(L, ud, receiverHint, nil)
}
if receiver.Type() != receiverHint {
L.RaiseError("incorrect receiver type")
}
Expand All @@ -68,7 +72,7 @@ func funcBypass(L *lua.LState) int {
}
args = append(args, reflect.ValueOf(&luarState))
ret := ref.Call(args)[0].Interface().(int)
if receiver.IsValid() && convertPtr && receiver.Kind() == reflect.Ptr {
if convertedPtr {
ud.(*lua.LUserData).Value = receiver.Elem().Interface()
}
return ret
Expand All @@ -87,7 +91,7 @@ func funcRegular(L *lua.LState) int {
L.RaiseError("invalid number of function arguments (%d or more expected, got %d)", expected-1, top)
}

convertPtr := shouldConvertPtr(L)
convertedPtr := false
var receiver reflect.Value
var ud lua.LValue

Expand All @@ -100,18 +104,18 @@ func funcRegular(L *lua.LState) int {
hint = refType.In(i)
}
var arg reflect.Value
if i == 0 && convertPtr {
if i == 0 && isPtrReceiverMethod(L) {
ud = L.Get(1)
arg = lValueToReflect(L, ud, hint, true)
arg = lValueToReflect(L, ud, hint, &convertedPtr)
receiver = arg
} else {
arg = lValueToReflect(L, L.Get(i+1), hint, false)
arg = lValueToReflect(L, L.Get(i+1), hint, nil)
}
args[i] = arg
}
ret := ref.Call(args)

if receiver.IsValid() && convertPtr && receiver.Kind() == reflect.Ptr {
if convertedPtr {
ud.(*lua.LUserData).Value = receiver.Elem().Interface()
}

Expand All @@ -128,12 +132,12 @@ func funcRegular(L *lua.LState) int {
return len(ret)
}

func funcWrapper(L *lua.LState, fn reflect.Value, convertToPtr bool) *lua.LFunction {
func funcWrapper(L *lua.LState, fn reflect.Value, isPtrReceiverMethod bool) *lua.LFunction {
up := L.NewUserData()
up.Value = fn

if funcIsBypass(fn.Type()) {
return L.NewClosure(funcBypass, up, lua.LBool(convertToPtr))
return L.NewClosure(funcBypass, up, lua.LBool(isPtrReceiverMethod))
}
return L.NewClosure(funcRegular, up, lua.LBool(convertToPtr))
return L.NewClosure(funcRegular, up, lua.LBool(isPtrReceiverMethod))
}
18 changes: 11 additions & 7 deletions luar.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func NewType(L *lua.LState, value interface{}) lua.LValue {
return ud
}

func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConvertToPtr bool) reflect.Value {
func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, tryConvertPtr *bool) reflect.Value {
if hint.Implements(refTypeLuaLValue) {
return reflect.ValueOf(v)
}
Expand Down Expand Up @@ -159,7 +159,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConve

for i := 0; i < hint.NumOut(); i++ {
outHint := hint.Out(i)
ret[i] = lValueToReflect(L, L.Get(-hint.NumOut()+i), outHint, false)
ret[i] = lValueToReflect(L, L.Get(-hint.NumOut()+i), outHint, nil)
}

return ret
Expand All @@ -185,7 +185,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConve

for i := 0; i < len; i++ {
value := converted.RawGetInt(i + 1)
elemValue := lValueToReflect(L, value, elemType, false)
elemValue := lValueToReflect(L, value, elemType, nil)
s.Index(i).Set(elemValue)
}

Expand All @@ -201,8 +201,8 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConve
return
}

lKey := lValueToReflect(L, key, keyType, false)
lValue := lValueToReflect(L, value, elemType, false)
lKey := lValueToReflect(L, key, keyType, nil)
lValue := lValueToReflect(L, value, elemType, nil)
s.SetMapIndex(lKey, lValue)
})

Expand Down Expand Up @@ -232,7 +232,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConve
}
field := hint.FieldByIndex(index)

lValue := lValueToReflect(L, value, field.Type, false)
lValue := lValueToReflect(L, value, field.Type, nil)
t.FieldByIndex(field.Index).Set(lValue)
})

Expand All @@ -247,12 +247,16 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConve
}
case *lua.LUserData:
val := reflect.ValueOf(converted.Value)
if val.Kind() != reflect.Ptr && hint.Kind() == reflect.Ptr && shouldConvertToPtr {
if tryConvertPtr != nil && val.Kind() != reflect.Ptr && hint.Kind() == reflect.Ptr && val.Type() == hint.Elem() {
newVal := reflect.New(hint.Elem())
newVal.Elem().Set(val)
val = newVal
*tryConvertPtr = true
} else {
val = val.Convert(hint)
if tryConvertPtr != nil {
*tryConvertPtr = false
}
}
return val
}
Expand Down
6 changes: 3 additions & 3 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ func mapIndex(L *lua.LState) int {
return 0
}

convertedKey := lValueToReflect(L, key, ref.Type().Key(), false)
convertedKey := lValueToReflect(L, key, ref.Type().Key(), nil)
item := ref.MapIndex(convertedKey)
if !item.IsValid() {

Expand Down Expand Up @@ -67,13 +67,13 @@ func mapNewIndex(L *lua.LState) int {
key := L.CheckAny(2)
value := L.CheckAny(3)

convertedKey := lValueToReflect(L, key, ref.Type().Key(), false)
convertedKey := lValueToReflect(L, key, ref.Type().Key(), nil)
if convertedKey.Type() != ref.Type().Key() {
L.ArgError(2, "invalid map key type")
}
var convertedValue reflect.Value
if value != lua.LNil {
convertedValue = lValueToReflect(L, value, ref.Type().Elem(), false)
convertedValue = lValueToReflect(L, value, ref.Type().Elem(), nil)
if convertedValue.Type() != ref.Type().Elem() {
L.ArgError(3, "invalid map value type")
}
Expand Down
2 changes: 1 addition & 1 deletion ptr.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ptrPow(L *lua.LState) int {
if !elem.CanSet() {
L.RaiseError("unable to set pointer value")
}
value := lValueToReflect(L, val, elem.Type(), false)
value := lValueToReflect(L, val, elem.Type(), nil)
elem.Set(value)
return 1
}
Expand Down
4 changes: 2 additions & 2 deletions slice.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func sliceNewIndex(L *lua.LState) int {
if index < 1 || index > ref.Len() {
L.ArgError(2, "index out of range")
}
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem(), false))
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem(), nil))
return 0
}

Expand Down Expand Up @@ -115,7 +115,7 @@ func sliceAppend(L *lua.LState) int {
hint := ref.Type().Elem()
values := make([]reflect.Value, L.GetTop()-1)
for i := 2; i <= L.GetTop(); i++ {
value := lValueToReflect(L, L.Get(i), hint, false)
value := lValueToReflect(L, L.Get(i), hint, nil)
if value.Type() != hint {
L.ArgError(i, "invalid type")
}
Expand Down
2 changes: 1 addition & 1 deletion struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,6 @@ func structNewIndex(L *lua.LState) int {
if !field.CanSet() {
L.RaiseError("cannot set field " + key)
}
field.Set(lValueToReflect(L, value, field.Type(), false))
field.Set(lValueToReflect(L, value, field.Type(), nil))
return 0
}

0 comments on commit 6fd8014

Please sign in to comment.