Skip to content

Commit

Permalink
feat(update): update from subquery(multi columns)
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Nov 11, 2021
1 parent b3e5d69 commit 75075f9
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 33 deletions.
83 changes: 56 additions & 27 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strings"

"gorm.io/gorm"
"gorm.io/gorm/callbacks"
"gorm.io/gorm/clause"
"gorm.io/gorm/schema"

Expand All @@ -27,6 +28,8 @@ type resultInfo struct {
Error error
}

var _ Dao = new(DO)

// DO (data object): implement basic query methods
// the structure embedded with a *gorm.DB, and has a element item "alias" will be used when used as a sub query
type DO struct {
Expand Down Expand Up @@ -143,7 +146,10 @@ func (d *DO) Clauses(conds ...clause.Expression) Dao {
}

// As alias cannot be heired, As must used on tail
func (d *DO) As(alias string) Dao { return &DO{db: d.db, alias: alias} }
func (d DO) As(alias string) Dao {
d.alias = alias
return &d
}

// Columns return columns for Subquery
func (*DO) Columns(cols ...field.Expr) columns { return cols }
Expand Down Expand Up @@ -351,6 +357,26 @@ func (d *DO) Preload(field field.RelationField) Dao {
return d.getInstance(d.db.Preload(field.Path(), args...))
}

// UpdateFrom specify update sub query
// WARNNING!!! This Method will be deprecated soon!!!
func (d *DO) UpdateFrom(querys ...subQuery) Dao {
var tableName strings.Builder
d.db.Statement.QuoteTo(&tableName, d.db.Statement.Table)
tableName.WriteByte(' ')
d.db.Statement.QuoteTo(&tableName, d.alias)
for _, q := range querys {
tableName.WriteByte(',')
if _, ok := q.underlyingDB().Statement.Clauses["SELECT"]; ok || len(q.underlyingDB().Statement.Selects) > 0 {
tableName.WriteString("(" + q.underlyingDB().ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Find(nil) }) + ")")
} else {
d.db.Statement.QuoteTo(&tableName, q.underlyingDB().Statement.Table)
}
tableName.WriteByte(' ')
d.db.Statement.QuoteTo(&tableName, q.underlyingDO().alias)
}
return d.getInstance(d.db.Clauses(clause.Update{Table: clause.Table{Name: tableName.String(), Raw: true}}))
}

func getFromClause(db *gorm.DB) *clause.From {
if db == nil || db.Statement == nil {
return &clause.From{}
Expand Down Expand Up @@ -468,12 +494,7 @@ func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info resultInfo, err err
return
}

dest, err := assignMap(d.db.Statement, columns)
if err != nil {
return resultInfo{Error: err}, err
}

result := d.db.Model(d.newResultPointer()).Updates(dest)
result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").Updates(map[string]interface{}{})
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

Expand Down Expand Up @@ -503,12 +524,7 @@ func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info resultInfo, e
return
}

dest, err := assignMap(d.db.Statement, columns)
if err != nil {
return resultInfo{Error: err}, err
}

result := d.db.Model(d.newResultPointer()).UpdateColumns(dest)
result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").UpdateColumns(map[string]interface{}{})
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

Expand All @@ -517,6 +533,28 @@ func (d *DO) UpdateColumns(value interface{}) (info resultInfo, err error) {
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// assignSet fetch all set
func (d *DO) assignSet(exprs []field.AssignExpr) (set clause.Set) {
for _, expr := range exprs {
column := clause.Column{Name: string(expr.ColumnName())}
if d.alias != "" {
column.Table = d.alias
}
switch e := expr.AssignExpr().(type) {
case clause.Expr:
set = append(set, clause.Assignment{Column: column, Value: e})
case clause.Eq:
set = append(set, clause.Assignment{Column: column, Value: e.Value})
case clause.Set:
set = append(set, e...)
}
}

stmt := d.db.Session(&gorm.Session{}).Statement
stmt.Dest = map[string]interface{}{}
return append(set, callbacks.ConvertToAssignments(stmt)...)
}

func (d *DO) Delete() (info resultInfo, err error) {
result := d.db.Model(d.newResultPointer()).Delete(reflect.New(d.modelType).Interface())
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
Expand Down Expand Up @@ -634,20 +672,6 @@ func toInterfaceSlice(value interface{}) []interface{} {
}
}

func assignMap(stmt *gorm.Statement, exprs []field.AssignExpr) (map[string]interface{}, error) {
dest := make(map[string]interface{}, len(exprs))
for _, expr := range exprs {
target := expr.BuildColumn(stmt, field.WithoutQuote).String()
switch e := expr.AssignExpr().(type) {
case clause.Expr:
dest[target] = e
case clause.Eq:
dest[target] = e.Value
}
}
return dest, nil
}

// ======================== New Table ========================

// Table return a new table produced by subquery,
Expand Down Expand Up @@ -683,6 +707,11 @@ func Table(subQueries ...subQuery) Dao {

type columns []field.Expr

// Set assign value by subquery
func (cs columns) Set(query subQuery) field.AssignExpr {
return field.AssignSubQuery(cs, query.underlyingDB())
}

// In accept query or value
func (cs columns) In(queryOrValue Condition) field.Expr {
if len(cs) == 0 {
Expand Down
27 changes: 22 additions & 5 deletions field/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,18 +142,35 @@ func ContainsSubQuery(columns []Expr, subQuery *gorm.DB) Expr {
Vars: []interface{}{columns[0].RawExpr(), subQuery},
}}
default: // len(columns) > 0
vars := make([]string, len(columns))
queryCols := make([]interface{}, len(columns))
placeholders := make([]string, len(columns))
cols := make([]interface{}, len(columns))
for i, c := range columns {
vars[i], queryCols[i] = "?", c.RawExpr()
placeholders[i], cols[i] = "?", c.RawExpr()
}
return expr{e: clause.Expr{
SQL: fmt.Sprintf("(%s) IN (?)", strings.Join(vars, ", ")),
Vars: append(queryCols, subQuery),
SQL: fmt.Sprintf("(%s) IN (?)", strings.Join(placeholders, ",")),
Vars: append(cols, subQuery),
}}
}
}

func AssignSubQuery(columns []Expr, subQuery *gorm.DB) AssignExpr {
cols := make([]string, len(columns))
for i, c := range columns {
cols[i] = string(c.BuildColumn(subQuery.Statement))
}

name := cols[0]
if len(cols) > 1 {
name = "(" + strings.Join(cols, ",") + ")"
}

return expr{e: clause.Set{{
Column: clause.Column{Name: name, Raw: true},
Value: gorm.Expr("(?)", subQuery),
}}}
}

type CompareOperate string

const (
Expand Down
4 changes: 4 additions & 0 deletions field/expr.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ func (e expr) LteCol(col Expr) Expr {
return e.setE(clause.Expr{SQL: "? <= ?", Vars: []interface{}{e.RawExpr(), col.RawExpr()}})
}

func (e expr) SetCol(col Expr) AssignExpr {
return e.setE(clause.Eq{Column: e.col.Name, Value: col.RawExpr()})
}

// ======================== keyword ========================
func (e expr) As(alias string) Expr {
if e.e != nil {
Expand Down
1 change: 0 additions & 1 deletion generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ type genInfo struct {
Interfaces []*check.InterfaceMethod
}

//
func (i *genInfo) appendMethods(methods []*check.InterfaceMethod) error {
for _, newMethod := range methods {
if i.methodInGenInfo(newMethod) {
Expand Down

0 comments on commit 75075f9

Please sign in to comment.