Skip to content

Commit

Permalink
feat: generate with interface (go-gorm#492)
Browse files Browse the repository at this point in the history
* feat: interface style

* chore(deps): modules tidy up

* feat: generate with DIY methods interface

* feat: optimize template render

* feat: add mode `WithQueryIface`

Co-authored-by: zhangzitao <[email protected]>
  • Loading branch information
tr1v3r and mc-zzt authored Jun 27, 2022
1 parent 8012a3f commit 8a85f25
Show file tree
Hide file tree
Showing 16 changed files with 218 additions and 104 deletions.
2 changes: 1 addition & 1 deletion condition.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ func condToExpression(conds []Condition) ([]clause.Expression, error) {
}

switch cond.(type) {
case *condContainer, field.Expr, subQuery:
case *condContainer, field.Expr, SubQuery:
default:
return nil, fmt.Errorf("unsupported condition: %+v", cond)
}
Expand Down
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ const (

// WithoutContext generate code without context constrain
WithoutContext

// WithQueryIface generate code with exported interface object
WithQueryIface
)

// Config generator's basic configuration
Expand Down
32 changes: 16 additions & 16 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ func (d *DO) Alias() string {
}

// Columns return columns for Subquery
func (*DO) Columns(cols ...field.Expr) columns { return cols }
func (*DO) Columns(cols ...field.Expr) Columns { return cols }

// ======================== chainable api ========================
func (d *DO) Not(conds ...Condition) Dao {
Expand Down Expand Up @@ -503,7 +503,7 @@ func (d *DO) Preload(field field.RelationField) Dao {
}

// UpdateFrom specify update sub query
func (d *DO) UpdateFrom(q subQuery) Dao {
func (d *DO) UpdateFrom(q SubQuery) Dao {
var tableName strings.Builder
d.db.Statement.QuoteTo(&tableName, d.TableName())
if d.alias != "" {
Expand Down Expand Up @@ -623,7 +623,7 @@ func (d *DO) Update(column field.Expr, value interface{}) (info ResultInfo, err
switch value := value.(type) {
case field.AssignExpr:
result = tx.Update(columnStr, value.AssignExpr())
case subQuery:
case SubQuery:
result = tx.Update(columnStr, value.underlyingDB())
default:
result = tx.Update(columnStr, value)
Expand Down Expand Up @@ -671,7 +671,7 @@ func (d *DO) UpdateColumn(column field.Expr, value interface{}) (info ResultInfo
switch value := value.(type) {
case field.Expr:
result = tx.UpdateColumn(columnStr, value.RawExpr())
case subQuery:
case SubQuery:
result = d.db.UpdateColumn(columnStr, value.underlyingDB())
default:
result = d.db.UpdateColumn(columnStr, value)
Expand Down Expand Up @@ -854,7 +854,7 @@ func toInterfaceSlice(value interface{}) []interface{} {
// Table(u.Select(u.ID, u.Name).Where(u.Age.Gt(18))).Select()
// the above usage is equivalent to SQL statement:
// SELECT * FROM (SELECT `id`, `name` FROM `users_info` WHERE `age` > ?)"
func Table(subQueries ...subQuery) Dao {
func Table(subQueries ...SubQuery) Dao {
if len(subQueries) == 0 {
return &DO{}
}
Expand All @@ -879,69 +879,69 @@ func Table(subQueries ...subQuery) Dao {

// ======================== sub query method ========================

type columns []field.Expr
type Columns []field.Expr

// Set assign value by subquery
func (cs columns) Set(query subQuery) field.AssignExpr {
func (cs Columns) Set(query SubQuery) field.AssignExpr {
return field.AssignSubQuery(cs, query.underlyingDB())
}

// In accept query or value
func (cs columns) In(queryOrValue Condition) field.Expr {
func (cs Columns) In(queryOrValue Condition) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}

switch query := queryOrValue.(type) {
case field.Value:
return field.ContainsValue(cs, query)
case subQuery:
case SubQuery:
return field.ContainsSubQuery(cs, query.underlyingDB())
default:
return field.EmptyExpr()
}
}

func (cs columns) NotIn(queryOrValue Condition) field.Expr {
func (cs Columns) NotIn(queryOrValue Condition) field.Expr {
return field.Not(cs.In(queryOrValue))
}

func (cs columns) Eq(query subQuery) field.Expr {
func (cs Columns) Eq(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.EqOp, cs[0], query.underlyingDB())
}

func (cs columns) Neq(query subQuery) field.Expr {
func (cs Columns) Neq(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.NeqOp, cs[0], query.underlyingDB())
}

func (cs columns) Gt(query subQuery) field.Expr {
func (cs Columns) Gt(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.GtOp, cs[0], query.underlyingDB())
}

func (cs columns) Gte(query subQuery) field.Expr {
func (cs Columns) Gte(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.GteOp, cs[0], query.underlyingDB())
}

func (cs columns) Lt(query subQuery) field.Expr {
func (cs Columns) Lt(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
return field.CompareSubQuery(field.LtOp, cs[0], query.underlyingDB())
}

func (cs columns) Lte(query subQuery) field.Expr {
func (cs Columns) Lte(query SubQuery) field.Expr {
if len(cs) == 0 {
return field.EmptyExpr()
}
Expand Down
4 changes: 2 additions & 2 deletions do_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ var (
// }
)

func checkBuildExpr(t *testing.T, e subQuery, opts []stmtOpt, result string, vars []interface{}) {
func checkBuildExpr(t *testing.T, e SubQuery, opts []stmtOpt, result string, vars []interface{}) {
stmt := build(e.underlyingDB().Statement, opts...)

sql := strings.TrimSpace(stmt.SQL.String())
Expand Down Expand Up @@ -75,7 +75,7 @@ func build(stmt *gorm.Statement, opts ...stmtOpt) *gorm.Statement {

func TestDO_methods(t *testing.T) {
testcases := []struct {
Expr subQuery
Expr SubQuery
Opts []stmtOpt
ExpectedVars []interface{}
Result string
Expand Down
24 changes: 16 additions & 8 deletions generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ func (g *Generator) GenerateModel(tableName string, opts ...FieldOpt) *check.Bas
return g.GenerateModelAs(tableName, g.db.Config.NamingStrategy.SchemaName(tableName), opts...)
}

// GenerateModel catch table info from db, return a BaseStruct
// GenerateModelAs catch table info from db, return a BaseStruct
func (g *Generator) GenerateModelAs(tableName string, modelName string, fieldOpts ...FieldOpt) *check.BaseStruct {
modelFieldOpts := make([]model.FieldOpt, len(fieldOpts))
for i, opt := range fieldOpts {
Expand Down Expand Up @@ -337,11 +337,11 @@ func (g *Generator) generateQueryFile() (err error) {
g.db.Logger.Error(context.Background(), "generate query unit test fail: %s", err)
return nil
}
err = render(tmpl.DIYMethod_TEST_Basic, &buf, nil)
err = render(tmpl.DIYMethodTestBasic, &buf, nil)
if err != nil {
return err
}
err = render(tmpl.QueryMethod_TEST, &buf, g)
err = render(tmpl.QueryMethodTest, &buf, g)
if err != nil {
g.db.Logger.Error(context.Background(), "generate query unit test fail: %s", err)
return nil
Expand Down Expand Up @@ -375,16 +375,24 @@ func (g *Generator) generateSingleQueryFile(data *genInfo) (err error) {
return err
}

structTmpl := tmpl.BaseStructWithContext
data.BaseStruct = data.BaseStruct.IfaceMode(g.judgeMode(WithQueryIface))

structTmpl := tmpl.TableQueryStructWithContext
if g.judgeMode(WithoutContext) {
structTmpl = tmpl.BaseStruct
structTmpl = tmpl.TableQueryStruct
}

err = render(structTmpl, &buf, data.BaseStruct)
if err != nil {
return err
}

if g.judgeMode(WithQueryIface) {
err = render(tmpl.TableQueryIface, &buf, data)
if err != nil {
return err
}
}

for _, method := range data.Interfaces {
err = render(tmpl.DIYMethod, &buf, method)
if err != nil {
Expand Down Expand Up @@ -418,13 +426,13 @@ func (g *Generator) generateQueryUnitTestFile(data *genInfo) (err error) {
return err
}

err = render(tmpl.CRUDMethod_TEST, &buf, data.BaseStruct)
err = render(tmpl.CRUDMethodTest, &buf, data.BaseStruct)
if err != nil {
return err
}

for _, method := range data.Interfaces {
err = render(tmpl.DIYMethod_TEST, &buf, method)
err = render(tmpl.DIYMethodTest, &buf, method)
if err != nil {
return err
}
Expand Down
13 changes: 1 addition & 12 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
Expand Down Expand Up @@ -114,8 +113,8 @@ github.com/mattn/go-sqlite3 v1.14.12 h1:TJ1bhYJPV44phC+IMu1u2K/i5RriLTPe+yc68XDJ
github.com/mattn/go-sqlite3 v1.14.12/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/modocache/gover v0.0.0-20171022184752-b58185e213c5/go.mod h1:caMODM3PzxT8aQXRPkAt8xlV/e7d7w8GM5g0fa5F0D8=
github.com/pkg/browser v0.0.0-20180916011732-0a3d74bf9ce4/go.mod h1:4OwLy04Bl9Ef3GJJCoec+30X3LQs/0/m4HFRt/2LUSA=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
Expand Down Expand Up @@ -192,7 +191,6 @@ golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654 h1:id054HUawV2/6IGm2IV8KZQjqtwAOo2CYlOToYqa0d0=
golang.org/x/sys v0.0.0-20211019181941-9d821ace8654/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220222200937-f2425489ef4c h1:sSIdNI2Dd6vGv47bKc/xArpfxVmEz2+3j0E6I484xC4=
golang.org/x/sys v0.0.0-20220222200937-f2425489ef4c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand Down Expand Up @@ -248,15 +246,6 @@ gorm.io/driver/sqlite v1.3.4/go.mod h1:B+8GyC9K7VgzJAcrcXMRPdnMcck+8FgJynEehEPM1
gorm.io/driver/sqlserver v1.3.1/go.mod h1:w25Vrx2BG+CJNUu/xKbFhaKlGxT/nzRkhWCCoptX8tQ=
gorm.io/driver/sqlserver v1.3.2 h1:yYt8f/xdAKLY7lCCyXxIUEgZ/WsURos3dHrx8MKFGAk=
gorm.io/driver/sqlserver v1.3.2/go.mod h1:w25Vrx2BG+CJNUu/xKbFhaKlGxT/nzRkhWCCoptX8tQ=
gorm.io/gorm v1.21.15/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.22.2/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.2/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.6 h1:KFLdNgri4ExFFGTRGGFWON2P1ZN28+9SJRN8voOoYe0=
gorm.io/gorm v1.23.6/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.7-0.20220614054850-8d457146283e h1:SpoT1b3pURhVG5iUd3Isbduzh9wSNZfxQWWigihm1HE=
gorm.io/gorm v1.23.7-0.20220614054850-8d457146283e/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/hints v1.1.0 h1:Lp4z3rxREufSdxn4qmkK3TLDltrM10FLTHiuqwDPvXw=
gorm.io/hints v1.1.0/go.mod h1:lKQ0JjySsPBj3uslFzY3JhYDtqEwzm+G1hv8rWujB6Y=
gorm.io/plugin/dbresolver v1.2.1 h1:moK7t4QJRh+Eer60UGuiANM/KG40uhnIqUOPLmnd/7Y=
Expand Down
6 changes: 3 additions & 3 deletions interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ type (
var (
_ Condition = (field.Expr)(nil)
_ Condition = (field.Value)(nil)
_ Condition = (subQuery)(nil)
_ Condition = (SubQuery)(nil)
_ Condition = (Dao)(nil)
)

type subQuery interface {
type SubQuery interface {
underlyingDB() *gorm.DB
underlyingDO() *DO

Expand All @@ -34,7 +34,7 @@ type subQuery interface {

// Dao CRUD methods
type Dao interface {
subQuery
SubQuery
schema.Tabler
As(alias string) Dao

Expand Down
25 changes: 16 additions & 9 deletions internal/check/checkinterface.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,23 @@ type InterfaceMethod struct { // feature will replace InterfaceMethod to parser.
Result []parser.Param // function output params
ResultData parser.Param // output data
Sections *Sections // Parse split SQL into sections
SqlParams []parser.Param // variable in sql need function input
SqlString string // SQL
SQLParams []parser.Param // variable in sql need function input
SQLString string // SQL
GormOption string // gorm execute method Find or Exec or Take
Table string // specified by user. if empty, generate it with gorm
InterfaceName string // origin interface name
Package string // interface package name
HasForParams bool //
}

// HasSqlData has variable or for params will creat params map
func (m *InterfaceMethod) HasSqlData() bool {
return len(m.SqlParams) > 0 || m.HasForParams
// FuncSign function signature
func (m *InterfaceMethod) FuncSign() string {
return fmt.Sprintf("%s(%s) (%s)", m.MethodName, m.GetParamInTmpl(), m.GetResultParamInTmpl())
}

// HasSQLData has variable or for params will creat params map
func (m *InterfaceMethod) HasSQLData() bool {
return len(m.SQLParams) > 0 || m.HasForParams
}

// HasGotPoint parameter has pointer or not
Expand All @@ -52,6 +57,7 @@ func (m *InterfaceMethod) GormRunMethodName() string {
return "Take"
}

// ReturnRowsAffected return rows affected
func (m *InterfaceMethod) ReturnRowsAffected() bool {
for _, res := range m.Result {
if res.Name == "rowsAffected" {
Expand All @@ -61,6 +67,7 @@ func (m *InterfaceMethod) ReturnRowsAffected() bool {
return false
}

// ReturnError return error
func (m *InterfaceMethod) ReturnError() bool {
for _, res := range m.Result {
if res.IsError() {
Expand Down Expand Up @@ -205,7 +212,7 @@ func (m *InterfaceMethod) checkResult(result []parser.Param) (err error) {

// checkSQL get sql from comment and check it
func (m *InterfaceMethod) checkSQL() (err error) {
m.SqlString = m.parseDocString()
m.SQLString = m.parseDocString()
if err = m.sqlStateCheckAndSplit(); err != nil {
err = fmt.Errorf("interface %s member method %s check sql err:%w", m.InterfaceName, m.MethodName, err)
}
Expand Down Expand Up @@ -260,7 +267,7 @@ func (m *InterfaceMethod) getSQLDocString() string {

// sqlStateCheckAndSplit check sql with an adeterministic finite automaton
func (m *InterfaceMethod) sqlStateCheckAndSplit() error {
sqlString := m.SqlString
sqlString := m.SQLString
m.Sections = NewSections()
var buf model.SQLBuffer
for i := 0; !strOutRange(i, sqlString); i++ {
Expand Down Expand Up @@ -390,7 +397,7 @@ func (m *InterfaceMethod) checkSQLVarByParams(param string, status model.Status)
switch status {
case model.DATA:
if !m.isParamExist(param) {
m.SqlParams = append(m.SqlParams, p)
m.SQLParams = append(m.SQLParams, p)
}
case model.VARIABLE:
if p.Type != "string" || p.IsArray {
Expand Down Expand Up @@ -419,7 +426,7 @@ func (m *InterfaceMethod) checkSQLVarByParams(param string, status model.Status)

// isParamExist check param duplicate
func (m *InterfaceMethod) isParamExist(paramName string) bool {
for _, param := range m.SqlParams {
for _, param := range m.SQLParams {
if param.Name == paramName {
return true
}
Expand Down
Loading

0 comments on commit 8a85f25

Please sign in to comment.