Skip to content

Commit

Permalink
Merge pull request go-gorm#131 from go-gorm/feature/association
Browse files Browse the repository at this point in the history
Feature/association
  • Loading branch information
tr1v3r authored Oct 2, 2021
2 parents 8098372 + 8825025 commit 1cedc61
Show file tree
Hide file tree
Showing 7 changed files with 250 additions and 80 deletions.
93 changes: 77 additions & 16 deletions field/association.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ type Relations struct {
Many2Many []*Relation
}

func (r *Relations) Accept(relations ...*Relation) {
for _, relation := range relations {
switch relation.Relationship() {
case HasOne:
r.HasOne = append(r.HasOne, relation)
case HasMany:
r.HasMany = append(r.HasMany, relation)
case BelongsTo:
r.BelongsTo = append(r.BelongsTo, relation)
case Many2Many:
r.Many2Many = append(r.Many2Many, relation)
}
}
}

func (r *Relations) SingleRelation() []*Relation {
return append(append(append(append(make([]*Relation, 0, 4), r.HasOne...), r.BelongsTo...), r.HasMany...), r.Many2Many...)
}

type RelationField interface {
Name() string
Path() string
Expand All @@ -39,28 +58,39 @@ type RelationField interface {
}

type Relation struct {
varName string
varType string
path string
relationship RelationshipType

relations []*Relation
fieldName string
fieldType string
fieldPath string
fieldModel interface{} // store relaiton model

childRelations []*Relation

conds []Expr
order []Expr
clauses []clause.Expression
}

func (r Relation) Name() string { return r.varName }
func (r Relation) Name() string { return r.fieldName }

func (r Relation) Path() string { return r.fieldPath }

func (r Relation) Path() string { return r.path }
func (r Relation) Type() string { return r.fieldType }

func (r Relation) Type() string { return r.varType }
func (r Relation) Model() interface{} { return r.fieldModel }

func (r Relation) Relationship() RelationshipType { return r.relationship }

func (r Relation) Field(member ...string) Expr {
if len(member) > 0 {
return NewString("", r.varName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote)
return NewString("", r.fieldName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote)
}
return NewString("", r.varName).appendBuildOpts(WithoutQuote)
return NewString("", r.fieldName).appendBuildOpts(WithoutQuote)
}

func (r *Relation) AppendChildRelation(relations ...*Relation) {
r.childRelations = append(r.childRelations, wrapPath(r.fieldPath, relations)...)
}

func (r *Relation) On(conds ...Expr) RelationField {
Expand All @@ -86,25 +116,56 @@ func (r *Relation) GetClauses() []clause.Expression { return r.clauses }

func (r *Relation) StructMember() string {
var memberStr string
for _, relation := range r.relations {
memberStr += relation.varName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n"
for _, relation := range r.childRelations {
memberStr += relation.fieldName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n"
}
return memberStr
}

func (r *Relation) StructMemberInit() string {
initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.varType)
for _, relation := range r.relations {
initStr += relation.varName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.fieldPath, r.fieldType)
for _, relation := range r.childRelations {
initStr += relation.fieldName + ": struct {\nfield.RelationField\n" + strings.TrimSpace(relation.StructMember()) + "}"
initStr += "{\n" + relation.StructMemberInit() + "},\n"
}
return initStr
}

func wrapPath(root string, rs []*Relation) []*Relation {
for _, r := range rs {
r.path = root + "." + r.path
r.relations = wrapPath(root, r.relations)
r.fieldPath = root + "." + r.fieldPath
r.childRelations = wrapPath(root, r.childRelations)
}
return rs
}

var defaultRelationshipPrefix = map[RelationshipType]string{
// HasOne: "",
// BelongsTo: "",
HasMany: "[]",
Many2Many: "[]",
}

type RelateConfig struct {
RelatePointer bool
RelateSlice bool
RelateSlicePointer bool

JSONTag string
GORMTag string
NewTag string
OverwriteTag string
}

func (c *RelateConfig) RelateFieldPrefix(relationshipType RelationshipType) string {
switch {
case c.RelatePointer:
return "*"
case c.RelateSlice:
return "[]"
case c.RelateSlicePointer:
return "[]*"
default:
return defaultRelationshipPrefix[relationshipType]
}
}
34 changes: 32 additions & 2 deletions field/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,36 @@ func EmptyExpr() Expr { return expr{e: clause.Expr{}} }
var AssociationFields Expr = NewString("", clause.Associations)
var Associations RelationField = NewRelation(clause.Associations, "")

func NewRelation(varName string, varType string, relations ...*Relation) *Relation {
return &Relation{varName: varName, path: varName, varType: varType, relations: wrapPath(varName, relations)}
func NewRelation(fieldName string, fieldType string, relations ...*Relation) *Relation {
return &Relation{
fieldName: fieldName,
fieldPath: fieldName,
fieldType: fieldType,
childRelations: wrapPath(fieldName, relations),
}
}

func NewRelationWithModel(relationship RelationshipType, fieldName string, fieldType string, fieldModel interface{}, relations ...*Relation) *Relation {
return &Relation{
relationship: relationship,
fieldName: fieldName,
fieldType: fieldType,
fieldPath: fieldName,
fieldModel: fieldModel,
}
}

func NewRelationAndCopy(relationship RelationshipType, fieldName string, fieldType string, relations ...*Relation) *Relation {
copiedRelations := make([]*Relation, len(relations))
for i, r := range relations {
copy := *r
copiedRelations[i] = &copy
}
return &Relation{
relationship: relationship,
fieldName: fieldName,
fieldType: fieldType,
fieldPath: fieldName,
childRelations: wrapPath(fieldName, copiedRelations),
}
}
64 changes: 56 additions & 8 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@ import (
"log"
"os"
"path/filepath"
"reflect"
"regexp"
"strconv"
"strings"
"text/template"

"golang.org/x/tools/imports"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"gorm.io/gorm/utils/tests"

"gorm.io/gen/field"
"gorm.io/gen/internal/check"
"gorm.io/gen/internal/parser"
tmpl "gorm.io/gen/internal/template"
Expand Down Expand Up @@ -79,7 +82,7 @@ func (cfg *Config) WithDbNameOpts(opts ...check.SchemaNameOpt) {

func (cfg *Config) Revise() (err error) {
if cfg.ModelPkgPath == "" {
cfg.ModelPkgPath = check.ModelPkg
cfg.ModelPkgPath = check.DefaultModelPkg
}

cfg.OutPath, err = filepath.Abs(cfg.OutPath)
Expand Down Expand Up @@ -265,6 +268,55 @@ var (
return m
}
}
FieldRelate = func(relationship field.RelationshipType, fieldName string, table *check.BaseStruct, config *field.RelateConfig) check.CreateMemberOpt {
if config == nil {
config = &field.RelateConfig{}
}
if config.JSONTag == "" {
config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName)
}
return func(*check.Member) *check.Member {
return &check.Member{
Name: fieldName,
Type: config.RelateFieldPrefix(relationship) + table.StructInfo.Type,
JSONTag: config.JSONTag,
GORMTag: config.GORMTag,
NewTag: config.NewTag,
OverwriteTag: config.OverwriteTag,

Relation: field.NewRelationAndCopy(
relationship, fieldName, table.StructInfo.Package+"."+table.StructInfo.Type,
table.Relations.SingleRelation()...),
}
}
}
FieldRelateModel = func(relationship field.RelationshipType, fieldName string, model interface{}, config *field.RelateConfig) check.CreateMemberOpt {
st := reflect.TypeOf(model)
if st.Kind() == reflect.Ptr {
st = st.Elem()
}
fieldType := st.String()

if config == nil {
config = &field.RelateConfig{}
}
if config.JSONTag == "" {
config.JSONTag = schema.NamingStrategy{}.ColumnName("", fieldName)
}

return func(*check.Member) *check.Member {
return &check.Member{
Name: fieldName,
Type: config.RelateFieldPrefix(relationship) + fieldType,
JSONTag: config.JSONTag,
GORMTag: config.GORMTag,
NewTag: config.NewTag,
OverwriteTag: config.OverwriteTag,

Relation: field.NewRelationWithModel(relationship, fieldName, fieldType, model),
}
}
}
)

/*
Expand Down Expand Up @@ -313,7 +365,7 @@ func (g *Generator) apply(fc interface{}, structs []*check.BaseStruct) {
panic("check interface fail")
}

err = readInterface.ParseFile(interfacePaths, check.GetNames(structs))
err = readInterface.ParseFile(interfacePaths, check.GetStructNames(structs))
if err != nil {
g.db.Logger.Error(context.Background(), "parser interface file fail: %s", err)
panic("parser interface file fail")
Expand Down Expand Up @@ -470,7 +522,7 @@ func (g *Generator) generateBaseStruct() (err error) {
}
path := filepath.Clean(g.ModelPkgPath)
if path == "" {
path = check.ModelPkg
path = check.DefaultModelPkg
}
if strings.Contains(path, "/") {
outPath, err = filepath.Abs(path)
Expand Down Expand Up @@ -522,11 +574,7 @@ func (g *Generator) output(fileName string, content []byte) error {
result, err := imports.Process(fileName, content, nil)
if err != nil {
errLine, _ := strconv.Atoi(strings.Split(err.Error(), ":")[1])
startLine, endLine := errLine-3, errLine+3
if startLine < 0 {
startLine = 0
}

startLine, endLine := 0, errLine+3
fmt.Println("Format fail:")
line := strings.Split(string(content), "\n")
for i := startLine; i <= endLine; i++ {
Expand Down
10 changes: 10 additions & 0 deletions internal/check/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"bytes"
"fmt"
"strings"

"gorm.io/gen/field"
)

type Status int
Expand Down Expand Up @@ -112,9 +114,17 @@ type Member struct {
GORMTag string
NewTag string
OverwriteTag string

Relation *field.Relation
}

func (m *Member) IsRelation() bool { return m.Relation != nil }

func (m *Member) GenType() string {
if m.IsRelation() {
return m.Type
}

switch m.Type {
case "string", "bytes":
return strings.Title(m.Type)
Expand Down
Loading

0 comments on commit 1cedc61

Please sign in to comment.