Skip to content

Commit

Permalink
allow non-pointer luar userdata values to be used for pointer receivers
Browse files Browse the repository at this point in the history
this mimics Go behaviour

fixes layeh#17
  • Loading branch information
Tim Cooper committed Jul 24, 2016
1 parent c4cd169 commit ca74b80
Show file tree
Hide file tree
Showing 10 changed files with 115 additions and 45 deletions.
8 changes: 4 additions & 4 deletions array.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ func arrayIndex(L *lua.LState) int {
}
L.Push(New(L, val.Interface()))
case lua.LString:
if isPtr {
if fn := mt.ptrMethod(string(converted)); fn != nil {
if !isPtr {
if fn := mt.method(string(converted)); fn != nil {
L.Push(fn)
return 1
}
}
if fn := mt.method(string(converted)); fn != nil {
if fn := mt.ptrMethod(string(converted)); fn != nil {
L.Push(fn)
return 1
}
Expand All @@ -67,7 +67,7 @@ func arrayNewIndex(L *lua.LState) int {
if index < 1 || index > ref.Len() {
L.ArgError(2, "index out of range")
}
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem()))
ref.Index(index - 1).Set(lValueToReflect(L, value, ref.Type().Elem(), false))
return 0
}

Expand Down
8 changes: 4 additions & 4 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,13 @@ func getMTCache(L *lua.LState) *mtCache {
return cache
}

func addMethods(L *lua.LState, vtype reflect.Type, tbl *lua.LTable) {
func addMethods(L *lua.LState, vtype reflect.Type, tbl *lua.LTable, ptrReceiver bool) {
for i := 0; i < vtype.NumMethod(); i++ {
method := vtype.Method(i)
if method.PkgPath != "" {
continue
}
fn := funcWrapper(L, method.Func)
fn := funcWrapper(L, method.Func, ptrReceiver)
tbl.RawSetString(method.Name, fn)
tbl.RawSetString(getUnexportedName(method.Name), fn)
}
Expand Down Expand Up @@ -182,10 +182,10 @@ func getMetatable(L *lua.LState, vtype reflect.Type) *lua.LTable {
mt.RawSetString("__index", L.NewFunction(ptrIndex))
}

addMethods(L, reflect.PtrTo(vtype), ptrMethods)
addMethods(L, reflect.PtrTo(vtype), ptrMethods, true)
mt.RawSetString("ptr_methods", ptrMethods)

addMethods(L, vtype, methods)
addMethods(L, vtype, methods, false)
mt.RawSetString("methods", methods)

mt.RawSetString("original", L.NewTable())
Expand Down
8 changes: 4 additions & 4 deletions chan.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ func chanIndex(L *lua.LState) int {
_, mt, isPtr := checkChan(L, 1)
key := L.CheckString(2)

if isPtr {
if fn := mt.ptrMethod(key); fn != nil {
if !isPtr {
if fn := mt.method(key); fn != nil {
L.Push(fn)
return 1
}
}

if fn := mt.method(key); fn != nil {
if fn := mt.ptrMethod(key); fn != nil {
L.Push(fn)
return 1
}
Expand All @@ -52,7 +52,7 @@ func chanLen(L *lua.LState) int {
func chanSend(L *lua.LState) int {
ref, _, _ := checkChan(L, 1)
value := L.CheckAny(2)
convertedValue := lValueToReflect(L, value, ref.Type().Elem())
convertedValue := lValueToReflect(L, value, ref.Type().Elem(), false)
if convertedValue.Type() != ref.Type().Elem() {
L.ArgError(2, "incorrect type")
}
Expand Down
29 changes: 24 additions & 5 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ import (
)

type Person struct {
Name string
Age int
Friend *Person
Name string
Age int
Friend *Person
LastAddSum int
}

func (p Person) Hello() string {
Expand All @@ -27,9 +28,14 @@ func (p *Person) AddNumbers(L *LState) int {
sum += L.CheckInt(i)
}
L.Push(lua.LString(p.Name + " counts: " + strconv.Itoa(sum)))
p.LastAddSum = sum
return 1
}

func (p *Person) IncreaseAge() {
p.Age++
}

func Example__1() {
const code = `
print(user1.Name)
Expand Down Expand Up @@ -1297,7 +1303,14 @@ func Example__37() {
func Example__38() {
const code = `
print(a[1]:AddNumbers(1, 2, 3, 4, 5))
print(s[1]:AddNumbers(1, 2, 3, 4, 5))
print(s[1]:AddNumbers(1, 2, 3, 4))
print(s[1].LastAddSum)
print(p:AddNumbers(1, 2, 3, 4, 5))
print(p.LastAddSum)
print(p.Age)
p:IncreaseAge()
print(p.Age)
`

L := lua.NewState()
Expand All @@ -1307,19 +1320,25 @@ func Example__38() {
{Name: "Tim"},
}
s := []Person{
{Name: "Tim"},
{Name: "Tim", Age: 32},
}

L.SetGlobal("a", New(L, &a))
L.SetGlobal("s", New(L, s))
L.SetGlobal("p", New(L, s[0]))

if err := L.DoString(code); err != nil {
panic(err)
}

// Output:
// Tim counts: 15
// Tim counts: 10
// 10
// Tim counts: 15
// 15
// 32
// 33
}

func ExampleLState() {
Expand Down
44 changes: 38 additions & 6 deletions func.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ func getFunc(L *lua.LState) (ref reflect.Value, refType reflect.Type) {
return
}

func shouldConvertPtr(L *lua.LState) bool {
return bool(L.Get(lua.UpvalueIndex(2)).(lua.LBool))
}

func funcIsBypass(t reflect.Type) bool {
if t.NumIn() == 1 && t.NumOut() == 1 && t.In(0) == refTypeLStatePtr && t.Out(0) == refTypeInt {
return true
Expand All @@ -46,19 +50,28 @@ func funcIsBypass(t reflect.Type) bool {
func funcBypass(L *lua.LState) int {
ref, refType := getFunc(L)

convertPtr := shouldConvertPtr(L)
var receiver reflect.Value
var ud lua.LValue

luarState := LState{L}
args := make([]reflect.Value, 0, 2)
if refType.NumIn() == 2 {
receiverHint := refType.In(0)
receiver := lValueToReflect(L, L.Get(1), receiverHint)
ud = L.Get(1)
receiver = lValueToReflect(L, ud, receiverHint, convertPtr)
if receiver.Type() != receiverHint {
L.RaiseError("incorrect receiver type")
}
args = append(args, receiver)
L.Remove(1)
}
args = append(args, reflect.ValueOf(&luarState))
return ref.Call(args)[0].Interface().(int)
ret := ref.Call(args)[0].Interface().(int)
if receiver.IsValid() && convertPtr && receiver.Kind() == reflect.Ptr {
ud.(*lua.LUserData).Value = receiver.Elem().Interface()
}
return ret
}

func funcRegular(L *lua.LState) int {
Expand All @@ -73,6 +86,11 @@ func funcRegular(L *lua.LState) int {
if variadic && top < expected-1 {
L.RaiseError("invalid number of function arguments (%d or more expected, got %d)", expected-1, top)
}

convertPtr := shouldConvertPtr(L)
var receiver reflect.Value
var ud lua.LValue

args := make([]reflect.Value, top)
for i := 0; i < L.GetTop(); i++ {
var hint reflect.Type
Expand All @@ -81,9 +99,22 @@ func funcRegular(L *lua.LState) int {
} else {
hint = refType.In(i)
}
args[i] = lValueToReflect(L, L.Get(i+1), hint)
var arg reflect.Value
if i == 0 && convertPtr {
ud = L.Get(1)
arg = lValueToReflect(L, ud, hint, true)
receiver = arg
} else {
arg = lValueToReflect(L, L.Get(i+1), hint, false)
}
args[i] = arg
}
ret := ref.Call(args)

if receiver.IsValid() && convertPtr && receiver.Kind() == reflect.Ptr {
ud.(*lua.LUserData).Value = receiver.Elem().Interface()
}

if len(ret) == 1 && ret[0].Type() == refTypeLuaLValueSlice {
values := ret[0].Interface().([]lua.LValue)
for _, value := range values {
Expand All @@ -97,11 +128,12 @@ func funcRegular(L *lua.LState) int {
return len(ret)
}

func funcWrapper(L *lua.LState, fn reflect.Value) *lua.LFunction {
func funcWrapper(L *lua.LState, fn reflect.Value, convertToPtr bool) *lua.LFunction {
up := L.NewUserData()
up.Value = fn

if funcIsBypass(fn.Type()) {
return L.NewClosure(funcBypass, up)
return L.NewClosure(funcBypass, up, lua.LBool(convertToPtr))
}
return L.NewClosure(funcRegular, up)
return L.NewClosure(funcRegular, up, lua.LBool(convertToPtr))
}
24 changes: 16 additions & 8 deletions luar.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ func New(L *lua.LState, value interface{}) lua.LValue {
ud.Metatable = getMetatableFromValue(L, val)
return ud
case reflect.Func:
return funcWrapper(L, val)
return funcWrapper(L, val, false)
case reflect.Interface:
ud := L.NewUserData()
ud.Value = val.Interface()
Expand All @@ -96,7 +96,7 @@ func NewType(L *lua.LState, value interface{}) lua.LValue {
return ud
}

func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Value {
func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type, shouldConvertToPtr bool) reflect.Value {
if hint.Implements(refTypeLuaLValue) {
return reflect.ValueOf(v)
}
Expand Down Expand Up @@ -159,7 +159,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Val

for i := 0; i < hint.NumOut(); i++ {
outHint := hint.Out(i)
ret[i] = lValueToReflect(L, L.Get(-hint.NumOut()+i), outHint)
ret[i] = lValueToReflect(L, L.Get(-hint.NumOut()+i), outHint, false)
}

return ret
Expand All @@ -185,7 +185,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Val

for i := 0; i < len; i++ {
value := converted.RawGetInt(i + 1)
elemValue := lValueToReflect(L, value, elemType)
elemValue := lValueToReflect(L, value, elemType, false)
s.Index(i).Set(elemValue)
}

Expand All @@ -201,8 +201,8 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Val
return
}

lKey := lValueToReflect(L, key, keyType)
lValue := lValueToReflect(L, value, elemType)
lKey := lValueToReflect(L, key, keyType, false)
lValue := lValueToReflect(L, value, elemType, false)
s.SetMapIndex(lKey, lValue)
})

Expand Down Expand Up @@ -232,7 +232,7 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Val
}
field := hint.FieldByIndex(index)

lValue := lValueToReflect(L, value, field.Type)
lValue := lValueToReflect(L, value, field.Type, false)
t.FieldByIndex(field.Index).Set(lValue)
})

Expand All @@ -246,7 +246,15 @@ func lValueToReflect(L *lua.LState, v lua.LValue, hint reflect.Type) reflect.Val
return reflect.ValueOf(converted).Convert(hint)
}
case *lua.LUserData:
return reflect.ValueOf(converted.Value).Convert(hint)
val := reflect.ValueOf(converted.Value)
if val.Kind() != reflect.Ptr && hint.Kind() == reflect.Ptr && shouldConvertToPtr {
newVal := reflect.New(hint.Elem())
newVal.Elem().Set(val)
val = newVal
} else {
val = val.Convert(hint)
}
return val
}
L.RaiseError("fatal lValueToReflect error")
return reflect.Value{} // never returns
Expand Down
19 changes: 15 additions & 4 deletions map.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,26 @@ func mapIndex(L *lua.LState) int {
return 0
}

convertedKey := lValueToReflect(L, key, ref.Type().Key())
convertedKey := lValueToReflect(L, key, ref.Type().Key(), false)
item := ref.MapIndex(convertedKey)
if !item.IsValid() {

if !isPtr {
if lstring, ok := key.(lua.LString); ok {
if fn := mt.method(string(lstring)); fn != nil {
L.Push(fn)
return 1
}
}
}

if lstring, ok := key.(lua.LString); ok {
if fn := mt.method(string(lstring)); fn != nil {
if fn := mt.ptrMethod(string(lstring)); fn != nil {
L.Push(fn)
return 1
}
}

return 0
}
L.Push(New(L, item.Interface()))
Expand All @@ -56,13 +67,13 @@ func mapNewIndex(L *lua.LState) int {
key := L.CheckAny(2)
value := L.CheckAny(3)

convertedKey := lValueToReflect(L, key, ref.Type().Key())
convertedKey := lValueToReflect(L, key, ref.Type().Key(), false)
if convertedKey.Type() != ref.Type().Key() {
L.ArgError(2, "invalid map key type")
}
var convertedValue reflect.Value
if value != lua.LNil {
convertedValue = lValueToReflect(L, value, ref.Type().Elem())
convertedValue = lValueToReflect(L, value, ref.Type().Elem(), false)
if convertedValue.Type() != ref.Type().Elem() {
L.ArgError(3, "invalid map value type")
}
Expand Down
2 changes: 1 addition & 1 deletion ptr.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func ptrPow(L *lua.LState) int {
if !elem.CanSet() {
L.RaiseError("unable to set pointer value")
}
value := lValueToReflect(L, val, elem.Type())
value := lValueToReflect(L, val, elem.Type(), false)
elem.Set(value)
return 1
}
Expand Down
Loading

0 comments on commit ca74b80

Please sign in to comment.