Skip to content

Commit

Permalink
Merge pull request go-gorm#125 from go-gorm/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
tr1v3r authored Sep 27, 2021
2 parents 555e137 + d34f2b4 commit 388e70d
Show file tree
Hide file tree
Showing 15 changed files with 1,048 additions and 106 deletions.
440 changes: 439 additions & 1 deletion README.md

Large diffs are not rendered by default.

88 changes: 60 additions & 28 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ func (d *DO) UseDB(db *gorm.DB, opts ...doOptions) {
d.db = db
}

func (d *DO) ReplaceDB(db *gorm.DB) {
d.db = db
}
func (d *DO) ReplaceDB(db *gorm.DB) { d.db = db }

// UseModel specify a data model structure as a source for table name
func (d *DO) UseModel(model interface{}) {
Expand All @@ -74,24 +72,24 @@ func (d *DO) UseModel(model interface{}) {
}

// UseTable specify table name
func (d *DO) UseTable(tableName string) {
d.db = d.db.Table(tableName).Session(new(gorm.Session))
}
func (d *DO) UseTable(tableName string) { d.db = d.db.Table(tableName).Session(new(gorm.Session)) }

// TableName return table name
func (d *DO) TableName() string {
func (d DO) TableName() string {
if d.schema == nil {
return ""
}
return d.schema.Table
}

// Session replace db with new session
func (d *DO) Session(config *gorm.Session) Dao { return d.getInstance(d.db.Session(config)) }

// UnderlyingDB return the underlying database connection
func (d *DO) UnderlyingDB() *gorm.DB {
return d.db
}
func (d *DO) UnderlyingDB() *gorm.DB { return d.db }

// Quote return qutoed data
func (d *DO) Quote(raw string) string {
return d.db.Statement.Quote(raw)
}
func (d *DO) Quote(raw string) string { return d.db.Statement.Quote(raw) }

// Build implement the interface of claues.Expression
// only call WHERE clause's Build
Expand Down Expand Up @@ -237,13 +235,16 @@ func (d *DO) Order(columns ...field.Expr) Dao {
// }
// }
// return d.newInstance(d.db.Clauses(clause.OrderBy{Expression: clause.CommaExpression{Exprs: toExpression(columns)}}))
return d.getInstance(d.db.Order(d.calcOrderValue(columns...)))
}

func (d *DO) calcOrderValue(columns ...field.Expr) string {
// eager build Columns
orderArray := make([]string, len(columns))
for i, c := range columns {
orderArray[i] = c.Build(d.db.Statement).String()
}
return d.getInstance(d.db.Order(strings.Join(orderArray, ",")))
return strings.Join(orderArray, ",")
}

func (d *DO) Distinct(columns ...field.Expr) Dao {
Expand Down Expand Up @@ -289,6 +290,8 @@ func (d *DO) Unscoped() Dao {
return d.getInstance(d.db.Unscoped())
}

// TODO implement commonDo

func (d *DO) Join(table schema.Tabler, conds ...field.Expr) Dao {
return d.join(table, clause.InnerJoin, conds)
}
Expand Down Expand Up @@ -323,6 +326,35 @@ func (d *DO) Assign(attrs ...field.Expr) Dao {
return d.getInstance(d.db.Assign(toExpressionInterface(attrs...)...))
}

func (d *DO) Joins(field field.RelationField) Dao {
return d.getInstance(d.db.Joins(field.Path()))
}

// func (d *DO) Preload(column field.RelationPath, subQuery ...SubQuery) Dao {
// if len(subQuery) > 0 {
// return d.getInstance(d.db.Preload(string(column.Path()), subQuery[0].underlyingDB()))
// }
// return d.getInstance(d.db.Preload(string(column.Path())))
// }

func (d *DO) Preload(field field.RelationField) Dao {
var args []interface{}
if conds := field.GetConds(); len(conds) > 0 {
args = append(args, toExpressionInterface(conds...)...)
}
if columns := field.GetOrderCol(); len(columns) > 0 {
args = append(args, func(db *gorm.DB) *gorm.DB {
return db.Order(d.calcOrderValue(columns...))
})
}
if clauses := field.GetClauses(); len(clauses) > 0 {
args = append(args, func(db *gorm.DB) *gorm.DB {
return db.Clauses(clauses...)
})
}
return d.getInstance(d.db.Preload(field.Path(), args...))
}

func getFromClause(db *gorm.DB) *clause.From {
if db == nil || db.Statement == nil {
return &clause.From{}
Expand All @@ -340,27 +372,27 @@ func getFromClause(db *gorm.DB) *clause.From {

// ======================== finisher api ========================
func (d *DO) Create(value interface{}) error {
return d.db.Model(d.model).Create(value).Error
return d.db.Create(value).Error
}

func (d *DO) CreateInBatches(value interface{}, batchSize int) error {
return d.db.Model(d.model).CreateInBatches(value, batchSize).Error
return d.db.CreateInBatches(value, batchSize).Error
}

func (d *DO) Save(value interface{}) error {
return d.db.Model(d.model).Save(value).Error
return d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(value).Error
}

func (d *DO) First() (result interface{}, err error) {
return d.singleQuery(d.db.Model(d.model).First)
return d.singleQuery(d.db.First)
}

func (d *DO) Take() (result interface{}, err error) {
return d.singleQuery(d.db.Model(d.model).Take)
return d.singleQuery(d.db.Take)
}

func (d *DO) Last() (result interface{}, err error) {
return d.singleQuery(d.db.Model(d.model).Last)
return d.singleQuery(d.db.Last)
}

func (d *DO) singleQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (result interface{}, err error) {
Expand All @@ -382,7 +414,7 @@ func (d *DO) singleScan() (result interface{}, err error) {
}

func (d *DO) Find() (results interface{}, err error) {
return d.multiQuery(d.db.Model(d.model).Find)
return d.multiQuery(d.db.Find)
}

func (d *DO) multiQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (results interface{}, err error) {
Expand All @@ -403,25 +435,25 @@ func (d *DO) findToMap() (interface{}, error) {

func (d *DO) FindInBatch(batchSize int, fc func(tx Dao, batch int) error) (result interface{}, err error) {
resultsPtr := d.newResultSlicePointer()
err = d.db.Model(d.model).FindInBatches(resultsPtr, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
err = d.db.FindInBatches(resultsPtr, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
return reflect.Indirect(reflect.ValueOf(resultsPtr)).Interface(), err
}

func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
return d.db.Model(d.model).FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
}

func (d *DO) FirstOrInit() (result interface{}, err error) {
return d.singleQuery(d.db.Model(d.model).FirstOrInit)
return d.singleQuery(d.db.FirstOrInit)
}

func (d *DO) FirstOrCreate() (result interface{}, err error) {
return d.singleQuery(d.db.Model(d.model).FirstOrCreate)
return d.singleQuery(d.db.FirstOrCreate)
}

func (d *DO) Update(column field.Expr, value interface{}) (info resultInfo, err error) {
tx := d.db.Model(d.model)
columnStr := column.BuildColumn(d.db.Statement, field.WithTable, field.WithoutQuote).String()
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()

var result *gorm.DB
switch value := value.(type) {
Expand Down Expand Up @@ -452,7 +484,7 @@ func (d *DO) Updates(value interface{}) (info resultInfo, err error) {

func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info resultInfo, err error) {
tx := d.db.Model(d.model)
columnStr := column.BuildColumn(d.db.Statement, field.WithTable, field.WithoutQuote).String()
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()

var result *gorm.DB
switch value := value.(type) {
Expand Down Expand Up @@ -606,7 +638,7 @@ func parseExprs(stmt *gorm.Statement, exprs []field.Expr) (map[string]interface{
if !ok {
return nil, ErrInvalidExpression
}
dest[e.BuildColumn(stmt, field.WithTable, field.WithoutQuote).String()] = expr
dest[e.BuildColumn(stmt, field.WithoutQuote).String()] = expr
}
return dest, nil
}
Expand Down
112 changes: 112 additions & 0 deletions field/association.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package field

import (
"fmt"
"strings"

"gorm.io/gorm/clause"
"gorm.io/gorm/schema"
)

var s = clause.Associations

type RelationshipType schema.RelationshipType

const (
HasOne RelationshipType = "has_one" // HasOneRel has one relationship
HasMany RelationshipType = "has_many" // HasManyRel has many relationship
BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship
Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship
)

type Relations struct {
HasOne []*Relation
BelongsTo []*Relation
HasMany []*Relation
Many2Many []*Relation
}

type RelationField interface {
Name() string
Path() string
Field(member ...string) Expr

On(conds ...Expr) RelationField
Order(columns ...Expr) RelationField
Clauses(hints ...clause.Expression) RelationField

GetConds() []Expr
GetOrderCol() []Expr
GetClauses() []clause.Expression
}

type Relation struct {
varName string
varType string
path string

relations []*Relation

conds []Expr
order []Expr
clauses []clause.Expression
}

func (r Relation) Name() string { return r.varName }

func (r Relation) Path() string { return r.path }

func (r Relation) Type() string { return r.varType }

func (r Relation) Field(member ...string) Expr {
if len(member) > 0 {
return NewString("", r.varName+"."+strings.Join(member, ".")).appendBuildOpts(WithoutQuote)
}
return NewString("", r.varName).appendBuildOpts(WithoutQuote)
}

func (r *Relation) On(conds ...Expr) RelationField {
r.conds = append(r.conds, conds...)
return r
}

func (r *Relation) Order(columns ...Expr) RelationField {
r.order = append(r.order, columns...)
return r
}

func (r *Relation) Clauses(hints ...clause.Expression) RelationField {
r.clauses = append(r.clauses, hints...)
return r
}

func (r *Relation) GetConds() []Expr { return r.conds }

func (r *Relation) GetOrderCol() []Expr { return r.order }

func (r *Relation) GetClauses() []clause.Expression { return r.clauses }

func (r *Relation) StructMember() string {
var memberStr string
for _, relation := range r.relations {
memberStr += relation.varName + " struct {\nfield.RelationField\n" + relation.StructMember() + "}\n"
}
return memberStr
}

func (r *Relation) StructMemberInit() string {
initStr := fmt.Sprintf("RelationField: field.NewRelation(%q, %q),\n", r.path, r.varType)
for _, relation := range r.relations {
initStr += relation.varName + ": struct {\nfield.RelationField\n" + strings.TrimPrefix(strings.TrimSpace(relation.StructMember()), relation.varName) + "}"
initStr += "{\n" + relation.StructMemberInit() + "},\n"
}
return initStr
}

func wrapPath(root string, rs []*Relation) []*Relation {
for _, r := range rs {
r.path = root + "." + r.path
r.relations = wrapPath(root, r.relations)
}
return rs
}
16 changes: 15 additions & 1 deletion field/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@ import (

type FieldOption func(clause.Column) clause.Column

var (
banColumnRaw FieldOption = func(col clause.Column) clause.Column {
col.Raw = false
return col
}
)

// TODO implement validator options

// ======================== generic field =======================
Expand Down Expand Up @@ -97,7 +104,7 @@ func toColumn(table, column string, opts ...FieldOption) clause.Column {
for _, opt := range opts {
col = opt(col)
}
return col
return banColumnRaw(col)
}

// ======================== boolean operate ========================
Expand Down Expand Up @@ -207,3 +214,10 @@ func ContainsValue(columns []Expr, value Value) Expr {
}

func EmptyExpr() Expr { return expr{e: clause.Expr{}} }

var AssociationFields Expr = NewString("", clause.Associations)
var Associations RelationField = NewRelation(clause.Associations, "")

func NewRelation(varName string, varType string, relations ...*Relation) *Relation {
return &Relation{varName: varName, path: varName, varType: varType, relations: wrapPath(varName, relations)}
}
Loading

0 comments on commit 388e70d

Please sign in to comment.