Skip to content

Commit

Permalink
style: optimize
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Jun 29, 2022
1 parent 2221495 commit 6734181
Show file tree
Hide file tree
Showing 14 changed files with 149 additions and 44 deletions.
60 changes: 57 additions & 3 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ func (d *DO) Alias() string {
func (*DO) Columns(cols ...field.Expr) Columns { return cols }

// ======================== chainable api ========================

// Not ...
func (d *DO) Not(conds ...Condition) Dao {
exprs, err := condToExpression(conds)
if err != nil {
Expand All @@ -189,6 +191,7 @@ func (d *DO) Not(conds ...Condition) Dao {
return d.getInstance(d.db.Clauses(clause.Where{Exprs: []clause.Expression{clause.Not(exprs...)}}))
}

// Or ...
func (d *DO) Or(conds ...Condition) Dao {
exprs, err := condToExpression(conds)
if err != nil {
Expand All @@ -200,6 +203,7 @@ func (d *DO) Or(conds ...Condition) Dao {
return d.getInstance(d.db.Clauses(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(exprs...))}}))
}

// Select ...
func (d *DO) Select(columns ...field.Expr) Dao {
if len(columns) == 0 {
return d.getInstance(d.db.Clauses(clause.Select{}))
Expand All @@ -211,6 +215,7 @@ func (d *DO) Select(columns ...field.Expr) Dao {
return d.getInstance(d.db.Select(strings.Join(query, ","), args...))
}

// Where ...
func (d *DO) Where(conds ...Condition) Dao {
exprs, err := condToExpression(conds)
if err != nil {
Expand All @@ -222,6 +227,7 @@ func (d *DO) Where(conds ...Condition) Dao {
return d.getInstance(d.db.Clauses(clause.Where{Exprs: exprs}))
}

// Order ...
func (d *DO) Order(columns ...field.Expr) Dao {
// lazy build Columns
// if c, ok := d.db.Statement.Clauses[clause.OrderBy{}.Name()]; ok {
Expand All @@ -236,10 +242,10 @@ func (d *DO) Order(columns ...field.Expr) Dao {
if len(columns) == 0 {
return d
}
return d.getInstance(d.db.Order(d.calcOrderValue(columns...)))
return d.getInstance(d.db.Order(d.toOrderValue(columns...)))
}

func (d *DO) calcOrderValue(columns ...field.Expr) string {
func (d *DO) toOrderValue(columns ...field.Expr) string {
// eager build Columns
orderArray := make([]string, len(columns))
for i, c := range columns {
Expand All @@ -248,17 +254,20 @@ func (d *DO) calcOrderValue(columns ...field.Expr) string {
return strings.Join(orderArray, ",")
}

// Distinct ...
func (d *DO) Distinct(columns ...field.Expr) Dao {
return d.getInstance(d.db.Distinct(toInterfaceSlice(toColExprFullName(d.db.Statement, columns...))...))
}

// Omit ...
func (d *DO) Omit(columns ...field.Expr) Dao {
if len(columns) == 0 {
return d
}
return d.getInstance(d.db.Omit(getColumnName(columns...)...))
}

// Group ...
func (d *DO) Group(columns ...field.Expr) Dao {
if len(columns) == 0 {
return d
Expand All @@ -270,6 +279,7 @@ func (d *DO) Group(columns ...field.Expr) Dao {
return d.getInstance(d.db.Group(name))
}

// Having ...
func (d *DO) Having(conds ...Condition) Dao {
exprs, err := condToExpression(conds)
if err != nil {
Expand All @@ -281,14 +291,17 @@ func (d *DO) Having(conds ...Condition) Dao {
return d.getInstance(d.db.Clauses(clause.GroupBy{Having: exprs}))
}

// Limit ...
func (d *DO) Limit(limit int) Dao {
return d.getInstance(d.db.Limit(limit))
}

// Offset ...
func (d *DO) Offset(offset int) Dao {
return d.getInstance(d.db.Offset(offset))
}

// Scopes ...
func (d *DO) Scopes(funcs ...func(Dao) Dao) Dao {
fcs := make([]func(*gorm.DB) *gorm.DB, len(funcs))
for i, f := range funcs {
Expand All @@ -298,18 +311,22 @@ func (d *DO) Scopes(funcs ...func(Dao) Dao) Dao {
return d.getInstance(d.db.Scopes(fcs...))
}

// Unscoped ...
func (d *DO) Unscoped() Dao {
return d.getInstance(d.db.Unscoped())
}

// Join ...
func (d *DO) Join(table schema.Tabler, conds ...field.Expr) Dao {
return d.join(table, clause.InnerJoin, conds)
}

// LeftJoin ...
func (d *DO) LeftJoin(table schema.Tabler, conds ...field.Expr) Dao {
return d.join(table, clause.LeftJoin, conds)
}

// RightJoin ...
func (d *DO) RightJoin(table schema.Tabler, conds ...field.Expr) Dao {
return d.join(table, clause.RightJoin, conds)
}
Expand All @@ -336,13 +353,15 @@ func (d *DO) join(table schema.Tabler, joinType clause.JoinType, conds []field.E
return d.getInstance(d.db.Clauses(from))
}

// Attrs ...
func (d *DO) Attrs(attrs ...field.AssignExpr) Dao {
if len(attrs) == 0 {
return d
}
return d.getInstance(d.db.Attrs(d.attrsValue(attrs)...))
}

// Assign ...
func (d *DO) Assign(attrs ...field.AssignExpr) Dao {
if len(attrs) == 0 {
return d
Expand All @@ -360,6 +379,7 @@ func (d *DO) attrsValue(attrs []field.AssignExpr) []interface{} {
return values
}

// Joins ...
func (d *DO) Joins(field field.RelationField) Dao {
var args []interface{}

Expand Down Expand Up @@ -473,6 +493,7 @@ func (d *DO) Joins(field field.RelationField) Dao {
// return d.getInstance(d.db.Preload(string(column.Path())))
// }

// Preload ...
func (d *DO) Preload(field field.RelationField) Dao {
var args []interface{}
if conds := field.GetConds(); len(conds) > 0 {
Expand All @@ -489,7 +510,7 @@ func (d *DO) Preload(field field.RelationField) Dao {
}
if columns := field.GetOrderCol(); len(columns) > 0 {
args = append(args, func(db *gorm.DB) *gorm.DB {
return db.Order(d.calcOrderValue(columns...))
return db.Order(d.toOrderValue(columns...))
})
}
if clauses := field.GetClauses(); len(clauses) > 0 {
Expand Down Expand Up @@ -544,26 +565,33 @@ func getFromClause(db *gorm.DB) *clause.From {
}

// ======================== finisher api ========================

// Create ...
func (d *DO) Create(value interface{}) error {
return d.db.Create(value).Error
}

// CreateInBatches ...
func (d *DO) CreateInBatches(value interface{}, batchSize int) error {
return d.db.CreateInBatches(value, batchSize).Error
}

// Save ...
func (d *DO) Save(value interface{}) error {
return d.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(value).Error
}

// First ...
func (d *DO) First() (result interface{}, err error) {
return d.singleQuery(d.db.First)
}

// Take ...
func (d *DO) Take() (result interface{}, err error) {
return d.singleQuery(d.db.Take)
}

// Last ...
func (d *DO) Last() (result interface{}, err error) {
return d.singleQuery(d.db.Last)
}
Expand All @@ -586,6 +614,7 @@ func (d *DO) singleScan() (result interface{}, err error) {
return
}

// Find ...
func (d *DO) Find() (results interface{}, err error) {
return d.multiQuery(d.db.Find)
}
Expand All @@ -606,18 +635,22 @@ func (d *DO) findToMap() (interface{}, error) {
return results, err
}

// FindInBatches ...
func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
}

// FirstOrInit ...
func (d *DO) FirstOrInit() (result interface{}, err error) {
return d.singleQuery(d.db.FirstOrInit)
}

// FirstOrCreate ...
func (d *DO) FirstOrCreate() (result interface{}, err error) {
return d.singleQuery(d.db.FirstOrCreate)
}

// Update ...
func (d *DO) Update(column field.Expr, value interface{}) (info ResultInfo, err error) {
tx := d.db.Model(d.newResultPointer())
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()
Expand All @@ -634,6 +667,7 @@ func (d *DO) Update(column field.Expr, value interface{}) (info ResultInfo, err
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// UpdateSimple ...
func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info ResultInfo, err error) {
if len(columns) == 0 {
return
Expand All @@ -643,6 +677,7 @@ func (d *DO) UpdateSimple(columns ...field.AssignExpr) (info ResultInfo, err err
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// Updates ...
func (d *DO) Updates(value interface{}) (info ResultInfo, err error) {
var result *gorm.DB
var rawTyp, typ reflect.Type
Expand All @@ -666,6 +701,7 @@ func (d *DO) Updates(value interface{}) (info ResultInfo, err error) {
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// UpdateColumn ...
func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info ResultInfo, err error) {
tx := d.db.Model(d.newResultPointer())
columnStr := column.BuildColumn(d.db.Statement, field.WithoutQuote).String()
Expand All @@ -682,6 +718,7 @@ func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info ResultInfo
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// UpdateColumnSimple ...
func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info ResultInfo, err error) {
if len(columns) == 0 {
return
Expand All @@ -691,6 +728,7 @@ func (d *DO) UpdateColumnSimple(columns ...field.AssignExpr) (info ResultInfo, e
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// UpdateColumns ...
func (d *DO) UpdateColumns(value interface{}) (info ResultInfo, err error) {
result := d.db.Model(d.newResultPointer()).UpdateColumns(value)
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
Expand Down Expand Up @@ -718,35 +756,43 @@ func (d *DO) assignSet(exprs []field.AssignExpr) (set clause.Set) {
return append(set, callbacks.ConvertToAssignments(stmt)...)
}

// Delete ...
func (d *DO) Delete() (info ResultInfo, err error) {
result := d.db.Model(d.newResultPointer()).Delete(reflect.New(d.modelType).Interface())
return ResultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

// Count ...
func (d *DO) Count() (count int64, err error) {
return count, d.db.Session(&gorm.Session{}).Model(d.newResultPointer()).Count(&count).Error
}

// Row ...
func (d *DO) Row() *sql.Row {
return d.db.Model(d.newResultPointer()).Row()
}

// Rows ...
func (d *DO) Rows() (*sql.Rows, error) {
return d.db.Model(d.newResultPointer()).Rows()
}

// Scan ...
func (d *DO) Scan(dest interface{}) error {
return d.db.Model(d.newResultPointer()).Scan(dest).Error
}

// Pluck ...
func (d *DO) Pluck(column field.Expr, dest interface{}) error {
return d.db.Model(d.newResultPointer()).Pluck(column.ColumnName().String(), dest).Error
}

// ScanRows ...
func (d *DO) ScanRows(rows *sql.Rows, dest interface{}) error {
return d.db.Model(d.newResultPointer()).ScanRows(rows, dest)
}

// WithResult ...
func (d DO) WithResult(fc func(tx Dao)) ResultInfo {
d.db = d.db.Set("", "")
fc(&d)
Expand Down Expand Up @@ -882,6 +928,7 @@ func Table(subQueries ...SubQuery) Dao {

// ======================== sub query method ========================

// Columns columns array
type Columns []field.Expr

// Set assign value by subquery
Expand All @@ -905,45 +952,52 @@ func (cs Columns) In(queryOrValue Condition) field.Expr {
}
}

// NotIn ...
func (cs Columns) NotIn(queryOrValue Condition) field.Expr {
return field.Not(cs.In(queryOrValue))
}

// Eq ...
func (cs Columns) Eq(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.EqOp, cs[0], query.underlyingDB())
}

// Neq ...
func (cs Columns) Neq(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.NeqOp, cs[0], query.underlyingDB())
}

// Gt ...
func (cs Columns) Gt(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.GtOp, cs[0], query.underlyingDB())
}

// Gte ...
func (cs Columns) Gte(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.GteOp, cs[0], query.underlyingDB())
}

// Lt ...
func (cs Columns) Lt(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.LtOp, cs[0], query.underlyingDB())
}

// Lte ...
func (cs Columns) Lte(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
Expand Down
2 changes: 1 addition & 1 deletion do_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func checkBuildExpr(t *testing.T, e SubQuery, opts []stmtOpt, result string, var

sql := strings.TrimSpace(stmt.SQL.String())
if sql != result {
t.Errorf("Sql expects %v got %v", result, sql)
t.Errorf("SQL expects %v got %v", result, sql)
}

if !reflect.DeepEqual(stmt.Vars, vars) {
Expand Down
2 changes: 1 addition & 1 deletion generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ func (g *Generator) ApplyBasic(models ...interface{}) {
// ApplyInterface specifies method interfaces on structures, implment codes will be generated after calling g.Execute()
// eg: g.ApplyInterface(func(model.Method){}, model.User{}, model.Company{})
func (g *Generator) ApplyInterface(fc interface{}, models ...interface{}) {
structs, err := check.CheckStructs(g.db, models...)
structs, err := check.ConvertStructs(g.db, models...)
if err != nil {
g.db.Logger.Error(context.Background(), "check struct fail: %v", err)
panic("check struct fail")
Expand Down
Loading

0 comments on commit 6734181

Please sign in to comment.