From a68c4a2ac70f729b8e55c495e74ae56800c94e32 Mon Sep 17 00:00:00 2001
From: riverchu <churiver@outlook.com>
Date: Fri, 1 Oct 2021 21:16:41 +0800
Subject: [PATCH 1/3] feat: add FieldRelate opt

---
 field/association.go          | 82 +++++++++++++++++++++++++++++------
 field/export.go               | 22 +++++++++-
 generator.go                  | 24 ++++++++++
 internal/check/base.go        | 10 +++++
 internal/check/checkstruct.go |  4 +-
 internal/check/gen_structs.go | 21 ++++++---
 internal/template/struct.go   | 18 +++++---
 7 files changed, 150 insertions(+), 31 deletions(-)

diff --git a/field/association.go b/field/association.go
index 5abef7c4..b3b57461 100644
--- a/field/association.go
+++ b/field/association.go
@@ -24,6 +24,25 @@ type Relations struct {
 	Many2Many []*Relation
 }
 
+func (r *Relations) Accept(relations ...*Relation) {
+	for _, relation := range relations {
+		switch relation.Relationship() {
+		case HasOne:
+			r.HasOne = append(r.HasOne, relation)
+		case HasMany:
+			r.HasMany = append(r.HasMany, relation)
+		case BelongsTo:
+			r.BelongsTo = append(r.BelongsTo, relation)
+		case Many2Many:
+			r.Many2Many = append(r.Many2Many, relation)
+		}
+	}
+}
+
+func (r *Relations) SingleRelation() []*Relation {
+	return append(append(append(append(make([]*Relation, 0, 4), r.HasOne...), r.BelongsTo...), r.HasMany...), r.Many2Many...)
+}
+
 type RelationField interface {
 	Name() string
 	Path() string
@@ -39,28 +58,32 @@ type RelationField interface {
 }
 
 type Relation struct {
-	varName string
-	varType string
-	path    string
+	relationship RelationshipType
+
+	fieldName string
+	fieldType string
+	path      string
 
-	relations []*Relation
+	childRelations []*Relation
 
 	conds   []Expr
 	order   []Expr
 	clauses []clause.Expression
 }
 
-func (r Relation) Name() string { return r.varName }
+func (r Relation) Name() string { return r.fieldName }
 
 func (r Relation) Path() string { return r.path }
 
-func (r Relation) Type() string { return r.varType }
+func (r Relation) Type() string { return r.fieldType }
+
+func (r Relation) Relationship() RelationshipType { return r.relationship }
 
 func (r Relation) Field(member ...string) Expr {
 	if len(member) > 0 {
-		return NewString("", r.varName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote)
+		return NewString("", r.fieldName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote)
 	}
-	return NewString("", r.varName).appendBuildOpts(WithoutQuote)
+	return NewString("", r.fieldName).appendBuildOpts(WithoutQuote)
 }
 
 func (r *Relation) On(conds ...Expr) RelationField {
@@ -86,16 +109,16 @@ func (r *Relation) GetClauses() []clause.Expression { return r.clauses }
 
 func (r *Relation) StructMember() string {
 	var memberStr string
-	for _, relation := range r.relations {
-		memberStr += relation.varName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n"
+	for _, relation := range r.childRelations {
+		memberStr += relation.fieldName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n"
 	}
 	return memberStr
 }
 
 func (r *Relation) StructMemberInit() string {
-	initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.varType)
-	for _, relation := range r.relations {
-		initStr += relation.varName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
+	initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.fieldType)
+	for _, relation := range r.childRelations {
+		initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
 		initStr += "{\n" + relation.StructMemberInit() + "},\n"
 	}
 	return initStr
@@ -104,7 +127,38 @@ func (r *Relation) StructMemberInit() string {
 func wrapPath(root string, rs []*Relation) []*Relation {
 	for _, r := range rs {
 		r.path = root + "." + r.path
-		r.relations = wrapPath(root, r.relations)
+		r.childRelations = wrapPath(root, r.childRelations)
 	}
 	return rs
 }
+
+var defaultRelationshipPrefix = map[RelationshipType]string{
+	// HasOne:    "",
+	// BelongsTo: "",
+	HasMany:   "[]",
+	Many2Many: "[]",
+}
+
+type RelateConfig struct {
+	RelatePointer      bool
+	RelateSlice        bool
+	RelateSlicePointer bool
+
+	JSONTag      string
+	GORMTag      string
+	NewTag       string
+	OverwriteTag string
+}
+
+func (c *RelateConfig) RelateFieldPrefix(relationshipType RelationshipType) string {
+	switch {
+	case c.RelatePointer:
+		return "*"
+	case c.RelateSlice:
+		return "[]"
+	case c.RelateSlicePointer:
+		return "[]*"
+	default:
+		return defaultRelationshipPrefix[relationshipType]
+	}
+}
diff --git a/field/export.go b/field/export.go
index dd94a8d5..f320ecba 100644
--- a/field/export.go
+++ b/field/export.go
@@ -218,6 +218,24 @@ func EmptyExpr() Expr { return expr{e: clause.Expr{}} }
 var AssociationFields Expr = NewString("", clause.Associations)
 var Associations RelationField = NewRelation(clause.Associations, "")
 
-func NewRelation(varName string, varType string, relations ...*Relation) *Relation {
-	return &Relation{varName: varName, path: varName, varType: varType, relations: wrapPath(varName, relations)}
+func NewRelation(fieldName string, fieldType string, relations ...*Relation) *Relation {
+	return &Relation{
+		fieldName:      fieldName,
+		path:           fieldName,
+		fieldType:      fieldType,
+		childRelations: wrapPath(fieldName, relations)}
+}
+
+func NewRelationWithCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
+	copiedRelations := make([]*Relation, len(relations))
+	for i, r := range relations {
+		copy := *r
+		copiedRelations[i] = &copy
+	}
+	return &Relation{
+		relationship:   relationship,
+		fieldName:      fieldName,
+		fieldType:      fieldType,
+		path:           fieldName,
+		childRelations: wrapPath(fieldName, copiedRelations)}
 }
diff --git a/generator.go b/generator.go
index 57e23347..5304a7b5 100644
--- a/generator.go
+++ b/generator.go
@@ -15,8 +15,10 @@ import (
 
 	"golang.org/x/tools/imports"
 	"gorm.io/gorm"
+	"gorm.io/gorm/schema"
 	"gorm.io/gorm/utils/tests"
 
+	"gorm.io/gen/field"
 	"gorm.io/gen/internal/check"
 	"gorm.io/gen/internal/parser"
 	tmpl "gorm.io/gen/internal/template"
@@ -265,6 +267,28 @@ var (
 			return m
 		}
 	}
+	FieldRelate = func(relationship field.RelationshipType, fieldName string, table *check.BaseStruct, config *field.RelateConfig) check.CreateMemberOpt {
+		if config == nil {
+			config = &field.RelateConfig{}
+		}
+		if config.JSONTag == "" {
+			config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName)
+		}
+		return func(*check.Member) *check.Member {
+			return &check.Member{
+				Name:         fieldName,
+				Type:         config.RelateFieldPrefix(relationship) + table.StructInfo.Type,
+				JSONTag:      config.JSONTag,
+				GORMTag:      config.GORMTag,
+				NewTag:       config.NewTag,
+				OverwriteTag: config.OverwriteTag,
+
+				Relation: field.NewRelationWithCopy(
+					relationship, fieldName, table.StructInfo.Package+"."+table.StructInfo.Type,
+					table.Relations.SingleRelation()...),
+			}
+		}
+	}
 )
 
 /*
diff --git a/internal/check/base.go b/internal/check/base.go
index 749371df..8ce65d90 100644
--- a/internal/check/base.go
+++ b/internal/check/base.go
@@ -4,6 +4,8 @@ import (
 	"bytes"
 	"fmt"
 	"strings"
+
+	"gorm.io/gen/field"
 )
 
 type Status int
@@ -112,9 +114,17 @@ type Member struct {
 	GORMTag          string
 	NewTag           string
 	OverwriteTag     string
+
+	Relation *field.Relation
 }
 
+func (m *Member) IsRelation() bool { return m.Relation != nil }
+
 func (m *Member) GenType() string {
+	if m.IsRelation() {
+		return m.Type
+	}
+
 	switch m.Type {
 	case "string", "bytes":
 		return strings.Title(m.Type)
diff --git a/internal/check/checkstruct.go b/internal/check/checkstruct.go
index 61594480..bccfbac8 100644
--- a/internal/check/checkstruct.go
+++ b/internal/check/checkstruct.go
@@ -45,12 +45,12 @@ func (b *BaseStruct) parseStruct(st interface{}) error {
 		}))
 	}
 
-	b.Relations = b.parseStructRelationShip(stmt.Schema.Relationships)
+	b.Relations = b.parseStructRelationShip(&stmt.Schema.Relationships)
 
 	return nil
 }
 
-func (b *BaseStruct) parseStructRelationShip(relationship schema.Relationships) field.Relations {
+func (b *BaseStruct) parseStructRelationShip(relationship *schema.Relationships) field.Relations {
 	cache := make(map[string]bool)
 	return field.Relations{
 		HasOne:    b.pullRelationShip(cache, relationship.HasOne),
diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go
index 5999cb4a..b2e0b8fa 100644
--- a/internal/check/gen_structs.go
+++ b/internal/check/gen_structs.go
@@ -10,6 +10,7 @@ import (
 	"gorm.io/gorm"
 	"gorm.io/gorm/utils/tests"
 
+	"gorm.io/gen/field"
 	"gorm.io/gen/internal/parser"
 )
 
@@ -54,16 +55,11 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 		NewStructName: uncaptialize(modelName),
 		S:             strings.ToLower(modelName[0:1]),
 		StructInfo:    parser.Param{Type: modelName, Package: pkg},
-	}
 
-	modifyOpts, filterOpts, createOpts := sortOpt(memberOpts)
-	for _, create := range createOpts {
-		m := create.self()(nil)
-		m.Name = db.NamingStrategy.SchemaName(m.Name)
-
-		base.Members = append(base.Members, m)
+		Relations: field.Relations{},
 	}
 
+	modifyOpts, filterOpts, createOpts := sortOpt(memberOpts)
 	for _, field := range columns {
 		m := field.toMember(nullable)
 
@@ -77,6 +73,17 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 		base.Members = append(base.Members, m)
 	}
 
+	for _, create := range createOpts {
+		m := create.self()(nil)
+
+		if m.Relation != nil {
+			base.Relations.Accept(m.Relation)
+		}
+
+		m.Name = db.NamingStrategy.SchemaName(m.Name)
+		base.Members = append(base.Members, m)
+	}
+
 	return &base, nil
 }
 
diff --git a/internal/template/struct.go b/internal/template/struct.go
index 3130358a..65f9acec 100644
--- a/internal/template/struct.go
+++ b/internal/template/struct.go
@@ -29,7 +29,8 @@ const (
 		_{{.NewStructName}}.{{.NewStructName}}Do.UseModel({{.StructInfo.Package}}.{{.StructInfo.Type}}{})
 	
 		{{if .HasMember}}tableName := _{{.NewStructName}}.{{.NewStructName}}Do.TableName(){{end}}
-		{{range .Members}} _{{$.NewStructName}}.{{.Name}} = field.New{{.GenType}}(tableName, "{{.ColumnName}}")
+		{{range .Members -}}
+		{{if not .IsRelation}}_{{$.NewStructName}}.{{.Name}} = field.New{{.GenType}}(tableName, "{{.ColumnName}}"){{end}}
 		{{end}}
 		{{range .Relations.HasOne}}
 			_{{$.NewStructName}}.{{.Name}} = {{$.NewStructName}}HasOne{{.Name}}{
@@ -64,15 +65,20 @@ const (
 	}
 	`
 	members = `
-	{{range .Members}}{{.Name}} field.{{.GenType}}
+	{{range .Members -}}
+	{{if not .IsRelation}}{{.Name}} field.{{.GenType}}{{end}}
 	{{end}}
-	{{range .Relations.HasOne}}{{.Name}} {{$.NewStructName}}HasOne{{.Name}}
+	{{range .Relations.HasOne -}}
+	{{.Name}} {{$.NewStructName}}HasOne{{.Name}}
 	{{end}}
-	{{- range .Relations.HasMany}}{{.Name}} {{$.NewStructName}}HasMany{{.Name}}
+	{{- range .Relations.HasMany -}}
+	{{.Name}} {{$.NewStructName}}HasMany{{.Name}}
 	{{end}}
-	{{- range .Relations.BelongsTo}}{{.Name}} {{$.NewStructName}}BelongsTo{{.Name}}
+	{{- range .Relations.BelongsTo -}}
+	{{.Name}} {{$.NewStructName}}BelongsTo{{.Name}}
 	{{end}}
-	{{- range .Relations.Many2Many}}{{.Name}} {{$.NewStructName}}Many2Many{{.Name}}
+	{{- range .Relations.Many2Many -}}
+	{{.Name}} {{$.NewStructName}}Many2Many{{.Name}}
 	{{end}}
 `
 	cloneMethod = `

From 388376237bdc5674948101818ac8d3250363b8f8 Mon Sep 17 00:00:00 2001
From: riverchu <churiver@outlook.com>
Date: Fri, 1 Oct 2021 21:32:56 +0800
Subject: [PATCH 2/3] feat: relation field do not convert name

---
 internal/check/gen_structs.go | 3 ++-
 1 file changed, 2 insertions(+), 1 deletion(-)

diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go
index b2e0b8fa..f8022f70 100644
--- a/internal/check/gen_structs.go
+++ b/internal/check/gen_structs.go
@@ -78,9 +78,10 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 
 		if m.Relation != nil {
 			base.Relations.Accept(m.Relation)
+		} else { // Relation Field do not need SchemaName convert
+			m.Name = db.NamingStrategy.SchemaName(m.Name)
 		}
 
-		m.Name = db.NamingStrategy.SchemaName(m.Name)
 		base.Members = append(base.Members, m)
 	}
 

From 8825025432e4604c6d39b85da8c91cee6498d09c Mon Sep 17 00:00:00 2001
From: riverchu <churiver@outlook.com>
Date: Sat, 2 Oct 2021 00:04:50 +0800
Subject: [PATCH 3/3] feat: relate with model

---
 field/association.go          | 19 ++++++---
 field/export.go               | 22 ++++++++---
 generator.go                  | 42 +++++++++++++++-----
 internal/check/checkstruct.go | 72 ++++++++++++++++++-----------------
 internal/check/gen_structs.go | 23 ++++++-----
 5 files changed, 114 insertions(+), 64 deletions(-)

diff --git a/field/association.go b/field/association.go
index b3b57461..d87d2842 100644
--- a/field/association.go
+++ b/field/association.go
@@ -60,9 +60,10 @@ type RelationField interface {
 type Relation struct {
 	relationship RelationshipType
 
-	fieldName string
-	fieldType string
-	path      string
+	fieldName  string
+	fieldType  string
+	fieldPath  string
+	fieldModel interface{} // store relaiton model
 
 	childRelations []*Relation
 
@@ -73,10 +74,12 @@ type Relation struct {
 
 func (r Relation) Name() string { return r.fieldName }
 
-func (r Relation) Path() string { return r.path }
+func (r Relation) Path() string { return r.fieldPath }
 
 func (r Relation) Type() string { return r.fieldType }
 
+func (r Relation) Model() interface{} { return r.fieldModel }
+
 func (r Relation) Relationship() RelationshipType { return r.relationship }
 
 func (r Relation) Field(member ...string) Expr {
@@ -86,6 +89,10 @@ func (r Relation) Field(member ...string) Expr {
 	return NewString("", r.fieldName).appendBuildOpts(WithoutQuote)
 }
 
+func (r *Relation) AppendChildRelation(relations ...*Relation) {
+	r.childRelations = append(r.childRelations, wrapPath(r.fieldPath, relations)...)
+}
+
 func (r *Relation) On(conds ...Expr) RelationField {
 	r.conds = append(r.conds, conds...)
 	return r
@@ -116,7 +123,7 @@ func (r *Relation) StructMember() string {
 }
 
 func (r *Relation) StructMemberInit() string {
-	initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.fieldType)
+	initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.fieldPath, r.fieldType)
 	for _, relation := range r.childRelations {
 		initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
 		initStr += "{\n" + relation.StructMemberInit() + "},\n"
@@ -126,7 +133,7 @@ func (r *Relation) StructMemberInit() string {
 
 func wrapPath(root string, rs []*Relation) []*Relation {
 	for _, r := range rs {
-		r.path = root + "." + r.path
+		r.fieldPath = root + "." + r.fieldPath
 		r.childRelations = wrapPath(root, r.childRelations)
 	}
 	return rs
diff --git a/field/export.go b/field/export.go
index f320ecba..17ff7e67 100644
--- a/field/export.go
+++ b/field/export.go
@@ -221,12 +221,23 @@ var Associations RelationField = NewRelation(clause.Associations, "")
 func NewRelation(fieldName string, fieldType string, relations ...*Relation) *Relation {
 	return &Relation{
 		fieldName:      fieldName,
-		path:           fieldName,
+		fieldPath:      fieldName,
 		fieldType:      fieldType,
-		childRelations: wrapPath(fieldName, relations)}
+		childRelations: wrapPath(fieldName, relations),
+	}
+}
+
+func NewRelationWithModel(relationship RelationshipType, fieldName string, fieldType string, fieldModel interface{}, relations ...*Relation) *Relation {
+	return &Relation{
+		relationship: relationship,
+		fieldName:    fieldName,
+		fieldType:    fieldType,
+		fieldPath:    fieldName,
+		fieldModel:   fieldModel,
+	}
 }
 
-func NewRelationWithCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
+func NewRelationAndCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
 	copiedRelations := make([]*Relation, len(relations))
 	for i, r := range relations {
 		copy := *r
@@ -236,6 +247,7 @@ func NewRelationWithCopy(relationship RelationshipType, fieldName string, fieldT
 		relationship:   relationship,
 		fieldName:      fieldName,
 		fieldType:      fieldType,
-		path:           fieldName,
-		childRelations: wrapPath(fieldName, copiedRelations)}
+		fieldPath:      fieldName,
+		childRelations: wrapPath(fieldName, copiedRelations),
+	}
 }
diff --git a/generator.go b/generator.go
index 5304a7b5..c5796edd 100644
--- a/generator.go
+++ b/generator.go
@@ -8,6 +8,7 @@ import (
 	"log"
 	"os"
 	"path/filepath"
+	"reflect"
 	"regexp"
 	"strconv"
 	"strings"
@@ -81,7 +82,7 @@ func (cfg *Config) WithDbNameOpts(opts ...check.SchemaNameOpt) {
 
 func (cfg *Config) Revise() (err error) {
 	if cfg.ModelPkgPath == "" {
-		cfg.ModelPkgPath = check.ModelPkg
+		cfg.ModelPkgPath = check.DefaultModelPkg
 	}
 
 	cfg.OutPath, err = filepath.Abs(cfg.OutPath)
@@ -283,12 +284,39 @@ var (
 				NewTag:       config.NewTag,
 				OverwriteTag: config.OverwriteTag,
 
-				Relation: field.NewRelationWithCopy(
+				Relation: field.NewRelationAndCopy(
 					relationship, fieldName, table.StructInfo.Package+"."+table.StructInfo.Type,
 					table.Relations.SingleRelation()...),
 			}
 		}
 	}
+	FieldRelateModel = func(relationship field.RelationshipType, fieldName string, model interface{}, config *field.RelateConfig) check.CreateMemberOpt {
+		st := reflect.TypeOf(model)
+		if st.Kind() == reflect.Ptr {
+			st = st.Elem()
+		}
+		fieldType := st.String()
+
+		if config == nil {
+			config = &field.RelateConfig{}
+		}
+		if config.JSONTag == "" {
+			config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName)
+		}
+
+		return func(*check.Member) *check.Member {
+			return &check.Member{
+				Name:         fieldName,
+				Type:         config.RelateFieldPrefix(relationship) + fieldType,
+				JSONTag:      config.JSONTag,
+				GORMTag:      config.GORMTag,
+				NewTag:       config.NewTag,
+				OverwriteTag: config.OverwriteTag,
+
+				Relation: field.NewRelationWithModel(relationship, fieldName, fieldType, model),
+			}
+		}
+	}
 )
 
 /*
@@ -337,7 +365,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) {
 		panic("check interface fail")
 	}
 
-	err = readInterface.ParseFile(interfacePaths, check.GetNames(structs))
+	err = readInterface.ParseFile(interfacePaths, check.GetStructNames(structs))
 	if err != nil {
 		g.db.Logger.Error(context.Background(), "parser interface file fail: %s", err)
 		panic("parser interface file fail")
@@ -494,7 +522,7 @@ func (g *Generator) generateBaseStruct() (err error) {
 	}
 	path := filepath.Clean(g.ModelPkgPath)
 	if path == "" {
-		path = check.ModelPkg
+		path = check.DefaultModelPkg
 	}
 	if strings.Contains(path, "/") {
 		outPath, err = filepath.Abs(path)
@@ -546,11 +574,7 @@ func (g *Generator) output(fileName string, content []byte) error {
 	result, err := imports.Process(fileName, content, nil)
 	if err != nil {
 		errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
-		startLine, endLine := errLine-3, errLine+3
-		if startLine < 0 {
-			startLine = 0
-		}
-
+		startLine, endLine := 0, errLine+3
 		fmt.Println("Format fail:")
 		line := strings.Split(string(content), "\n")
 		for i := startLine; i <= endLine; i++ {
diff --git a/internal/check/checkstruct.go b/internal/check/checkstruct.go
index bccfbac8..2c24d7ab 100644
--- a/internal/check/checkstruct.go
+++ b/internal/check/checkstruct.go
@@ -45,44 +45,11 @@ func (b *BaseStruct) parseStruct(st interface{}) error {
 		}))
 	}
 
-	b.Relations = b.parseStructRelationShip(&stmt.Schema.Relationships)
+	b.Relations = *ParseStructRelationShip(&stmt.Schema.Relationships)
 
 	return nil
 }
 
-func (b *BaseStruct) parseStructRelationShip(relationship *schema.Relationships) field.Relations {
-	cache := make(map[string]bool)
-	return field.Relations{
-		HasOne:    b.pullRelationShip(cache, relationship.HasOne),
-		BelongsTo: b.pullRelationShip(cache, relationship.BelongsTo),
-		HasMany:   b.pullRelationShip(cache, relationship.HasMany),
-		Many2Many: b.pullRelationShip(cache, relationship.Many2Many),
-	}
-}
-
-func (b *BaseStruct) pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation {
-	if len(relationships) == 0 {
-		return nil
-	}
-	result := make([]*field.Relation, len(relationships))
-	for i, relationship := range relationships {
-		var childRelations []*field.Relation
-		varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*")
-		if !cache[varType] {
-			cache[varType] = true
-			childRelations = b.pullRelationShip(cache, append(append(append(append(
-				make([]*schema.Relationship, 0, 4),
-				relationship.FieldSchema.Relationships.BelongsTo...),
-				relationship.FieldSchema.Relationships.HasOne...),
-				relationship.FieldSchema.Relationships.HasMany...),
-				relationship.FieldSchema.Relationships.Many2Many...),
-			)
-		}
-		result[i] = field.NewRelation(relationship.Name, varType, childRelations...)
-	}
-	return result
-}
-
 // getMemberRealType  get basic type of member
 func (b *BaseStruct) getMemberRealType(member reflect.Type) string {
 	scanValuer := reflect.TypeOf((*field.ScanValuer)(nil)).Elem()
@@ -134,7 +101,7 @@ func (b *BaseStruct) check() (err error) {
 	return nil
 }
 
-func GetNames(bases []*BaseStruct) (res []string) {
+func GetStructNames(bases []*BaseStruct) (res []string) {
 	for _, base := range bases {
 		res = append(res, base.StructName)
 	}
@@ -145,3 +112,38 @@ func isStructType(data reflect.Value) bool {
 	return data.Kind() == reflect.Struct ||
 		(data.Kind() == reflect.Ptr && data.Elem().Kind() == reflect.Struct)
 }
+
+// ParseStructRelationShip parse struct's relationship
+// No one should use it directly in project
+func ParseStructRelationShip(relationship *schema.Relationships) *field.Relations {
+	cache := make(map[string]bool)
+	return &field.Relations{
+		HasOne:    pullRelationShip(cache, relationship.HasOne),
+		BelongsTo: pullRelationShip(cache, relationship.BelongsTo),
+		HasMany:   pullRelationShip(cache, relationship.HasMany),
+		Many2Many: pullRelationShip(cache, relationship.Many2Many),
+	}
+}
+
+func pullRelationShip(cache map[string]bool, relationships []*schema.Relationship) []*field.Relation {
+	if len(relationships) == 0 {
+		return nil
+	}
+	result := make([]*field.Relation, len(relationships))
+	for i, relationship := range relationships {
+		var childRelations []*field.Relation
+		varType := strings.TrimLeft(relationship.Field.FieldType.String(), "[]*")
+		if !cache[varType] {
+			cache[varType] = true
+			childRelations = pullRelationShip(cache, append(append(append(append(
+				make([]*schema.Relationship, 0, 4),
+				relationship.FieldSchema.Relationships.BelongsTo...),
+				relationship.FieldSchema.Relationships.HasOne...),
+				relationship.FieldSchema.Relationships.HasMany...),
+				relationship.FieldSchema.Relationships.Many2Many...),
+			)
+		}
+		result[i] = field.NewRelation(relationship.Name, varType, childRelations...)
+	}
+	return result
+}
diff --git a/internal/check/gen_structs.go b/internal/check/gen_structs.go
index f8022f70..879c3212 100644
--- a/internal/check/gen_structs.go
+++ b/internal/check/gen_structs.go
@@ -10,7 +10,6 @@ import (
 	"gorm.io/gorm"
 	"gorm.io/gorm/utils/tests"
 
-	"gorm.io/gen/field"
 	"gorm.io/gen/internal/parser"
 )
 
@@ -20,7 +19,7 @@ import (
  */
 
 const (
-	ModelPkg = "model"
+	DefaultModelPkg = "model"
 
 	//query table structure
 	columnQuery = "SELECT COLUMN_NAME ,COLUMN_COMMENT ,DATA_TYPE ,IS_NULLABLE ,COLUMN_KEY,COLUMN_TYPE,COLUMN_DEFAULT,EXTRA" +
@@ -30,7 +29,7 @@ const (
 type SchemaNameOpt func(*gorm.DB) string
 
 // GenBaseStructs generate db model by table name
-func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) {
+func GenBaseStructs(db *gorm.DB, modelPkg, tableName, modelName string, schemaNameOpts []SchemaNameOpt, memberOpts []MemberOpt, nullable bool) (bases *BaseStruct, err error) {
 	if _, ok := db.Config.Dialector.(tests.DummyDialector); ok {
 		return nil, fmt.Errorf("UseDB() is necessary to generate model struct [%s] from database table [%s]", modelName, tableName)
 	}
@@ -38,10 +37,10 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 	if err = checkModelName(modelName); err != nil {
 		return nil, fmt.Errorf("model name %q is invalid: %w", modelName, err)
 	}
-	if pkg == "" {
-		pkg = ModelPkg
+	if modelPkg == "" {
+		modelPkg = DefaultModelPkg
 	}
-	pkg = filepath.Base(pkg)
+	modelPkg = filepath.Base(modelPkg)
 	dbName := getSchemaName(db, schemaNameOpts...)
 	columns, err := getTbColumns(db, dbName, tableName)
 	if err != nil {
@@ -54,9 +53,7 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 		StructName:    modelName,
 		NewStructName: uncaptialize(modelName),
 		S:             strings.ToLower(modelName[0:1]),
-		StructInfo:    parser.Param{Type: modelName, Package: pkg},
-
-		Relations: field.Relations{},
+		StructInfo:    parser.Param{Type: modelName, Package: modelPkg},
 	}
 
 	modifyOpts, filterOpts, createOpts := sortOpt(memberOpts)
@@ -77,6 +74,14 @@ func GenBaseStructs(db *gorm.DB, pkg, tableName, modelName string, schemaNameOpt
 		m := create.self()(nil)
 
 		if m.Relation != nil {
+			if m.Relation.Model() != nil {
+				stmt := gorm.Statement{DB: db}
+				_ = stmt.Parse(m.Relation.Model())
+				if stmt.Schema != nil {
+					m.Relation.AppendChildRelation(ParseStructRelationShip(&stmt.Schema.Relationships).SingleRelation()...)
+				}
+			}
+			m.Type = strings.ReplaceAll(m.Type, modelPkg+".", "") // remove modelPkg in field's Type, avoid import error
 			base.Relations.Accept(m.Relation)
 		} else { // Relation Field do not need SchemaName convert
 			m.Name = db.NamingStrategy.SchemaName(m.Name)