Skip to content

Commit

Permalink
feat: refactor model
Browse files Browse the repository at this point in the history
  • Loading branch information
tr1v3r committed Sep 15, 2021
1 parent 7ec59fc commit 9dd48f0
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 68 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -838,11 +838,11 @@ u.WithContext(ctx).Where(u.Activate.Is(true)).UpdateSimple(u.Age.Add(1))
u := query.Use(db).User

// Update attributes with `map`
u.WithContext(ctx).Model(&model.User{ID: 111}).Updates(map[string]interface{}{"name": "hello", "age": 18, "active": false})
u.WithContext(ctx).Where(u.ID.Eq(111)).Updates(map[string]interface{}{"name": "hello", "age": 18, "active": false})
// UPDATE users SET name='hello', age=18, active=false, updated_at='2013-11-17 21:34:10' WHERE id=111;

// Update attributes with `struct`
u.WithContext(ctx)..Where(u.ID.Eq(111)).Updates(model.User{Name: "hello", Age: 18, Active: false})
u.WithContext(ctx).Where(u.ID.Eq(111)).Updates(model.User{Name: "hello", Age: 18, Active: false})
// UPDATE users SET name='hello', age=18, active=false, updated_at='2013-11-17 21:34:10' WHERE id=111;
```

Expand Down
87 changes: 45 additions & 42 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ type resultInfo struct {
// DO (data object): implement basic query methods
// the structure embedded with a *gorm.DB, and has a element item "alias" will be used when used as a sub query
type DO struct {
db *gorm.DB
alias string // for subquery
model interface{}
db *gorm.DB
alias string // for subquery
model interface{}
schema *schema.Schema
}

func (d *DO) getInstance(db *gorm.DB) *DO { return &DO{db: db, alias: d.alias, model: d.model} }
Expand All @@ -54,11 +55,19 @@ func (d *DO) UseDB(db *gorm.DB, opts ...doOptions) {
d.db = db
}

func (d *DO) ReplaceDB(db *gorm.DB) {
d.db = db
}

// UseModel specify a data model structure as a source for table name
func (d *DO) UseModel(model interface{}) {
d.model = model
d.db = d.db.Model(model).Session(new(gorm.Session))
_ = d.db.Statement.Parse(model)

err := d.db.Statement.Parse(model)
if err != nil {
panic(fmt.Errorf("Cannot parse model: %+v", model))
}
d.schema = d.db.Statement.Schema
}

// UseTable specify table name
Expand All @@ -68,7 +77,7 @@ func (d *DO) UseTable(tableName string) {

// TableName return table name
func (d *DO) TableName() string {
return d.db.Statement.Table
return d.schema.Table
}

// UnderlyingDB return the underlying database connection
Expand Down Expand Up @@ -320,27 +329,27 @@ func getFromClause(db *gorm.DB) *clause.From {

// ======================== finisher api ========================
func (d *DO) Create(value interface{}) error {
return d.db.Create(value).Error
return d.db.Model(d.model).Create(value).Error
}

func (d *DO) CreateInBatches(value interface{}, batchSize int) error {
return d.db.CreateInBatches(value, batchSize).Error
return d.db.Model(d.model).CreateInBatches(value, batchSize).Error
}

func (d *DO) Save(value interface{}) error {
return d.db.Save(value).Error
return d.db.Model(d.model).Save(value).Error
}

func (d *DO) First() (result interface{}, err error) {
return d.singleQuery(d.db.First)
return d.singleQuery(d.db.Model(d.model).First)
}

func (d *DO) Take() (result interface{}, err error) {
return d.singleQuery(d.db.Take)
return d.singleQuery(d.db.Model(d.model).Take)
}

func (d *DO) Last() (result interface{}, err error) {
return d.singleQuery(d.db.Last)
return d.singleQuery(d.db.Model(d.model).Last)
}

func (d *DO) singleQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (result interface{}, err error) {
Expand All @@ -352,7 +361,7 @@ func (d *DO) singleQuery(query func(dest interface{}, conds ...interface{}) *gor
}

func (d *DO) Find() (results interface{}, err error) {
return d.multiQuery(d.db.Find)
return d.multiQuery(d.db.Model(d.model).Find)
}

func (d *DO) multiQuery(query func(dest interface{}, conds ...interface{}) *gorm.DB) (results interface{}, err error) {
Expand All @@ -362,29 +371,20 @@ func (d *DO) multiQuery(query func(dest interface{}, conds ...interface{}) *gorm
}

func (d *DO) FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error {
return d.db.FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
}

// func (d *DO) FirstOrInit(dest interface{}, conds ...field.Expr) error {
// return d.db.Clauses(toExpression(conds)...).FirstOrInit(dest).Error
// }

// func (d *DO) FirstOrCreate(dest interface{}, conds ...field.Expr) error {
// return d.db.Clauses(toExpression(conds)...).FirstOrCreate(dest).Error
// }

func (d *DO) Model(model interface{}) Dao {
return d.getInstance(d.db.Model(model))
return d.db.Model(d.model).FindInBatches(dest, batchSize, func(tx *gorm.DB, batch int) error { return fc(d.getInstance(tx), batch) }).Error
}

func (d *DO) Update(column field.Expr, value interface{}) error {
tx := d.db.Model(d.model)
columnStr := column.BuildColumn(d.db.Statement, field.WithTable, field.WithoutQuote).String()

switch value := value.(type) {
case field.Expr:
return d.db.Update(column.BuildColumn(d.db.Statement, field.WithTable).String(), value.RawExpr()).Error
return tx.Update(columnStr, value.RawExpr()).Error
case subQuery:
return d.db.Update(column.BuildColumn(d.db.Statement, field.WithTable).String(), value.underlyingDB()).Error
return tx.Update(columnStr, value.underlyingDB()).Error
default:
return d.db.Update(column.BuildColumn(d.db.Statement, field.WithTable).String(), value).Error
return tx.Update(columnStr, value).Error
}
}

Expand All @@ -397,18 +397,21 @@ func (d *DO) UpdateSimple(column field.Expr) error {
}

func (d *DO) Updates(value interface{}) (info resultInfo, err error) {
result := d.db.Updates(value)
result := d.db.Model(d.model).Updates(value)
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

func (d *DO) UpdateColumn(column field.Expr, value interface{}) error {
tx := d.db.Model(d.model)
columnStr := column.BuildColumn(d.db.Statement, field.WithTable, field.WithoutQuote).String()

switch value := value.(type) {
case field.Expr:
return d.db.UpdateColumn(column.BuildColumn(d.db.Statement, field.WithTable).String(), value.RawExpr()).Error
return tx.UpdateColumn(columnStr, value.RawExpr()).Error
case subQuery:
return d.db.UpdateColumn(column.BuildColumn(d.db.Statement, field.WithTable).String(), value.underlyingDB()).Error
return d.db.UpdateColumn(columnStr, value.underlyingDB()).Error
default:
return d.db.UpdateColumn(column.BuildColumn(d.db.Statement, field.WithTable).String(), value).Error
return d.db.UpdateColumn(columnStr, value).Error
}
}

Expand All @@ -417,41 +420,41 @@ func (d *DO) UpdateColumnSimple(column field.Expr) error {
if !ok {
return ErrInvalidExpression
}
return d.db.UpdateColumn(column.BuildColumn(d.db.Statement, field.WithTable).String(), expr).Error
return d.db.Model(d.model).UpdateColumn(column.BuildColumn(d.db.Statement, field.WithTable).String(), expr).Error
}

func (d *DO) UpdateColumns(value interface{}) (info resultInfo, err error) {
result := d.db.UpdateColumns(value)
result := d.db.Model(d.model).UpdateColumns(value)
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

func (d *DO) Delete() (info resultInfo, err error) {
result := d.db.Delete(reflect.New(d.getModelType()).Interface())
result := d.db.Model(d.model).Delete(reflect.New(d.getModelType()).Interface())
return resultInfo{RowsAffected: result.RowsAffected, Error: result.Error}, result.Error
}

func (d *DO) Count() (count int64, err error) {
return count, d.db.Count(&count).Error
return count, d.db.Model(d.model).Count(&count).Error
}

func (d *DO) Row() *sql.Row {
return d.db.Row()
return d.db.Model(d.model).Row()
}

func (d *DO) Rows() (*sql.Rows, error) {
return d.db.Rows()
return d.db.Model(d.model).Rows()
}

func (d *DO) Scan(dest interface{}) error {
return d.db.Scan(dest).Error
return d.db.Model(d.model).Scan(dest).Error
}

func (d *DO) Pluck(column field.Expr, dest interface{}) error {
return d.db.Pluck(column.ColumnName().String(), dest).Error
return d.db.Model(d.model).Pluck(column.ColumnName().String(), dest).Error
}

func (d *DO) ScanRows(rows *sql.Rows, dest interface{}) error {
return d.db.ScanRows(rows, dest)
return d.db.Model(d.model).ScanRows(rows, dest)
}

func (d *DO) newResultPointer() interface{} {
Expand Down
5 changes: 0 additions & 5 deletions generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,11 +252,6 @@ func (u userDo) FindByPage(offset int, limit int) (result []*user, count int64,
return
}

func (u userDo) Model(result *user) *userDo {
u.DO = *u.DO.Model(result).(*DO)
return &u
}

var u = func() *user {
u := user{
ID: field.NewUint("", "id"),
Expand Down
1 change: 0 additions & 1 deletion interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ type Dao interface {
Last() (result interface{}, err error)
Find() (results interface{}, err error)
FindInBatches(dest interface{}, batchSize int, fc func(tx Dao, batch int) error) error
Model(model interface{}) Dao
Update(column field.Expr, value interface{}) error
UpdateSimple(column field.Expr) error
Updates(values interface{}) (info resultInfo, err error)
Expand Down
5 changes: 0 additions & 5 deletions internal/template/method.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,4 @@ func ({{.S}} {{.NewStructName}}Do) FindByPage(offset int, limit int) (result []*
return
}
func ({{.S}} {{.NewStructName}}Do) Model(result *{{.StructInfo.Package}}.{{.StructInfo.Type}}) *{{.NewStructName}}Do {
{{.S}}.DO = *{{.S}}.DO.Model(result).(*gen.DO)
return &{{.S}}
}
`
41 changes: 32 additions & 9 deletions internal/template/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ func SetDefault(db *gorm.DB) {
`

// TODO remove pointer && query clone
const QueryTmpl = `
func Use(db *gorm.DB) *Query {
return &Query{
Expand All @@ -32,31 +31,55 @@ type Query struct{
db *gorm.DB
{{range $name,$d :=.Data -}}
{{$d.StructName}} *{{$d.NewStructName}}
{{$d.StructName}} {{$d.NewStructName}}
{{end}}
}
func (q *Query) clone(db *gorm.DB) *Query {
return &Query{
{{range $name,$d :=.Data -}}
{{$d.StructName}}: q.{{$d.StructName}}.clone(db),
{{end}}
}
}
type queryCtx struct{
{{range $name,$d :=.Data -}}
{{$d.StructName}} {{$d.NewStructName}}Do
{{end}}
}
func (q *Query) WithContext(ctx context.Context) *queryCtx {
return &queryCtx{
{{range $name,$d :=.Data -}}
{{$d.StructName}}: q.{{$d.StructName}}.{{$d.NewStructName}}Do,
{{end}}
}
}
func (q *Query) Transaction(fc func(db *Query) error, opts ...*sql.TxOptions) error {
return q.db.Transaction(func(tx *gorm.DB) error { return fc(Use(tx)) }, opts...)
return q.db.Transaction(func(tx *gorm.DB) error { return fc(q.clone(tx)) }, opts...)
}
func (q *Query) Begin(opts ...*sql.TxOptions) *Query {
return Use(q.db.Begin(opts...))
func (q *Query) Begin(opts ...*sql.TxOptions) *queryTx {
return &queryTx{q.clone(q.db.Begin(opts...))}
}
func (q *Query) Commit() error {
type queryTx struct{ *Query }
func (q *queryTx) Commit() error {
return q.db.Commit().Error
}
func (q *Query) Rollback() error {
func (q *queryTx) Rollback() error {
return q.db.Rollback().Error
}
func (q *Query) SavePoint(name string) error {
func (q *queryTx) SavePoint(name string) error {
return q.db.SavePoint(name).Error
}
func (q *Query) RollbackTo(name string) error {
func (q *queryTx) RollbackTo(name string) error {
return q.db.RollbackTo(name).Error
}
Expand Down
15 changes: 11 additions & 4 deletions internal/template/struct.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package template

const createMethod = `
func new{{.StructName}}(db *gorm.DB) *{{.NewStructName}} {
_{{.NewStructName}} := new({{.NewStructName}})
func new{{.StructName}}(db *gorm.DB) {{.NewStructName}} {
_{{.NewStructName}} := {{.NewStructName}}{}
_{{.NewStructName}}.{{.NewStructName}}Do.UseDB(db)
_{{.NewStructName}}.{{.NewStructName}}Do.UseModel({{.StructInfo.Package}}.{{.StructInfo.Type}}{})
Expand All @@ -18,6 +18,13 @@ func new{{.StructName}}(db *gorm.DB) *{{.NewStructName}} {

const defineMethodStruct = `type {{.NewStructName}}Do struct { gen.DO }`

const cloneMethod = `
func ({{.S}} {{.NewStructName}}) clone(db *gorm.DB) {{.NewStructName}} {
{{.S}}.{{.NewStructName}}Do.ReplaceDB(db)
return {{.S}}
}
`

const BaseStruct = createMethod + `
type {{.NewStructName}} struct {
{{.NewStructName}}Do
Expand All @@ -26,7 +33,7 @@ type {{.NewStructName}} struct {
{{end}}
}
` + defineMethodStruct
` + cloneMethod + defineMethodStruct

const BaseStructWithContext = createMethod + `
type {{.NewStructName}} struct {
Expand All @@ -38,4 +45,4 @@ type {{.NewStructName}} struct {
func ({{.S}} *{{.NewStructName}}) WithContext(ctx context.Context) *{{.NewStructName}}Do { return {{.S}}.{{.NewStructName}}Do.WithContext(ctx)}
` + defineMethodStruct
` + cloneMethod + defineMethodStruct

0 comments on commit 9dd48f0

Please sign in to comment.