diff --git a/cmd/swag/main.go b/cmd/swag/main.go index b2242423e..db84a1695 100644 --- a/cmd/swag/main.go +++ b/cmd/swag/main.go @@ -60,8 +60,9 @@ var initFlags = []cli.Flag{ Usage: "Parse go files in 'vendor' folder, disabled by default", }, &cli.BoolFlag{ - Name: parseDependencyFlag, - Usage: "Parse go files in outside dependency folder, disabled by default", + Name: parseDependencyFlag, + Aliases: []string{"pd"}, + Usage: "Parse go files inside dependency folder, disabled by default", }, &cli.StringFlag{ Name: markdownFilesFlag, diff --git a/packages.go b/packages.go index 8dbfb4de2..bff103dc7 100644 --- a/packages.go +++ b/packages.go @@ -2,7 +2,10 @@ package swag import ( "go/ast" + goparser "go/parser" "go/token" + "golang.org/x/tools/go/loader" + "os" "path/filepath" "sort" "strings" @@ -95,51 +98,58 @@ func (pkgs *PackagesDefinitions) RangeFiles(handle func(filename string, file *a func (pkgs *PackagesDefinitions) ParseTypes() (map[*TypeSpecDef]*Schema, error) { parsedSchemas := make(map[*TypeSpecDef]*Schema) for astFile, info := range pkgs.files { - for _, astDeclaration := range astFile.Decls { - generalDeclaration, ok := astDeclaration.(*ast.GenDecl) - if ok && generalDeclaration.Tok == token.TYPE { - for _, astSpec := range generalDeclaration.Specs { - typeSpec, ok := astSpec.(*ast.TypeSpec) - if ok { - typeSpecDef := &TypeSpecDef{ - PkgPath: info.PackagePath, - File: astFile, - TypeSpec: typeSpec, - } + pkgs.parseTypesFromFile(astFile, info.PackagePath, parsedSchemas) + } + return parsedSchemas, nil +} - idt, ok := typeSpec.Type.(*ast.Ident) - if ok && IsGolangPrimitiveType(idt.Name) { - parsedSchemas[typeSpecDef] = &Schema{ - PkgPath: typeSpecDef.PkgPath, - Name: astFile.Name.Name, - Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)), - } - } +func (pkgs *PackagesDefinitions) parseTypesFromFile(astFile *ast.File, packagePath string, parsedSchemas map[*TypeSpecDef]*Schema) { + for _, astDeclaration := range astFile.Decls { + if generalDeclaration, ok := astDeclaration.(*ast.GenDecl); ok && generalDeclaration.Tok == token.TYPE { + for _, astSpec := range generalDeclaration.Specs { + if typeSpec, ok := astSpec.(*ast.TypeSpec); ok { + typeSpecDef := &TypeSpecDef{ + PkgPath: packagePath, + File: astFile, + TypeSpec: typeSpec, + } - if pkgs.uniqueDefinitions == nil { - pkgs.uniqueDefinitions = make(map[string]*TypeSpecDef) + if idt, ok := typeSpec.Type.(*ast.Ident); ok && IsGolangPrimitiveType(idt.Name) && parsedSchemas != nil { + parsedSchemas[typeSpecDef] = &Schema{ + PkgPath: typeSpecDef.PkgPath, + Name: astFile.Name.Name, + Schema: PrimitiveSchema(TransToValidSchemeType(idt.Name)), } + } - fullName := typeSpecDef.FullName() - anotherTypeDef, ok := pkgs.uniqueDefinitions[fullName] - if ok { - if typeSpecDef.PkgPath == anotherTypeDef.PkgPath { - continue - } else { - delete(pkgs.uniqueDefinitions, fullName) - } + if pkgs.uniqueDefinitions == nil { + pkgs.uniqueDefinitions = make(map[string]*TypeSpecDef) + } + + fullName := typeSpecDef.FullName() + anotherTypeDef, ok := pkgs.uniqueDefinitions[fullName] + if ok { + if typeSpecDef.PkgPath == anotherTypeDef.PkgPath { + continue } else { - pkgs.uniqueDefinitions[fullName] = typeSpecDef + delete(pkgs.uniqueDefinitions, fullName) } + } else { + pkgs.uniqueDefinitions[fullName] = typeSpecDef + } + if pkgs.packages[typeSpecDef.PkgPath] == nil { + pkgs.packages[typeSpecDef.PkgPath] = &PackageDefinitions{ + Name: astFile.Name.Name, + TypeDefinitions: map[string]*TypeSpecDef{typeSpecDef.Name(): typeSpecDef}, + } + } else if _, ok = pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()]; !ok { pkgs.packages[typeSpecDef.PkgPath].TypeDefinitions[typeSpecDef.Name()] = typeSpecDef } } } } } - - return parsedSchemas, nil } func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) *TypeSpecDef { @@ -157,11 +167,43 @@ func (pkgs *PackagesDefinitions) findTypeSpec(pkgPath string, typeName string) * return nil } +func (pkgs *PackagesDefinitions) loadExternalPackage(importPath string) error { + cwd, err := os.Getwd() + if err != nil { + return err + } + + conf := loader.Config{ + ParserMode: goparser.ParseComments, + Cwd: cwd, + } + + conf.Import(importPath) + + lprog, err := conf.Load() + if err != nil { + return err + } + + for _, info := range lprog.AllPackages { + pkgPath := info.Pkg.Path() + if strings.HasPrefix(pkgPath, "vendor/") { + pkgPath = pkgPath[7:] + } + for _, astFile := range info.Files { + pkgs.parseTypesFromFile(astFile, pkgPath, nil) + } + } + + return nil +} + // findPackagePathFromImports finds out the package path of a package via ranging imports of a ast.File // @pkg the name of the target package // @file current ast.File in which to search imports +// @fuzzy search for the package path that the last part matches the @pkg if true // @return the package path of a package of @pkg. -func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File) string { +func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *ast.File, fuzzy bool) string { if file == nil { return "" } @@ -172,6 +214,14 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as hasAnonymousPkg := false + matchLastPathPart := func(pkgPath string) bool { + paths := strings.Split(pkgPath, "/") + if paths[len(paths)-1] == pkg { + return true + } + return false + } + // prior to match named package for _, imp := range file.Imports { if imp.Name != nil { @@ -186,11 +236,12 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as } if pkgs.packages != nil { path := strings.Trim(imp.Path.Value, `"`) - pd, ok := pkgs.packages[path] - if ok { - if pd.Name == pkg { + if fuzzy { + if matchLastPathPart(path) { return path } + } else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg { + return path } } } @@ -203,11 +254,12 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as } if imp.Name.Name == "_" { path := strings.Trim(imp.Path.Value, `"`) - pd, ok := pkgs.packages[path] - if ok { - if pd.Name == pkg { + if fuzzy { + if matchLastPathPart(path) { return path } + } else if pd, ok := pkgs.packages[path]; ok && pd.Name == pkg { + return path } } } @@ -220,7 +272,7 @@ func (pkgs *PackagesDefinitions) findPackagePathFromImports(pkg string, file *as // @typeName the name of the target type, if it starts with a package name, find its own package path from imports on top of @file // @file the ast.file in which @typeName is used // @pkgPath the package path of @file. -func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) *TypeSpecDef { +func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File, parseDependency bool) *TypeSpecDef { if IsGolangPrimitiveType(typeName) { return nil } @@ -248,10 +300,19 @@ func (pkgs *PackagesDefinitions) FindTypeSpec(typeName string, file *ast.File) * return typeDef } } - - pkgPath := pkgs.findPackagePathFromImports(parts[0], file) - if len(pkgPath) == 0 && parts[0] == file.Name.Name { - pkgPath = pkgs.files[file].PackagePath + pkgPath := pkgs.findPackagePathFromImports(parts[0], file, false) + if len(pkgPath) == 0 { + //check if the current package + if parts[0] == file.Name.Name { + pkgPath = pkgs.files[file].PackagePath + } else if parseDependency { + //take it as an external package, needs to be loaded + if pkgPath = pkgs.findPackagePathFromImports(parts[0], file, true); len(pkgPath) > 0 { + if err := pkgs.loadExternalPackage(pkgPath); err != nil { + return nil + } + } + } } return pkgs.findTypeSpec(pkgPath, parts[1]) diff --git a/parser.go b/parser.go index ed670d4e9..2adaadf5b 100644 --- a/parser.go +++ b/parser.go @@ -758,7 +758,7 @@ func (parser *Parser) getTypeSchema(typeName string, file *ast.File, ref bool) ( return PrimitiveSchema(schemaType), nil } - typeSpecDef := parser.packages.FindTypeSpec(typeName, file) + typeSpecDef := parser.packages.FindTypeSpec(typeName, file, parser.ParseDependency) if typeSpecDef == nil { return nil, fmt.Errorf("cannot find type definition: %s", typeName) } diff --git a/parser_test.go b/parser_test.go index 5d06f0bca..4867c083d 100644 --- a/parser_test.go +++ b/parser_test.go @@ -1988,6 +1988,20 @@ func TestParseConflictSchemaName(t *testing.T) { assert.Equal(t, string(expected), string(b)) } +func TestParseExternalModels(t *testing.T) { + searchDir := "testdata/external_models/main" + mainAPIFile := "main.go" + p := New() + p.ParseDependency = true + err := p.ParseAPI(searchDir, mainAPIFile, defaultParseDepth) + assert.NoError(t, err) + b, _ := json.MarshalIndent(p.swagger, "", " ") + //ioutil.WriteFile("./testdata/external_models/main/expected.json",b,0777) + expected, err := ioutil.ReadFile(filepath.Join(searchDir, "expected.json")) + assert.NoError(t, err) + assert.Equal(t, string(expected), string(b)) +} + func TestParser_ParseStructArrayObject(t *testing.T) { t.Parallel() diff --git a/testdata/external_models/external/model.go b/testdata/external_models/external/model.go new file mode 100644 index 000000000..199ea2d24 --- /dev/null +++ b/testdata/external_models/external/model.go @@ -0,0 +1,7 @@ +package external + +import "github.com/urfave/cli/v2" + +type MyError struct { + cli.Author +} diff --git a/testdata/external_models/main/api/api.go b/testdata/external_models/main/api/api.go new file mode 100644 index 000000000..9b12d8bda --- /dev/null +++ b/testdata/external_models/main/api/api.go @@ -0,0 +1,18 @@ +package api + +import ( + "net/http" +) + +// GetExternalModels example +// @Summary parse external models +// @Description get string by ID +// @ID get_external_models +// @Accept json +// @Produce json +// @Success 200 {string} string "ok" +// @Failure 400 {object} http.Header "from internal pkg" +// @Router /testapi/external_models [get] +func GetExternalModels(w http.ResponseWriter, r *http.Request) { + +} diff --git a/testdata/external_models/main/expected.json b/testdata/external_models/main/expected.json new file mode 100644 index 000000000..d8c5e5e3d --- /dev/null +++ b/testdata/external_models/main/expected.json @@ -0,0 +1,50 @@ +{ + "swagger": "2.0", + "info": { + "description": "Parse external models.", + "title": "Swagger Example API", + "contact": {}, + "version": "1.0" + }, + "basePath": "/v1", + "paths": { + "/testapi/external_models": { + "get": { + "description": "get string by ID", + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "summary": "parse external models", + "operationId": "get_external_models", + "responses": { + "200": { + "description": "ok", + "schema": { + "type": "string" + } + }, + "400": { + "description": "from internal pkg", + "schema": { + "$ref": "#/definitions/http.Header" + } + } + } + } + } + }, + "definitions": { + "http.Header": { + "type": "object", + "additionalProperties": { + "type": "array", + "items": { + "type": "string" + } + } + } + } +} \ No newline at end of file diff --git a/testdata/external_models/main/main.go b/testdata/external_models/main/main.go new file mode 100644 index 000000000..07d306259 --- /dev/null +++ b/testdata/external_models/main/main.go @@ -0,0 +1,8 @@ +package main + +// @title Swagger Example API +// @version 1.0 +// @description Parse external models. +// @BasePath /v1 +func main() { +}