Skip to content

Commit

Permalink
Merge pull request vektra#87 from arbortech/many-mockery-fixes
Browse files Browse the repository at this point in the history
Many mockery fixes
  • Loading branch information
evanphx committed Jun 6, 2016
2 parents 3603a98 + 79510c3 commit 70d3236
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 61 deletions.
5 changes: 5 additions & 0 deletions mockery/fixtures/invalid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package test

// If you reach this build error, it means that Parser is trying to parse
// *_test.go files, which it shouldn't.
var x = y
9 changes: 9 additions & 0 deletions mockery/fixtures/mock_method_uses_pkg_iface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
package test

type Sibling interface {
DoSomething()
}

type UsesOtherPkgIface interface {
DoSomethingElse(obj Sibling)
}
52 changes: 27 additions & 25 deletions mockery/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ type Generator struct {

ip bool
iface *Interface
pkg string
}

func NewGenerator(iface *Interface) *Generator {
func NewGenerator(iface *Interface, pkg string) *Generator {
return &Generator{
iface: iface,
pkg: pkg,
}
}

Expand Down Expand Up @@ -145,56 +147,56 @@ type namer interface {
Name() string
}

func renderType(t types.Type) string {
func (g *Generator) renderType(t types.Type) string {
switch t := t.(type) {
case *types.Named:
o := t.Obj()
if o.Pkg() == nil || o.Pkg().Name() == "main" {
if o.Pkg() == nil || o.Pkg().Name() == "main" || o.Pkg().Name() == g.pkg {
return o.Name()
} else {
return o.Pkg().Name() + "." + o.Name()
}
case *types.Basic:
return t.Name()
case *types.Pointer:
return "*" + renderType(t.Elem())
return "*" + g.renderType(t.Elem())
case *types.Slice:
return "[]" + renderType(t.Elem())
return "[]" + g.renderType(t.Elem())
case *types.Array:
return fmt.Sprintf("[%d]%s", t.Len(), renderType(t.Elem()))
return fmt.Sprintf("[%d]%s", t.Len(), g.renderType(t.Elem()))
case *types.Signature:
switch t.Results().Len() {
case 0:
return fmt.Sprintf(
"func(%s)",
renderTypeTuple(t.Params()),
g.renderTypeTuple(t.Params()),
)
case 1:
return fmt.Sprintf(
"func(%s) %s",
renderTypeTuple(t.Params()),
renderType(t.Results().At(0).Type()),
g.renderTypeTuple(t.Params()),
g.renderType(t.Results().At(0).Type()),
)
default:
return fmt.Sprintf(
"func(%s)(%s)",
renderTypeTuple(t.Params()),
renderTypeTuple(t.Results()),
g.renderTypeTuple(t.Params()),
g.renderTypeTuple(t.Results()),
)
}
case *types.Map:
kt := renderType(t.Key())
vt := renderType(t.Elem())
kt := g.renderType(t.Key())
vt := g.renderType(t.Elem())

return fmt.Sprintf("map[%s]%s", kt, vt)
case *types.Chan:
switch t.Dir() {
case types.SendRecv:
return "chan " + renderType(t.Elem())
return "chan " + g.renderType(t.Elem())
case types.RecvOnly:
return "<-chan " + renderType(t.Elem())
return "<-chan " + g.renderType(t.Elem())
default:
return "chan<- " + renderType(t.Elem())
return "chan<- " + g.renderType(t.Elem())
}
case *types.Struct:
var fields []string
Expand All @@ -203,9 +205,9 @@ func renderType(t types.Type) string {
f := t.Field(i)

if f.Anonymous() {
fields = append(fields, renderType(f.Type()))
fields = append(fields, g.renderType(f.Type()))
} else {
fields = append(fields, fmt.Sprintf("%s %s", f.Name(), renderType(f.Type())))
fields = append(fields, fmt.Sprintf("%s %s", f.Name(), g.renderType(f.Type())))
}
}

Expand All @@ -223,13 +225,13 @@ func renderType(t types.Type) string {
}
}

func renderTypeTuple(tup *types.Tuple) string {
func (g *Generator) renderTypeTuple(tup *types.Tuple) string {
var parts []string

for i := 0; i < tup.Len(); i++ {
v := tup.At(i)

parts = append(parts, renderType(v.Type()))
parts = append(parts, g.renderType(v.Type()))
}

return strings.Join(parts, " , ")
Expand All @@ -252,7 +254,7 @@ type paramList struct {
Nilable []bool
}

func genList(list *types.Tuple, varadic bool) *paramList {
func (g *Generator) genList(list *types.Tuple, varadic bool) *paramList {
var params paramList

if list == nil {
Expand All @@ -262,13 +264,13 @@ func genList(list *types.Tuple, varadic bool) *paramList {
for i := 0; i < list.Len(); i++ {
v := list.At(i)

ts := renderType(v.Type())
ts := g.renderType(v.Type())

if varadic && i == list.Len()-1 {
t := v.Type()
switch t := t.(type) {
case *types.Slice:
ts = "..." + renderType(t.Elem())
ts = "..." + g.renderType(t.Elem())
default:
panic("bad varadic type!")
}
Expand Down Expand Up @@ -305,8 +307,8 @@ func (g *Generator) Generate() error {
ftype := fn.Type().(*types.Signature)
fname := fn.Name()

params := genList(ftype.Params(), ftype.Variadic())
returns := genList(ftype.Results(), false)
params := g.genList(ftype.Params(), ftype.Variadic())
returns := g.genList(ftype.Results(), false)

g.printf("// %s provides a mock function with given fields: %s\n", fname, strings.Join(params.Names, ", "))
g.printf("func (_m *%s) %s(%s) ", g.mockName(), fname, strings.Join(params.Params, ", "))
Expand Down
Loading

0 comments on commit 70d3236

Please sign in to comment.