diff --git a/field/association.go b/field/association.go index 5abef7c4..d87d2842 100644 --- a/field/association.go +++ b/field/association.go @@ -24,6 +24,25 @@ type Relations struct { Many2Many []*Relation } +func (r *Relations) Accept(relations ...*Relation) { + for _, relation := range relations { + switch relation.Relationship() { + case HasOne: + r.HasOne = append(r.HasOne, relation) + case HasMany: + r.HasMany = append(r.HasMany, relation) + case BelongsTo: + r.BelongsTo = append(r.BelongsTo, relation) + case Many2Many: + r.Many2Many = append(r.Many2Many, relation) + } + } +} + +func (r *Relations) SingleRelation() []*Relation { + return append(append(append(append(make([]*Relation, 0, 4), r.HasOne...), r.BelongsTo...), r.HasMany...), r.Many2Many...) +} + type RelationField interface { Name() string Path() string @@ -39,28 +58,39 @@ type RelationField interface { } type Relation struct { - varName string - varType string - path string + relationship RelationshipType - relations []*Relation + fieldName string + fieldType string + fieldPath string + fieldModel interface{} // store relaiton model + + childRelations []*Relation conds []Expr order []Expr clauses []clause.Expression } -func (r Relation) Name() string { return r.varName } +func (r Relation) Name() string { return r.fieldName } + +func (r Relation) Path() string { return r.fieldPath } -func (r Relation) Path() string { return r.path } +func (r Relation) Type() string { return r.fieldType } -func (r Relation) Type() string { return r.varType } +func (r Relation) Model() interface{} { return r.fieldModel } + +func (r Relation) Relationship() RelationshipType { return r.relationship } func (r Relation) Field(member ...string) Expr { if len(member) > 0 { - return NewString("", r.varName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote) + return NewString("", r.fieldName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote) } - return NewString("", r.varName).appendBuildOpts(WithoutQuote) + return NewString("", r.fieldName).appendBuildOpts(WithoutQuote) +} + +func (r *Relation) AppendChildRelation(relations ...*Relation) { + r.childRelations = append(r.childRelations, wrapPath(r.fieldPath, relations)...) } func (r *Relation) On(conds ...Expr) RelationField { @@ -86,16 +116,16 @@ func (r *Relation) GetClauses() []clause.Expression { return r.clauses } func (r *Relation) StructMember() string { var memberStr string - for _, relation := range r.relations { - memberStr += relation.varName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n" + for _, relation := range r.childRelations { + memberStr += relation.fieldName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n" } return memberStr } func (r *Relation) StructMemberInit() string { - initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.varType) - for _, relation := range r.relations { - initStr += relation.varName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}" + initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.fieldPath, r.fieldType) + for _, relation := range r.childRelations { + initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}" initStr += "{\n" + relation.StructMemberInit() + "},\n" } return initStr @@ -103,8 +133,39 @@ func (r *Relation) StructMemberInit() string { func wrapPath(root string, rs []*Relation) []*Relation { for _, r := range rs { - r.path = root + "." + r.path - r.relations = wrapPath(root, r.relations) + r.fieldPath = root + "." + r.fieldPath + r.childRelations = wrapPath(root, r.childRelations) } return rs } + +var defaultRelationshipPrefix = map[RelationshipType]string{ + // HasOne: "", + // BelongsTo: "", + HasMany: "[]", + Many2Many: "[]", +} + +type RelateConfig struct { + RelatePointer bool + RelateSlice bool + RelateSlicePointer bool + + JSONTag string + GORMTag string + NewTag string + OverwriteTag string +} + +func (c *RelateConfig) RelateFieldPrefix(relationshipType RelationshipType) string { + switch { + case c.RelatePointer: + return "*" + case c.RelateSlice: + return "[]" + case c.RelateSlicePointer: + return "[]*" + default: + return defaultRelationshipPrefix[relationshipType] + } +} diff --git a/field/export.go b/field/export.go index dd94a8d5..17ff7e67 100644 --- a/field/export.go +++ b/field/export.go @@ -218,6 +218,36 @@ func EmptyExpr() Expr { return expr{e: clause.Expr{}} } var AssociationFields Expr = NewString("", clause.Associations) var Associations RelationField = NewRelation(clause.Associations, "") -func NewRelation(varName string, varType string, relations ...*Relation) *Relation { - return &Relation{varName: varName, path: varName, varType: varType, relations: wrapPath(varName, relations)} +func NewRelation(fieldName string, fieldType string, relations ...*Relation) *Relation { + return &Relation{ + fieldName: fieldName, + fieldPath: fieldName, + fieldType: fieldType, + childRelations: wrapPath(fieldName, relations), + } +} + +func NewRelationWithModel(relationship RelationshipType, fieldName string, fieldType string, fieldModel interface{}, relations ...*Relation) *Relation { + return &Relation{ + relationship: relationship, + fieldName: fieldName, + fieldType: fieldType, + fieldPath: fieldName, + fieldModel: fieldModel, + } +} + +func NewRelationAndCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation { + copiedRelations := make([]*Relation, len(relations)) + for i, r := range relations { + copy := *r + copiedRelations[i] = © + } + return &Relation{ + relationship: relationship, + fieldName: fieldName, + fieldType: fieldType, + fieldPath: fieldName, + childRelations: wrapPath(fieldName, copiedRelations), + } } diff --git a/generator.go b/generator.go index 57e23347..c5796edd 100644 --- a/generator.go +++ b/generator.go @@ -8,6 +8,7 @@ import ( "log" "os" "path/filepath" + "reflect" "regexp" "strconv" "strings" @@ -15,8 +16,10 @@ import ( "golang.org/x/tools/imports" "gorm.io/gorm" + "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" + "gorm.io/gen/field" "gorm.io/gen/internal/check" "gorm.io/gen/internal/parser" tmpl "gorm.io/gen/internal/template" @@ -79,7 +82,7 @@ func (cfg *Config) WithDbNameOpts(opts ...check.SchemaNameOpt) { func (cfg *Config) Revise() (err error) { if cfg.ModelPkgPath == "" { - cfg.ModelPkgPath = check.ModelPkg + cfg.ModelPkgPath = check.DefaultModelPkg } cfg.OutPath, err = filepath.Abs(cfg.OutPath) @@ -265,6 +268,55 @@ var ( return m } } + FieldRelate = func(relationship field.RelationshipType, fieldName string, table *check.BaseStruct, config *field.RelateConfig) check.CreateMemberOpt { + if config == nil { + config = &field.RelateConfig{} + } + if config.JSONTag == "" { + config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName) + } + return func(*check.Member) *check.Member { + return &check.Member{ + Name: fieldName, + Type: config.RelateFieldPrefix(relationship) + table.StructInfo.Type, + JSONTag: config.JSONTag, + GORMTag: config.GORMTag, + NewTag: config.NewTag, + OverwriteTag: config.OverwriteTag, + + Relation: field.NewRelationAndCopy( + relationship, fieldName, table.StructInfo.Package+"."+table.StructInfo.Type, + table.Relations.SingleRelation()...), + } + } + } + FieldRelateModel = func(relationship field.RelationshipType, fieldName string, model interface{}, config *field.RelateConfig) check.CreateMemberOpt { + st := reflect.TypeOf(model) + if st.Kind() == reflect.Ptr { + st = st.Elem() + } + fieldType := st.String() + + if config == nil { + config = &field.RelateConfig{} + } + if config.JSONTag == "" { + config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName) + } + + return func(*check.Member) *check.Member { + return &check.Member{ + Name: fieldName, + Type: config.RelateFieldPrefix(relationship) + fieldType, + JSONTag: config.JSONTag, + GORMTag: config.GORMTag, + NewTag: config.NewTag, + OverwriteTag: config.OverwriteTag, + + Relation: field.NewRelationWithModel(relationship, fieldName, fieldType, model), + } + } + } ) /* @@ -313,7 +365,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) { panic("check interface fail") } - err = readInterface.ParseFile(interfacePaths, check.GetNames(structs)) + err = readInterface.ParseFile(interfacePaths, check.GetStructNames(structs)) if err != nil { g.db.Logger.Error(context.Background(), "parser interface file fail: %s", err) panic("parser interface file fail") @@ -470,7 +522,7 @@ func (g *Generator) generateBaseStruct() (err error) { } path := filepath.Clean(g.ModelPkgPath) if path == "" { - path = check.ModelPkg + path = check.DefaultModelPkg } if strings.Contains(path, "/") { outPath, err = filepath.Abs(path) @@ -522,11 +574,7 @@ func (g *Generator) output(fileName string, content []byte) error { result, err := imports.Process(fileName, content, nil) if err != nil { errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1]) - startLine, endLine := errLine-3, errLine+3 - if startLine < 0 { - startLine = 0 - } - + startLine, endLine := 0, errLine+3 fmt.Println("Format fail:") line := strings.Split(string(content), "\n") for i := startLine; i <= endLine; i++ { diff --git a/internal/check/base.go b/internal/check/base.go index 749371df..8ce65d90 100644 --- a/internal/check/base.go +++ b/internal/check/base.go @@ -4,6 +4,8 @@ import ( "bytes" "fmt" "strings" + + "gorm.io/gen/field" ) type Status int @@ -112,9 +114,17 @@ type Member struct { GORMTag string NewTag string OverwriteTag string + + Relation *field.Relation } +func (m *Member) IsRelation() bool { return m.Relation != nil } + func (m *Member) GenType() string { + if m.IsRelation() { + return m.Type + } + switch m.Type { case "string", "bytes": return strings.Title(m.Type) diff --git a/internal/check/checkstruct.go b/internal/check/checkstruct.go index 61594480..2c24d7ab 100644 --- a/internal/check/checkstruct.go +++ b/internal/check/checkstruct.go @@ -45,44 +45,11 @@ func (b *BaseStruct) parseStruct(st interface{}) error { })) } - b.Relations = b.parseStructRelationShip(stmt.Schema.Relationships) + b.Relations = *ParseStructRelationShip(&stmt.Schema.Relationships) return nil } -func (b *BaseStruct) parseStructRelationShip(relationship schema.Relationships) field.Relations { - cache := make(map[string]bool) - return field.Relations{ - HasOne: b.pullRelationShip(cache, relationship.HasOne), - BelongsTo: b.pullRelationShip(cache, relationship.BelongsTo), - HasMany: b.pullRelationShip(cache, relationship.HasMany), - Many2Many: b.pullRelationShip(cache, relationship.Many2Many), - } -} - -func (b *BaseStruct) pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation { - if len(relationships) == 0 { - return nil - } - result := make([]*field.Relation, len(relationships)) - for i, relationship := range relationships { - var childRelations []*field.Relation - varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*") - if !cache[varType] { - cache[varType] = true - childRelations = b.pullRelationShip(cache, append(append(append(append( - make([]*schema.Relationship, 0, 4), - relationship.FieldSchema.Relationships.BelongsTo...), - relationship.FieldSchema.Relationships.HasOne...), - relationship.FieldSchema.Relationships.HasMany...), - relationship.FieldSchema.Relationships.Many2Many...), - ) - } - result[i] = field.NewRelation(relationship.Name, varType, childRelations...) - } - return result -} - // getMemberRealType get basic type of member func (b *BaseStruct) getMemberRealType(member reflect.Type) string { scanValuer := reflect.TypeOf((*field.ScanValuer)(nil)).Elem() @@ -134,7 +101,7 @@ func (b *BaseStruct) check() (err error) { return nil } -func GetNames(bases []*BaseStruct) (res []string) { +func GetStructNames(bases []*BaseStruct) (res []string) { for _, base := range bases { res = append(res, base.StructName) } @@ -145,3 +112,38 @@ func isStructType(data reflect.Value) bool { return data.Kind() == reflect.Struct || (data.Kind() == reflect.Ptr && data.Elem().Kind() == reflect.Struct) } + +// ParseStructRelationShip parse struct's relationship +// No one should use it directly in project +func ParseStructRelationShip(relationship *schema.Relationships) *field.Relations { + cache := make(map[string]bool) + return &field.Relations{ + HasOne: pullRelationShip(cache, relationship.HasOne), + BelongsTo: pullRelationShip(cache, relationship.BelongsTo), + HasMany: pullRelationShip(cache, relationship.HasMany), + Many2Many: pullRelationShip(cache, relationship.Many2Many), + } +} + +func pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation { + if len(relationships) == 0 { + return nil + } + result := make([]*field.Relation, len(relationships)) + for i, relationship := range relationships { + var childRelations []*field.Relation + varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*") + if !cache[varType] { + cache[varType] = true + childRelations = pullRelationShip(cache, append(append(append(append( + make([]*schema.Relationship, 0, 4), + relationship.FieldSchema.Relationships.BelongsTo...), + relationship.FieldSchema.Relationships.HasOne...), + relationship.FieldSchema.Relationships.HasMany...), + relationship.FieldSchema.Relationships.Many2Many...), + ) + } + result[i] = field.NewRelation(relationship.Name, varType, childRelations...) + } + return result +} diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go index 5999cb4a..879c3212 100644 --- a/internal/check/gen_structs.go +++ b/internal/check/gen_structs.go @@ -19,7 +19,7 @@ import ( */ const ( - ModelPkg = "model" + DefaultModelPkg = "model" //query table structure columnQuery = "SELECT COLUMN_NAME ,COLUMN_COMMENT ,DATA_TYPE ,IS_NULLABLE ,COLUMN_KEY,COLUMN_TYPE,COLUMN_DEFAULT,EXTRA" + @@ -29,7 +29,7 @@ const ( type SchemaNameOpt func(*gorm.DB) string // GenBaseStructs generate db model by table name -func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) { +func GenBaseStructs(db *gorm.DB, modelPkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) { if _, ok := db.Config.Dialector.(tests.DummyDialector); ok { return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", modelName, tableName) } @@ -37,10 +37,10 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt if err = checkModelName(modelName); err != nil { return nil, fmt.Errorf("model name %q is invalid: %w", modelName, err) } - if pkg == "" { - pkg = ModelPkg + if modelPkg == "" { + modelPkg = DefaultModelPkg } - pkg = filepath.Base(pkg) + modelPkg = filepath.Base(modelPkg) dbName := getSchemaName(db, schemaNameOpts...) columns, err := getTbColumns(db, dbName, tableName) if err != nil { @@ -53,17 +53,10 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt StructName: modelName, NewStructName: uncaptialize(modelName), S: strings.ToLower(modelName[0:1]), - StructInfo: parser.Param{Type: modelName, Package: pkg}, + StructInfo: parser.Param{Type: modelName, Package: modelPkg}, } modifyOpts, filterOpts, createOpts := sortOpt(memberOpts) - for _, create := range createOpts { - m := create.self()(nil) - m.Name = db.NamingStrategy.SchemaName(m.Name) - - base.Members = append(base.Members, m) - } - for _, field := range columns { m := field.toMember(nullable) @@ -77,6 +70,26 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt base.Members = append(base.Members, m) } + for _, create := range createOpts { + m := create.self()(nil) + + if m.Relation != nil { + if m.Relation.Model() != nil { + stmt := gorm.Statement{DB: db} + _ = stmt.Parse(m.Relation.Model()) + if stmt.Schema != nil { + m.Relation.AppendChildRelation(ParseStructRelationShip(&stmt.Schema.Relationships).SingleRelation()...) + } + } + m.Type = strings.ReplaceAll(m.Type, modelPkg+".", "") // remove modelPkg in field's Type, avoid import error + base.Relations.Accept(m.Relation) + } else { // Relation Field do not need SchemaName convert + m.Name = db.NamingStrategy.SchemaName(m.Name) + } + + base.Members = append(base.Members, m) + } + return &base, nil } diff --git a/internal/template/struct.go b/internal/template/struct.go index 3130358a..65f9acec 100644 --- a/internal/template/struct.go +++ b/internal/template/struct.go @@ -29,7 +29,8 @@ const ( _{{.NewStructName}}.{{.NewStructName}}Do.UseModel({{.StructInfo.Package}}.{{.StructInfo.Type}}{}) {{if .HasMember}}tableName := _{{.NewStructName}}.{{.NewStructName}}Do.TableName(){{end}} - {{range .Members}} _{{$.NewStructName}}.{{.Name}} = field.New{{.GenType}}(tableName, "{{.ColumnName}}") + {{range .Members -}} + {{if not .IsRelation}}_{{$.NewStructName}}.{{.Name}} = field.New{{.GenType}}(tableName, "{{.ColumnName}}"){{end}} {{end}} {{range .Relations.HasOne}} _{{$.NewStructName}}.{{.Name}} = {{$.NewStructName}}HasOne{{.Name}}{ @@ -64,15 +65,20 @@ const ( } ` members = ` - {{range .Members}}{{.Name}} field.{{.GenType}} + {{range .Members -}} + {{if not .IsRelation}}{{.Name}} field.{{.GenType}}{{end}} {{end}} - {{range .Relations.HasOne}}{{.Name}} {{$.NewStructName}}HasOne{{.Name}} + {{range .Relations.HasOne -}} + {{.Name}} {{$.NewStructName}}HasOne{{.Name}} {{end}} - {{- range .Relations.HasMany}}{{.Name}} {{$.NewStructName}}HasMany{{.Name}} + {{- range .Relations.HasMany -}} + {{.Name}} {{$.NewStructName}}HasMany{{.Name}} {{end}} - {{- range .Relations.BelongsTo}}{{.Name}} {{$.NewStructName}}BelongsTo{{.Name}} + {{- range .Relations.BelongsTo -}} + {{.Name}} {{$.NewStructName}}BelongsTo{{.Name}} {{end}} - {{- range .Relations.Many2Many}}{{.Name}} {{$.NewStructName}}Many2Many{{.Name}} + {{- range .Relations.Many2Many -}} + {{.Name}} {{$.NewStructName}}Many2Many{{.Name}} {{end}} ` cloneMethod = `