Skip to content

Commit

Permalink
Handle collisions of imports and interface package
Browse files Browse the repository at this point in the history
  • Loading branch information
colonelpanic8 committed Jun 25, 2016
1 parent da23b10 commit 0d607e7
Show file tree
Hide file tree
Showing 4 changed files with 196 additions and 41 deletions.
13 changes: 13 additions & 0 deletions mockery/fixtures/imports_same_as_package.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package test

import (
"github.com/vektra/mockery/mockery/fixtures/test"
)

type C int

type ImportsSameAsPackage interface {
A() test.B
B() KeyManager
C(C)
}
3 changes: 3 additions & 0 deletions mockery/fixtures/test/test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package test

type B int
120 changes: 90 additions & 30 deletions mockery/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
"strings"
"unicode"

"code.uber.internal/go-common.git/x/log"

"golang.org/x/tools/imports"
)

Expand All @@ -20,36 +22,90 @@ import (
type Generator struct {
buf bytes.Buffer

ip bool
iface *Interface
pkg string
packageToName map[string]string
imports []*ast.ImportSpec
ip bool
iface *Interface
pkg string
packageToName map[string]string
imports []*ast.ImportSpec
localPackageName *string
}

// NewGenerator builds a Generator.
func NewGenerator(iface *Interface, pkg string, inPackage bool) *Generator {
return &Generator{
iface: iface,
pkg: pkg,
packageToName: make(map[string]string),
imports: iface.File.Imports,
ip: inPackage,
iface: iface,
pkg: pkg,
ip: inPackage,
}
}

func (g *Generator) getLocalPackageName() string {
if g.localPackageName == nil {
localName := g.getLocalPackageNameFromPackageMap(g.getPackageToName())
g.localPackageName = &localName
}
return *g.localPackageName
}

func (g *Generator) getLocalPackageNameFromPackageMap(packageToName map[string]string) string {
localPackageName := g.iface.Pkg.Name()
for path, name := range packageToName {
if localPackageName == name && path != g.getInterfacePackagePath() {
return "_interfacePackage"
}
}
return localPackageName
}

func (g *Generator) getInterfacePackagePath() string {
return g.getLocalizedPath(g.iface.Pkg.Path())
}

func (g *Generator) getLocalizedPath(path string) string {
local, err := filepath.Rel(
filepath.Join(os.Getenv("GOPATH"), "src"),
filepath.Dir(path),
)
if err != nil {
panic("unable to figure out path for package")
}
return local
}

func (g *Generator) getPackageToName() map[string]string {
if true {
if g.packageToName == nil {
g.packageToName = make(map[string]string)
for _, imp := range g.imports {
if imp.Name != nil {
g.packageToName[strings.Replace(imp.Path.Value, "\"", "", -1)] = imp.Name.Name
for _, imp := range g.iface.File.Imports {
importName, err := g.getNameForImport(imp)
if err == nil {
g.packageToName[g.unescapedImportPath(imp)] = importName
} else {
log.Warn(err)
}
}
g.packageToName[g.getInterfacePackagePath()] =
g.getLocalPackageNameFromPackageMap(g.packageToName)
}
return g.packageToName
}

func (g *Generator) unescapedImportPath(imp *ast.ImportSpec) string {
return strings.Replace(imp.Path.Value, "\"", "", -1)
}

func (g *Generator) getNameForImport(imp *ast.ImportSpec) (string, error) {
if imp.Name != nil {
return imp.Name.Name, nil
}
unescapedPath := g.unescapedImportPath(imp)
for _, p := range g.iface.Pkg.Imports() {
if p.Path() == unescapedPath {
return p.Name(), nil
}
}
return "", fmt.Errorf("Unable to find package name for import")
}

func (g *Generator) mockName() string {
if g.ip {
if ast.IsExported(g.iface.Name) {
Expand All @@ -68,34 +124,31 @@ func (g *Generator) mockName() string {
return g.iface.Name
}

func (g *Generator) getImportStringFromSpec(imp *ast.ImportSpec) string {
if name, ok := g.getPackageToName()[g.unescapedImportPath(imp)]; ok {
return fmt.Sprintf("import %s %s\n", name, imp.Path.Value)
}
return fmt.Sprintf("import %s\n", imp.Path.Value)
}

func (g *Generator) generateImports() {
if g.iface.File.Imports == nil {
return
}
for _, imp := range g.iface.File.Imports {
if imp.Name == nil {
g.printf("import %s\n", imp.Path.Value)
} else {
g.printf("import %s %s\n", imp.Name.Name, imp.Path.Value)
}
g.printf(g.getImportStringFromSpec(imp))
}
}

// GeneratePrologue generates the prologue of the mock.
func (g *Generator) GeneratePrologue(pkg string) {
if g.ip {
g.printf("package %s\n\n", g.iface.File.Name)
g.printf("package %s\n\n", g.iface.Pkg.Name())
} else {
g.printf("package %v\n\n", pkg)
local, err := filepath.Rel(
filepath.Join(os.Getenv("GOPATH"), "src"),
filepath.Dir(g.iface.Path),
g.printf(
"import %s \"%s\"\n", g.getLocalPackageName(), g.getInterfacePackagePath(),
)
if err != nil {
panic("unable to figure out path for package")
}

g.printf("import \"%s\"\n", local)
}
g.printf("import \"github.com/stretchr/testify/mock\"\n\n")
g.generateImports()
Expand Down Expand Up @@ -155,7 +208,14 @@ type namer interface {
}

func (g *Generator) getInFilePackageNameFromPackage(p *types.Package) string {
if name, ok := g.getPackageToName()[p.Path()]; ok {
path := p.Path()
if strings.HasPrefix(path, "/") {
path = g.getLocalizedPath(path)
}
// if path == g.getInterfacePackagePath() {
// return g.getLocalPackageName()
// }
if name, ok := g.getPackageToName()[path]; ok {
return name
}
return p.Name()
Expand All @@ -165,7 +225,7 @@ func (g *Generator) renderType(t types.Type) string {
switch t := t.(type) {
case *types.Named:
o := t.Obj()
if o.Pkg() == nil || o.Pkg().Name() == "main" || o.Pkg().Name() == g.pkg {
if o.Pkg() == nil || o.Pkg().Name() == "main" || (g.ip && o.Pkg().Name() == g.pkg) {
return o.Name()
}
return g.getInFilePackageNameFromPackage(o.Pkg()) + "." + o.Name()
Expand Down
101 changes: 90 additions & 11 deletions mockery/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,22 +188,22 @@ func (s *GeneratorSuite) TestGeneratorPrologue() {
generator := s.getGenerator(testFile, "Requester", false)
expected := `package mocks
import "` + s.getInterfaceRelPath(generator.iface) + `"
import test "` + s.getInterfaceRelPath(generator.iface) + `"
import "github.com/stretchr/testify/mock"
`
s.checkPrologueGeneration(generator, expected)
}

func (s *GeneratorSuite) TestGeneratorProloguewithImports() {
func (s *GeneratorSuite) TestGeneratorPrologueWithImports() {
generator := s.getGenerator("requester_ns.go", "RequesterNS", false)
expected := `package mocks
import "` + s.getInterfaceRelPath(generator.iface) + `"
import test "` + s.getInterfaceRelPath(generator.iface) + `"
import "github.com/stretchr/testify/mock"
import "net/http"
import http "net/http"
`
s.checkPrologueGeneration(generator, expected)
Expand All @@ -214,10 +214,10 @@ func (s *GeneratorSuite) TestGeneratorPrologueWithMultipleImportsSameName() {

expected := `package mocks
import "` + s.getInterfaceRelPath(generator.iface) + `"
import test "` + s.getInterfaceRelPath(generator.iface) + `"
import "github.com/stretchr/testify/mock"
import "net/http"
import http "net/http"
import my_http "github.com/vektra/mockery/mockery/fixtures/http"
`
Expand Down Expand Up @@ -483,7 +483,7 @@ type KeyManager struct {
}
// GetKey provides a mock function with given fields: _a0, _a1
func (_m *KeyManager) GetKey(_a0 string, _a1 uint16) ([]byte, *Err) {
func (_m *KeyManager) GetKey(_a0 string, _a1 uint16) ([]byte, *test.Err) {
ret := _m.Called(_a0, _a1)
var r0 []byte
Expand All @@ -495,12 +495,12 @@ func (_m *KeyManager) GetKey(_a0 string, _a1 uint16) ([]byte, *Err) {
}
}
var r1 *Err
if rf, ok := ret.Get(1).(func(string, uint16) *Err); ok {
var r1 *test.Err
if rf, ok := ret.Get(1).(func(string, uint16) *test.Err); ok {
r1 = rf(_a0, _a1)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).(*Err)
r1 = ret.Get(1).(*test.Err)
}
}
Expand Down Expand Up @@ -852,7 +852,7 @@ type UsesOtherPkgIface struct {
}
// DoSomethingElse provides a mock function with given fields: obj
func (_m *UsesOtherPkgIface) DoSomethingElse(obj Sibling) {
func (_m *UsesOtherPkgIface) DoSomethingElse(obj test.Sibling) {
_m.Called(obj)
}
`
Expand All @@ -862,6 +862,23 @@ func (_m *UsesOtherPkgIface) DoSomethingElse(obj Sibling) {
)
}

func (s *GeneratorSuite) TestGeneratorForMethodUsingInterfaceInPackage() {
expected := `// MockUsesOtherPkgIface is an autogenerated mock type for the UsesOtherPkgIface type
type MockUsesOtherPkgIface struct {
mock.Mock
}
// DoSomethingElse provides a mock function with given fields: obj
func (_m *MockUsesOtherPkgIface) DoSomethingElse(obj Sibling) {
_m.Called(obj)
}
`
s.checkGeneration(
filepath.Join(fixturePath, "mock_method_uses_pkg_iface.go"),
"UsesOtherPkgIface", true, expected,
)
}

func (s *GeneratorSuite) TestGeneratorWithAliasing() {
expected := `// Example is an autogenerated mock type for the Example type
type Example struct {
Expand Down Expand Up @@ -903,6 +920,68 @@ func (_m *Example) B(_a0 string) my_http.MyStruct {
)
}

func (s *GeneratorSuite) TestGeneratorWithImportSameAsLocalPackage() {
expected := `// ImportsSameAsPackage is an autogenerated mock type for the ImportsSameAsPackage type
type ImportsSameAsPackage struct {
mock.Mock
}
// A provides a mock function with given fields:
func (_m *ImportsSameAsPackage) A() test.B {
ret := _m.Called()
var r0 test.B
if rf, ok := ret.Get(0).(func() test.B); ok {
r0 = rf()
} else {
r0 = ret.Get(0).(test.B)
}
return r0
}
// B provides a mock function with given fields:
func (_m *ImportsSameAsPackage) B() _interfacePackage.KeyManager {
ret := _m.Called()
var r0 _interfacePackage.KeyManager
if rf, ok := ret.Get(0).(func() _interfacePackage.KeyManager); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(_interfacePackage.KeyManager)
}
}
return r0
}
// C provides a mock function with given fields: _a0
func (_m *ImportsSameAsPackage) C(_a0 _interfacePackage.C) {
_m.Called(_a0)
}
`
s.checkGeneration(
"imports_same_as_package.go", "ImportsSameAsPackage", false,
expected,
)
}

func (s *GeneratorSuite) TestPrologueWithImportSameAsLocalPackage() {
generator := s.getGenerator(
"imports_same_as_package.go", "ImportsSameAsPackage", false,
)
s.getInterfaceRelPath(generator.iface)
expected := `package mocks
import _interfacePackage "` + s.getInterfaceRelPath(generator.iface) + `"
import "github.com/stretchr/testify/mock"
import test "github.com/vektra/mockery/mockery/fixtures/test"
`

s.checkPrologueGeneration(generator, expected)
}

func TestGeneratorSuite(t *testing.T) {
generatorSuite := new(GeneratorSuite)
suite.Run(t, generatorSuite)
Expand Down

0 comments on commit 0d607e7

Please sign in to comment.