Skip to content

Commit

Permalink
Extended generics support (swaggo#1277)
Browse files Browse the repository at this point in the history
* feat: add support for nested generics

nested generics support and related tests added

* fix: Multiple usage of same generic generate different definition paths

cache generic definitions by full name

* feat: add support for generic array parameter

- allow usage of arrays as parameter definitions
- tests extended and new body param added to tests

* feat: Add support for generic properties

- get generic field type
- support built in types in structs

refs swaggo#1213

* feat: Support custom model names for generics

add prefix to generic model names, to prevent renaming, if name annotation exists

* fix: Check if generic name starts with pkgName

- The first underscore was replaced instead of checking if the generated name even starts with the package name.
- New Tests added to test the name generation
- schema test extended to test the new behavior

* refactor: Apply suggested changes from PR
  • Loading branch information
FabianMartin authored Aug 2, 2022
1 parent cc25410 commit 2f148dd
Show file tree
Hide file tree
Showing 29 changed files with 3,720 additions and 397 deletions.
201 changes: 176 additions & 25 deletions generics.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,35 @@
package swag

import (
"fmt"
"go/ast"
"strings"
)

var genericsDefinitions = map[*TypeSpecDef]map[string]*TypeSpecDef{}

type genericTypeSpec struct {
ArrayDepth int
TypeSpec *TypeSpecDef
Name string
}

func (s *genericTypeSpec) Type() ast.Expr {
if s.TypeSpec != nil {
return s.TypeSpec.TypeSpec.Type
}

return &ast.Ident{Name: s.Name}
}

func (s *genericTypeSpec) TypeDocName() string {
if s.TypeSpec != nil {
return strings.Replace(TypeDocName(s.TypeSpec.FullName(), s.TypeSpec.TypeSpec), "-", "_", -1)
}

return s.Name
}

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
fullName := typeSpecDef.FullName()

Expand All @@ -26,29 +51,44 @@ func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return fullName
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
genericParams := strings.Split(strings.TrimRight(fullGenericForm, "]"), "[")
if len(genericParams) == 1 {
return nil
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
if spec, ok := genericsDefinitions[original][fullGenericForm]; ok {
return spec
}

genericParams = strings.Split(genericParams[1], ",")
for i, p := range genericParams {
genericParams[i] = strings.TrimSpace(p)
pkgName := strings.Split(fullGenericForm, ".")[0]
genericTypeName, genericParams := splitStructName(fullGenericForm)
if genericParams == nil {
return nil
}
genericParamTypeDefs := map[string]*TypeSpecDef{}

genericParamTypeDefs := map[string]*genericTypeSpec{}
if len(genericParams) != len(original.TypeSpec.TypeParams.List) {
return nil
}

for i, genericParam := range genericParams {
tdef, ok := pkgDefs.uniqueDefinitions[genericParam]
if !ok {
return nil
arrayDepth := 0
for {
if len(genericParam) <= 2 || genericParam[:2] != "[]" {
break
}
genericParam = genericParam[2:]
arrayDepth++
}

genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = tdef
tdef := pkgDefs.FindTypeSpec(genericParam, original.File, parseDependency)
if tdef == nil {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
Name: genericParam,
}
} else {
genericParamTypeDefs[original.TypeSpec.TypeParams.List[i].Names[0].Name] = &genericTypeSpec{
ArrayDepth: arrayDepth,
TypeSpec: tdef,
}
}
}

parametrizedTypeSpec := &TypeSpecDef{
Expand All @@ -66,16 +106,34 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
Obj: original.TypeSpec.Name.Obj,
}

genNameParts := strings.Split(fullGenericForm, "[")
if strings.Contains(genNameParts[0], ".") {
genNameParts[0] = strings.Split(genNameParts[0], ".")[1]
if strings.Contains(genericTypeName, ".") {
genericTypeName = strings.Split(genericTypeName, ".")[1]
}

ident.Name = genNameParts[0] + "-" + strings.Replace(strings.Join(genericParams, "-"), ".", "_", -1)
ident.Name = strings.Replace(strings.Replace(ident.Name, "\t", "", -1), " ", "", -1)
var typeName = []string{TypeDocName(fullTypeName(pkgName, genericTypeName), parametrizedTypeSpec.TypeSpec)}

parametrizedTypeSpec.TypeSpec.Name = ident
for _, def := range original.TypeSpec.TypeParams.List {
if specDef, ok := genericParamTypeDefs[def.Names[0].Name]; ok {
var prefix = ""
if specDef.ArrayDepth > 0 {
prefix = "array_"
if specDef.ArrayDepth > 1 {
prefix = fmt.Sprintf("array%d_", specDef.ArrayDepth)
}
}
typeName = append(typeName, prefix+specDef.TypeDocName())
}
}

ident.Name = strings.Join(typeName, "-")
ident.Name = strings.Replace(ident.Name, ".", "_", -1)
pkgNamePrefix := pkgName + "_"
if strings.HasPrefix(ident.Name, pkgNamePrefix) {
ident.Name = fullTypeName(pkgName, ident.Name[len(pkgNamePrefix):])
}
ident.Name = string(IgnoreNameOverridePrefix) + ident.Name

parametrizedTypeSpec.TypeSpec.Name = ident
origStructType := original.TypeSpec.Type.(*ast.StructType)

newStructTypeDef := &ast.StructType{
Expand All @@ -101,18 +159,111 @@ func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, ful
}

parametrizedTypeSpec.TypeSpec.Type = newStructTypeDef

if genericsDefinitions[original] == nil {
genericsDefinitions[original] = map[string]*TypeSpecDef{}
}
genericsDefinitions[original][fullGenericForm] = parametrizedTypeSpec
return parametrizedTypeSpec
}

func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[string]*TypeSpecDef) ast.Expr {
if asIdent, ok := expr.(*ast.Ident); ok {
if genTypeSpec, ok := genericParamTypeDefs[asIdent.Name]; ok {
return genTypeSpec.TypeSpec.Type
// splitStructName splits a generic struct name in his parts
func splitStructName(fullGenericForm string) (string, []string) {
// split only at the first '[' and remove the last ']'
genericParams := strings.SplitN(strings.TrimSpace(fullGenericForm)[:len(fullGenericForm)-1], "[", 2)
if len(genericParams) == 1 {
return "", nil
}

// generic type name
genericTypeName := genericParams[0]

// generic params
insideBrackets := 0
lastParam := ""
params := strings.Split(genericParams[1], ",")
genericParams = []string{}
for _, p := range params {
numOpened := strings.Count(p, "[")
numClosed := strings.Count(p, "]")
if numOpened == numClosed && insideBrackets == 0 {
genericParams = append(genericParams, strings.TrimSpace(p))
continue
}

insideBrackets += numOpened - numClosed
lastParam += p + ","

if insideBrackets == 0 {
genericParams = append(genericParams, strings.TrimSpace(strings.TrimRight(lastParam, ",")))
lastParam = ""
}
}

return genericTypeName, genericParams
}

func resolveType(expr ast.Expr, field *ast.Field, genericParamTypeDefs map[string]*genericTypeSpec) ast.Expr {
switch astExpr := expr.(type) {
case *ast.Ident:
if genTypeSpec, ok := genericParamTypeDefs[astExpr.Name]; ok {
if genTypeSpec.ArrayDepth > 0 {
genTypeSpec.ArrayDepth--
return &ast.ArrayType{Elt: resolveType(expr, field, genericParamTypeDefs)}
}
return genTypeSpec.Type()
}
case *ast.ArrayType:
return &ast.ArrayType{
Elt: resolveType(astExpr.Elt, field, genericParamTypeDefs),
Len: astExpr.Len,
Lbrack: astExpr.Lbrack,
}
} else if asArray, ok := expr.(*ast.ArrayType); ok {
return &ast.ArrayType{Elt: resolveType(asArray.Elt, field, genericParamTypeDefs), Len: asArray.Len, Lbrack: asArray.Lbrack}
}

return field.Type
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
switch fieldType := field.(type) {
case *ast.IndexListExpr:
spec := &TypeSpecDef{
File: file,
TypeSpec: getGenericTypeSpec(fieldType.X),
PkgPath: file.Name.Name,
}
fullName := spec.FullName() + "["

for _, index := range fieldType.Indices {
var fieldName string
var err error

switch item := index.(type) {
case *ast.ArrayType:
fieldName, err = getFieldType(file, item.Elt)
fieldName = "[]" + fieldName
default:
fieldName, err = getFieldType(file, index)
}

if err != nil {
return "", err
}

fullName += fieldName + ", "
}

return strings.TrimRight(fullName, ", ") + "]", nil
}

return "", fmt.Errorf("unknown field type %#v", field)
}

func getGenericTypeSpec(field ast.Expr) *ast.TypeSpec {
switch indexType := field.(type) {
case *ast.Ident:
return indexType.Obj.Decl.(*ast.TypeSpec)
case *ast.ArrayType:
return indexType.Elt.(*ast.Ident).Obj.Decl.(*ast.TypeSpec)
}
return nil
}
11 changes: 10 additions & 1 deletion generics_other.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@

package swag

import (
"fmt"
"go/ast"
)

func typeSpecFullName(typeSpecDef *TypeSpecDef) string {
return typeSpecDef.FullName()
}

func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string) *TypeSpecDef {
func (pkgDefs *PackagesDefinitions) parametrizeStruct(original *TypeSpecDef, fullGenericForm string, parseDependency bool) *TypeSpecDef {
return original
}

func getGenericFieldType(file *ast.File, field ast.Expr) (string, error) {
return "", fmt.Errorf("unknown field type %#v", field)
}
Loading

0 comments on commit 2f148dd

Please sign in to comment.