Skip to content

Commit

Permalink
feat: relate with model
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Oct 1, 2021
1 parent 3883762 commit 8825025
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 64 deletions.
19 changes: 13 additions & 6 deletions field/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ type RelationField interface {
type Relation struct {
relationship RelationshipType

fieldName string
fieldType string
path string
fieldName string
fieldType string
fieldPath string
fieldModel interface{} // store relaiton model

childRelations []*Relation

Expand All @@ -73,10 +74,12 @@ type Relation struct {

func (r Relation) Name() string { return r.fieldName }

func (r Relation) Path() string { return r.path }
func (r Relation) Path() string { return r.fieldPath }

func (r Relation) Type() string { return r.fieldType }

func (r Relation) Model() interface{} { return r.fieldModel }

func (r Relation) Relationship() RelationshipType { return r.relationship }

func (r Relation) Field(member ...string) Expr {
Expand All @@ -86,6 +89,10 @@ func (r Relation) Field(member ...string) Expr {
return NewString("", r.fieldName).appendBuildOpts(WithoutQuote)
}

func (r *Relation) AppendChildRelation(relations ...*Relation) {
r.childRelations = append(r.childRelations, wrapPath(r.fieldPath, relations)...)
}

func (r *Relation) On(conds ...Expr) RelationField {
r.conds = append(r.conds, conds...)
return r
Expand Down Expand Up @@ -116,7 +123,7 @@ func (r *Relation) StructMember() string {
}

func (r *Relation) StructMemberInit() string {
initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.fieldType)
initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.fieldPath, r.fieldType)
for _, relation := range r.childRelations {
initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
initStr += "{\n" + relation.StructMemberInit() + "},\n"
Expand All @@ -126,7 +133,7 @@ func (r *Relation) StructMemberInit() string {

func wrapPath(root string, rs []*Relation) []*Relation {
for _, r := range rs {
r.path = root + "." + r.path
r.fieldPath = root + "." + r.fieldPath
r.childRelations = wrapPath(root, r.childRelations)
}
return rs
Expand Down
22 changes: 17 additions & 5 deletions field/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,23 @@ var Associations RelationField = NewRelation(clause.Associations, "")
func NewRelation(fieldName string, fieldType string, relations ...*Relation) *Relation {
return &Relation{
fieldName: fieldName,
path: fieldName,
fieldPath: fieldName,
fieldType: fieldType,
childRelations: wrapPath(fieldName, relations)}
childRelations: wrapPath(fieldName, relations),
}
}

func NewRelationWithModel(relationship RelationshipType, fieldName string, fieldType string, fieldModel interface{}, relations ...*Relation) *Relation {
return &Relation{
relationship: relationship,
fieldName: fieldName,
fieldType: fieldType,
fieldPath: fieldName,
fieldModel: fieldModel,
}
}

func NewRelationWithCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
func NewRelationAndCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
copiedRelations := make([]*Relation, len(relations))
for i, r := range relations {
copy := *r
Expand All @@ -236,6 +247,7 @@ func NewRelationWithCopy(relationship RelationshipType, fieldName string, fieldT
relationship: relationship,
fieldName: fieldName,
fieldType: fieldType,
path: fieldName,
childRelations: wrapPath(fieldName, copiedRelations)}
fieldPath: fieldName,
childRelations: wrapPath(fieldName, copiedRelations),
}
}
42 changes: 33 additions & 9 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"log"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
Expand Down Expand Up @@ -81,7 +82,7 @@ func (cfg *Config) WithDbNameOpts(opts ...check.SchemaNameOpt) {

func (cfg *Config) Revise() (err error) {
if cfg.ModelPkgPath == "" {
cfg.ModelPkgPath = check.ModelPkg
cfg.ModelPkgPath = check.DefaultModelPkg
}

cfg.OutPath, err = filepath.Abs(cfg.OutPath)
Expand Down Expand Up @@ -283,12 +284,39 @@ var (
NewTag: config.NewTag,
OverwriteTag: config.OverwriteTag,

Relation: field.NewRelationWithCopy(
Relation: field.NewRelationAndCopy(
relationship, fieldName, table.StructInfo.Package+"."+table.StructInfo.Type,
table.Relations.SingleRelation()...),
}
}
}
FieldRelateModel = func(relationship field.RelationshipType, fieldName string, model interface{}, config *field.RelateConfig) check.CreateMemberOpt {
st := reflect.TypeOf(model)
if st.Kind() == reflect.Ptr {
st = st.Elem()
}
fieldType := st.String()

if config == nil {
config = &field.RelateConfig{}
}
if config.JSONTag == "" {
config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName)
}

return func(*check.Member) *check.Member {
return &check.Member{
Name: fieldName,
Type: config.RelateFieldPrefix(relationship) + fieldType,
JSONTag: config.JSONTag,
GORMTag: config.GORMTag,
NewTag: config.NewTag,
OverwriteTag: config.OverwriteTag,

Relation: field.NewRelationWithModel(relationship, fieldName, fieldType, model),
}
}
}
)

/*
Expand Down Expand Up @@ -337,7 +365,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) {
panic("check interface fail")
}

err = readInterface.ParseFile(interfacePaths, check.GetNames(structs))
err = readInterface.ParseFile(interfacePaths, check.GetStructNames(structs))
if err != nil {
g.db.Logger.Error(context.Background(), "parser interface file fail: %s", err)
panic("parser interface file fail")
Expand Down Expand Up @@ -494,7 +522,7 @@ func (g *Generator) generateBaseStruct() (err error) {
}
path := filepath.Clean(g.ModelPkgPath)
if path == "" {
path = check.ModelPkg
path = check.DefaultModelPkg
}
if strings.Contains(path, "/") {
outPath, err = filepath.Abs(path)
Expand Down Expand Up @@ -546,11 +574,7 @@ func (g *Generator) output(fileName string, content []byte) error {
result, err := imports.Process(fileName, content, nil)
if err != nil {
errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
startLine, endLine := errLine-3, errLine+3
if startLine < 0 {
startLine = 0
}

startLine, endLine := 0, errLine+3
fmt.Println("Format fail:")
line := strings.Split(string(content), "\n")
for i := startLine; i <= endLine; i++ {
Expand Down
72 changes: 37 additions & 35 deletions internal/check/checkstruct.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,44 +45,11 @@ func (b *BaseStruct) parseStruct(st interface{}) error {
}))
}

b.Relations = b.parseStructRelationShip(&stmt.Schema.Relationships)
b.Relations = *ParseStructRelationShip(&stmt.Schema.Relationships)

return nil
}

func (b *BaseStruct) parseStructRelationShip(relationship *schema.Relationships) field.Relations {
cache := make(map[string]bool)
return field.Relations{
HasOne: b.pullRelationShip(cache, relationship.HasOne),
BelongsTo: b.pullRelationShip(cache, relationship.BelongsTo),
HasMany: b.pullRelationShip(cache, relationship.HasMany),
Many2Many: b.pullRelationShip(cache, relationship.Many2Many),
}
}

func (b *BaseStruct) pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation {
if len(relationships) == 0 {
return nil
}
result := make([]*field.Relation, len(relationships))
for i, relationship := range relationships {
var childRelations []*field.Relation
varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*")
if !cache[varType] {
cache[varType] = true
childRelations = b.pullRelationShip(cache, append(append(append(append(
make([]*schema.Relationship, 0, 4),
relationship.FieldSchema.Relationships.BelongsTo...),
relationship.FieldSchema.Relationships.HasOne...),
relationship.FieldSchema.Relationships.HasMany...),
relationship.FieldSchema.Relationships.Many2Many...),
)
}
result[i] = field.NewRelation(relationship.Name, varType, childRelations...)
}
return result
}

// getMemberRealType get basic type of member
func (b *BaseStruct) getMemberRealType(member reflect.Type) string {
scanValuer := reflect.TypeOf((*field.ScanValuer)(nil)).Elem()
Expand Down Expand Up @@ -134,7 +101,7 @@ func (b *BaseStruct) check() (err error) {
return nil
}

func GetNames(bases []*BaseStruct) (res []string) {
func GetStructNames(bases []*BaseStruct) (res []string) {
for _, base := range bases {
res = append(res, base.StructName)
}
Expand All @@ -145,3 +112,38 @@ func isStructType(data reflect.Value) bool {
return data.Kind() == reflect.Struct ||
(data.Kind() == reflect.Ptr && data.Elem().Kind() == reflect.Struct)
}

// ParseStructRelationShip parse struct's relationship
// No one should use it directly in project
func ParseStructRelationShip(relationship *schema.Relationships) *field.Relations {
cache := make(map[string]bool)
return &field.Relations{
HasOne: pullRelationShip(cache, relationship.HasOne),
BelongsTo: pullRelationShip(cache, relationship.BelongsTo),
HasMany: pullRelationShip(cache, relationship.HasMany),
Many2Many: pullRelationShip(cache, relationship.Many2Many),
}
}

func pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation {
if len(relationships) == 0 {
return nil
}
result := make([]*field.Relation, len(relationships))
for i, relationship := range relationships {
var childRelations []*field.Relation
varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*")
if !cache[varType] {
cache[varType] = true
childRelations = pullRelationShip(cache, append(append(append(append(
make([]*schema.Relationship, 0, 4),
relationship.FieldSchema.Relationships.BelongsTo...),
relationship.FieldSchema.Relationships.HasOne...),
relationship.FieldSchema.Relationships.HasMany...),
relationship.FieldSchema.Relationships.Many2Many...),
)
}
result[i] = field.NewRelation(relationship.Name, varType, childRelations...)
}
return result
}
23 changes: 14 additions & 9 deletions internal/check/gen_structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/utils/tests"

"gorm.io/gen/field"
"gorm.io/gen/internal/parser"
)

Expand All @@ -20,7 +19,7 @@ import (
*/

const (
ModelPkg = "model"
DefaultModelPkg = "model"

//query table structure
columnQuery = "SELECT COLUMN_NAME ,COLUMN_COMMENT ,DATA_TYPE ,IS_NULLABLE ,COLUMN_KEY,COLUMN_TYPE,COLUMN_DEFAULT,EXTRA" +
Expand All @@ -30,18 +29,18 @@ const (
type SchemaNameOpt func(*gorm.DB) string

// GenBaseStructs generate db model by table name
func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) {
func GenBaseStructs(db *gorm.DB, modelPkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) {
if _, ok := db.Config.Dialector.(tests.DummyDialector); ok {
return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", modelName, tableName)
}

if err = checkModelName(modelName); err != nil {
return nil, fmt.Errorf("model name %q is invalid: %w", modelName, err)
}
if pkg == "" {
pkg = ModelPkg
if modelPkg == "" {
modelPkg = DefaultModelPkg
}
pkg = filepath.Base(pkg)
modelPkg = filepath.Base(modelPkg)
dbName := getSchemaName(db, schemaNameOpts...)
columns, err := getTbColumns(db, dbName, tableName)
if err != nil {
Expand All @@ -54,9 +53,7 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
StructName: modelName,
NewStructName: uncaptialize(modelName),
S: strings.ToLower(modelName[0:1]),
StructInfo: parser.Param{Type: modelName, Package: pkg},

Relations: field.Relations{},
StructInfo: parser.Param{Type: modelName, Package: modelPkg},
}

modifyOpts, filterOpts, createOpts := sortOpt(memberOpts)
Expand All @@ -77,6 +74,14 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
m := create.self()(nil)

if m.Relation != nil {
if m.Relation.Model() != nil {
stmt := gorm.Statement{DB: db}
_ = stmt.Parse(m.Relation.Model())
if stmt.Schema != nil {
m.Relation.AppendChildRelation(ParseStructRelationShip(&stmt.Schema.Relationships).SingleRelation()...)
}
}
m.Type = strings.ReplaceAll(m.Type, modelPkg+".", "") // remove modelPkg in field's Type, avoid import error
base.Relations.Accept(m.Relation)
} else { // Relation Field do not need SchemaName convert
m.Name = db.NamingStrategy.SchemaName(m.Name)
Expand Down

0 comments on commit 8825025

Please sign in to comment.