Skip to content

Commit

Permalink
Aggresively refactor import generation
Browse files Browse the repository at this point in the history
The Generator now eagerly goes through all types used in methods of the
interface to be mocked and performs a lookup of to ensure that all
package import names are unique. This fixes support for imports that are
needed for methods coming from a nested interface whose names conflict
with imports coming from the main interface or other nested interfaces.

Fixes: vektra#94
  • Loading branch information
colonelpanic8 committed Jun 28, 2016
1 parent fb563f2 commit 97a72e5
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 110 deletions.
10 changes: 10 additions & 0 deletions mockery/fixtures/imports_from_nested_interface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package test

import (
"github.com/vektra/mockery/mockery/fixtures/http"
)

type HasConflictingNestedImports interface {
RequesterNS
Z() http.MyStruct
}
2 changes: 1 addition & 1 deletion mockery/fixtures/same_name_imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ import (
// Example is an example
type Example interface {
A() http.Flusher
B(my_http string) my_http.MyStruct
B(fixtureshttp string) my_http.MyStruct
}
200 changes: 113 additions & 87 deletions mockery/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,60 +7,128 @@ import (
"go/ast"
"go/types"
"io"
"log"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"unicode"

"code.uber.internal/go-common.git/x/log"

"golang.org/x/tools/imports"
)

var invalidIdentifierChar = regexp.MustCompile("[^[:digit:][:alpha:]_]")

func stripChars(str, chr string) string {
return strings.Map(func(r rune) rune {
if strings.IndexRune(chr, r) < 0 {
return r
}
return -1
}, str)
}

// Generator is responsible for generating the string containing
// imports and the mock struct that will later be written out as file.
type Generator struct {
buf bytes.Buffer

ip bool
iface *Interface
pkg string
packageToName map[string]string
imports []*ast.ImportSpec
localPackageName *string
localizationCache map[string]string
ip bool
iface *Interface
pkg string
localPackageName *string

importsWerePopulated bool
localizationCache map[string]string
packagePathToName map[string]string
nameToPackagePath map[string]string
}

// NewGenerator builds a Generator.
func NewGenerator(iface *Interface, pkg string, inPackage bool) *Generator {
return &Generator{
g := &Generator{
iface: iface,
pkg: pkg,
ip: inPackage,
localizationCache: make(map[string]string),
packagePathToName: make(map[string]string),
nameToPackagePath: make(map[string]string),
}
g.addPackageImportWithName("github.com/stretchr/testify/mock", "mock")
return g
}

func (g *Generator) populateImports() {
if g.importsWerePopulated {
return
}
for i := 0; i < g.iface.Type.NumMethods(); i++ {
fn := g.iface.Type.Method(i)
ftype := fn.Type().(*types.Signature)
g.addImportsFromTuple(ftype.Params())
g.addImportsFromTuple(ftype.Results())
}
}

func (g *Generator) addImportsFromTuple(list *types.Tuple) {
for i := 0; i < list.Len(); i++ {
// We use renderType here because we need to recursively
// resolve any types to make sure that all named types that
// will appear in the interface file are known
g.renderType(list.At(i).Type())
}
}

func (g *Generator) getLocalPackageName() string {
if g.localPackageName == nil {
localName := g.getLocalPackageNameFromPackageMap(g.getPackageToName())
g.localPackageName = &localName
func (g *Generator) addPackageImport(pkg *types.Package) string {
return g.addPackageImportWithName(pkg.Path(), pkg.Name())
}

func (g *Generator) addPackageImportWithName(path, name string) string {
path = g.getLocalizedPath(path)
if existingName, pathExists := g.packagePathToName[path]; pathExists {
return existingName
}
return *g.localPackageName

nonConflictingName := g.getNonConflictingName(path, name)
g.packagePathToName[path] = nonConflictingName
g.nameToPackagePath[nonConflictingName] = path
return nonConflictingName
}

func (g *Generator) getLocalPackageNameFromPackageMap(packageToName map[string]string) string {
localPackageName := g.iface.Pkg.Name()
for path, name := range packageToName {
if localPackageName == name && path != g.getInterfacePackagePath() {
return "_interfacePackage"
func (g *Generator) getNonConflictingName(path, name string) string {
if !g.importNameExists(name) {
return name
}
directories := strings.Split(path, string(filepath.Separator))

cleanedDirectories := make([]string, 0, len(directories))
for _, directory := range directories {
cleaned := invalidIdentifierChar.ReplaceAllString(directory, "_")
cleanedDirectories = append(cleanedDirectories, cleaned)
}
numDirectories := len(cleanedDirectories)
var prospectiveName string
for i := 1; i <= numDirectories; i++ {
prospectiveName = strings.Join(cleanedDirectories[numDirectories-i:], "")
if !g.importNameExists(prospectiveName) {
return prospectiveName
}
}
return localPackageName
// Try adding numbers to the given name
i := 2
for {
prospectiveName = fmt.Sprintf("%v%d", name, i)
if !g.importNameExists(prospectiveName) {
return prospectiveName
}
i++
}
}

func (g *Generator) getInterfacePackagePath() string {
return g.getLocalizedPathFromPackage(g.iface.Pkg)
func (g *Generator) importNameExists(name string) bool {
_, nameExists := g.nameToPackagePath[name]
return nameExists
}

func (g *Generator) getLocalizedPathFromPackage(pkg *types.Package) string {
Expand Down Expand Up @@ -97,7 +165,7 @@ func (g *Generator) getLocalizedPath(path string) string {
if err == nil {
toReturn = packagePath
} else {
log.Warn("Unable to localize path %v, %v", path, err)
log.Printf("Unable to localize path %v, %v", path, err)
}
}
}
Expand All @@ -106,40 +174,6 @@ func (g *Generator) getLocalizedPath(path string) string {
return toReturn
}

func (g *Generator) getPackageToName() map[string]string {
if g.packageToName == nil {
g.packageToName = make(map[string]string)
for _, imp := range g.iface.File.Imports {
importName, err := g.getNameForImport(imp)
if err == nil {
g.packageToName[g.unescapedImportPath(imp)] = importName
} else {
log.Warn(err)
}
}
g.packageToName[g.getInterfacePackagePath()] =
g.getLocalPackageNameFromPackageMap(g.packageToName)
}
return g.packageToName
}

func (g *Generator) unescapedImportPath(imp *ast.ImportSpec) string {
return strings.Replace(imp.Path.Value, "\"", "", -1)
}

func (g *Generator) getNameForImport(imp *ast.ImportSpec) (string, error) {
if imp.Name != nil {
return imp.Name.Name, nil
}
unescapedPath := g.unescapedImportPath(imp)
for _, p := range g.iface.Pkg.Imports() {
if g.getLocalizedPathFromPackage(p) == unescapedPath {
return p.Name(), nil
}
}
return "", fmt.Errorf("unable to find package name for import: %v", unescapedPath)
}

func (g *Generator) mockName() string {
if g.ip {
if ast.IsExported(g.iface.Name) {
Expand All @@ -158,34 +192,42 @@ func (g *Generator) mockName() string {
return g.iface.Name
}

func (g *Generator) unescapedImportPath(imp *ast.ImportSpec) string {
return strings.Replace(imp.Path.Value, "\"", "", -1)
}

func (g *Generator) getImportStringFromSpec(imp *ast.ImportSpec) string {
if name, ok := g.getPackageToName()[g.unescapedImportPath(imp)]; ok {
if name, ok := g.packagePathToName[g.unescapedImportPath(imp)]; ok {
return fmt.Sprintf("import %s %s\n", name, imp.Path.Value)
}
return fmt.Sprintf("import %s\n", imp.Path.Value)
}

func (g *Generator) generateImports() {
if g.iface.File.Imports == nil {
return
func (g *Generator) sortedImportNames() (importNames []string) {
for name := range g.nameToPackagePath {
importNames = append(importNames, name)
}
sort.Strings(importNames)
return
}

for _, imp := range g.iface.File.Imports {
g.printf(g.getImportStringFromSpec(imp))
func (g *Generator) generateImports() {
// Sort by import name so that we get a deterministic order
for _, name := range g.sortedImportNames() {
path := g.nameToPackagePath[name]
g.printf("import %s \"%s\"\n", name, path)
}
}

// GeneratePrologue generates the prologue of the mock.
func (g *Generator) GeneratePrologue(pkg string) {
g.populateImports()
if g.ip {
g.printf("package %s\n\n", g.iface.Pkg.Name())
} else {
g.printf("package %v\n\n", pkg)
g.printf(
"import %s \"%s\"\n", g.getLocalPackageName(), g.getInterfacePackagePath(),
)
}
g.printf("import \"github.com/stretchr/testify/mock\"\n\n")

g.generateImports()
g.printf("\n")
}
Expand Down Expand Up @@ -242,24 +284,14 @@ type namer interface {
Name() string
}

func (g *Generator) getInFilePackageNameFromPackage(p *types.Package) string {
path := p.Path()
path = g.getLocalizedPathFromPackage(p)
if name, ok := g.getPackageToName()[path]; ok {
return name
}
log.Warnf("Could not find package name for %v", path)
return p.Name()
}

func (g *Generator) renderType(t types.Type) string {
switch t := t.(type) {
case *types.Named:
o := t.Obj()
if o.Pkg() == nil || o.Pkg().Name() == "main" || (g.ip && o.Pkg().Name() == g.pkg) {
return o.Name()
}
return g.getInFilePackageNameFromPackage(o.Pkg()) + "." + o.Name()
return g.addPackageImport(o.Pkg()) + "." + o.Name()
case *types.Basic:
return t.Name()
case *types.Pointer:
Expand Down Expand Up @@ -398,15 +430,8 @@ func (g *Generator) genList(list *types.Tuple, varadic bool) *paramList {
func (g *Generator) nameCollides(pname string) bool {
if pname == g.pkg {
return true
} else if g.iface.Pkg != nil {
for _, imp := range g.iface.Pkg.Imports() {
if g.getInFilePackageNameFromPackage(imp) == pname {
// Argument is same as that of an imported package
return true
}
}
}
return false
return g.importNameExists(pname)
}

// ErrNotSetup is returned when the generator is not configured.
Expand All @@ -415,6 +440,7 @@ var ErrNotSetup = errors.New("not setup")
// Generate builds a string that constitutes a valid go source file
// containing the mock of the relevant interface.
func (g *Generator) Generate() error {
g.populateImports()
if g.iface == nil {
return ErrNotSetup
}
Expand Down
Loading

0 comments on commit 97a72e5

Please sign in to comment.