From 75075f90d6df3b40a788218a0ba20a06a28fea5f Mon Sep 17 00:00:00 2001 From: riverchu Date: Wed, 10 Nov 2021 18:41:51 +0800 Subject: [PATCH] feat(update): update from subquery(multi columns) --- do.go | 83 +++++++++++++++++++++++++++++++++---------------- field/export.go | 27 +++++++++++++--- field/expr.go | 4 +++ generator.go | 1 - 4 files changed, 82 insertions(+), 33 deletions(-) diff --git a/do.go b/do.go index 9e348bcc..66b3b229 100644 --- a/do.go +++ b/do.go @@ -8,6 +8,7 @@ import ( "strings" "gorm.io/gorm" + "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/schema" @@ -27,6 +28,8 @@ type resultInfo struct { Error error } +var _ Dao = new(DO) + // DO (data object): implement basic query methods // the structure embedded with a *gorm.DB, and has a element item "alias" will be used when used as a sub query type DO struct { @@ -143,7 +146,10 @@ func (d *DO) Clauses(conds ...clause.Expression) Dao { } // As alias cannot be heired, As must used on tail -func (d *DO) As(alias string) Dao { return &DO{db: d.db, alias: alias} } +func (d DO) As(alias string) Dao { + d.alias = alias + return &d +} // Columns return columns for Subquery func (*DO) Columns(cols ...field.Expr) columns { return cols } @@ -351,6 +357,26 @@ func (d *DO) Preload(field field.RelationField) Dao { return d.getInstance(d.db.Preload(field.Path(), args...)) } +// UpdateFrom specify update sub query +// WARNNING!!! This Method will be deprecated soon!!! +func (d *DO) UpdateFrom(querys ...subQuery) Dao { + var tableName strings.Builder + d.db.Statement.QuoteTo(&tableName, d.db.Statement.Table) + tableName.WriteByte(' ') + d.db.Statement.QuoteTo(&tableName, d.alias) + for _, q := range querys { + tableName.WriteByte(',') + if _, ok := q.underlyingDB().Statement.Clauses["SELECT"]; ok || len(q.underlyingDB().Statement.Selects) > 0 { + tableName.WriteString("(" + q.underlyingDB().ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Find(nil) }) + ")") + } else { + d.db.Statement.QuoteTo(&tableName, q.underlyingDB().Statement.Table) + } + tableName.WriteByte(' ') + d.db.Statement.QuoteTo(&tableName, q.underlyingDO().alias) + } + return d.getInstance(d.db.Clauses(clause.Update{Table: clause.Table{Name: tableName.String(), Raw: true}})) +} + func getFromClause(db *gorm.DB) *clause.From { if db == nil || db.Statement == nil { return &clause.From{} @@ -468,12 +494,7 @@ func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info resultInfo, err err return } - dest, err := assignMap(d.db.Statement, columns) - if err != nil { - return resultInfo{Error: err}, err - } - - result := d.db.Model(d.newResultPointer()).Updates(dest) + result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").Updates(map[string]interface{}{}) return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error } @@ -503,12 +524,7 @@ func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info resultInfo, e return } - dest, err := assignMap(d.db.Statement, columns) - if err != nil { - return resultInfo{Error: err}, err - } - - result := d.db.Model(d.newResultPointer()).UpdateColumns(dest) + result := d.db.Model(d.newResultPointer()).Clauses(d.assignSet(columns)).Omit("*").UpdateColumns(map[string]interface{}{}) return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error } @@ -517,6 +533,28 @@ func (d *DO) UpdateColumns(value interface{}) (info resultInfo, err error) { return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error } +// assignSet fetch all set +func (d *DO) assignSet(exprs []field.AssignExpr) (set clause.Set) { + for _, expr := range exprs { + column := clause.Column{Name: string(expr.ColumnName())} + if d.alias != "" { + column.Table = d.alias + } + switch e := expr.AssignExpr().(type) { + case clause.Expr: + set = append(set, clause.Assignment{Column: column, Value: e}) + case clause.Eq: + set = append(set, clause.Assignment{Column: column, Value: e.Value}) + case clause.Set: + set = append(set, e...) + } + } + + stmt := d.db.Session(&gorm.Session{}).Statement + stmt.Dest = map[string]interface{}{} + return append(set, callbacks.ConvertToAssignments(stmt)...) +} + func (d *DO) Delete() (info resultInfo, err error) { result := d.db.Model(d.newResultPointer()).Delete(reflect.New(d.modelType).Interface()) return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error @@ -634,20 +672,6 @@ func toInterfaceSlice(value interface{}) []interface{} { } } -func assignMap(stmt *gorm.Statement, exprs []field.AssignExpr) (map[string]interface{}, error) { - dest := make(map[string]interface{}, len(exprs)) - for _, expr := range exprs { - target := expr.BuildColumn(stmt, field.WithoutQuote).String() - switch e := expr.AssignExpr().(type) { - case clause.Expr: - dest[target] = e - case clause.Eq: - dest[target] = e.Value - } - } - return dest, nil -} - // ======================== New Table ======================== // Table return a new table produced by subquery, @@ -683,6 +707,11 @@ func Table(subQueries ...subQuery) Dao { type columns []field.Expr +// Set assign value by subquery +func (cs columns) Set(query subQuery) field.AssignExpr { + return field.AssignSubQuery(cs, query.underlyingDB()) +} + // In accept query or value func (cs columns) In(queryOrValue Condition) field.Expr { if len(cs) == 0 { diff --git a/field/export.go b/field/export.go index 6a068803..3b69b44a 100644 --- a/field/export.go +++ b/field/export.go @@ -142,18 +142,35 @@ func ContainsSubQuery(columns []Expr, subQuery *gorm.DB) Expr { Vars: []interface{}{columns[0].RawExpr(), subQuery}, }} default: // len(columns) > 0 - vars := make([]string, len(columns)) - queryCols := make([]interface{}, len(columns)) + placeholders := make([]string, len(columns)) + cols := make([]interface{}, len(columns)) for i, c := range columns { - vars[i], queryCols[i] = "?", c.RawExpr() + placeholders[i], cols[i] = "?", c.RawExpr() } return expr{e: clause.Expr{ - SQL: fmt.Sprintf("(%s) IN (?)", strings.Join(vars, ", ")), - Vars: append(queryCols, subQuery), + SQL: fmt.Sprintf("(%s) IN (?)", strings.Join(placeholders, ",")), + Vars: append(cols, subQuery), }} } } +func AssignSubQuery(columns []Expr, subQuery *gorm.DB) AssignExpr { + cols := make([]string, len(columns)) + for i, c := range columns { + cols[i] = string(c.BuildColumn(subQuery.Statement)) + } + + name := cols[0] + if len(cols) > 1 { + name = "(" + strings.Join(cols, ",") + ")" + } + + return expr{e: clause.Set{{ + Column: clause.Column{Name: name, Raw: true}, + Value: gorm.Expr("(?)", subQuery), + }}} +} + type CompareOperate string const ( diff --git a/field/expr.go b/field/expr.go index 85e13d2e..fe40e1f1 100644 --- a/field/expr.go +++ b/field/expr.go @@ -205,6 +205,10 @@ func (e expr) LteCol(col Expr) Expr { return e.setE(clause.Expr{SQL: "? <= ?", Vars: []interface{}{e.RawExpr(), col.RawExpr()}}) } +func (e expr) SetCol(col Expr) AssignExpr { + return e.setE(clause.Eq{Column: e.col.Name, Value: col.RawExpr()}) +} + // ======================== keyword ======================== func (e expr) As(alias string) Expr { if e.e != nil { diff --git a/generator.go b/generator.go index bf9ae594..37ab5aed 100644 --- a/generator.go +++ b/generator.go @@ -109,7 +109,6 @@ type genInfo struct { Interfaces []*check.InterfaceMethod } -// func (i *genInfo) appendMethods(methods []*check.InterfaceMethod) error { for _, newMethod := range methods { if i.methodInGenInfo(newMethod) {